Skip to content

Commit 5b19907

Browse files
committed
Add test for join_node, polish the implementation
1 parent d4f3515 commit 5b19907

File tree

7 files changed

+185
-99
lines changed

7 files changed

+185
-99
lines changed

include/oneapi/tbb/detail/_flow_graph_indexer_impl.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*
2-
Copyright (c) 2005-2024 Intel Corporation
2+
Copyright (c) 2005-2025 Intel Corporation
3+
Copyright (c) 2025 UXL Foundation Contributors
34
45
Licensed under the Apache License, Version 2.0 (the "License");
56
you may not use this file except in compliance with the License.
@@ -227,18 +228,18 @@
227228
struct indexer_types {
228229
using output_type = tagged_msg<std::size_t, T0, TN...>;
229230
using input_ports_type = std::tuple<indexer_input_port<T0>, indexer_input_port<TN>...>;
230-
using indexer_FE_type = indexer_node_FE<input_ports_type, output_type, std::tuple<T0, TN...>>;
231231
using indexer_base_type = indexer_node_base<input_ports_type, output_type, std::tuple<T0, TN...>>;
232232
};
233233

234234
template<typename T0, typename... TN>
235235
class unfolded_indexer_node : public indexer_types<T0, TN...>::indexer_base_type {
236236
public:
237-
typedef std::tuple<T0, TN...> tuple_types;
238-
typedef typename indexer_types<T0, TN...>::input_ports_type input_ports_type;
239-
typedef typename indexer_types<T0, TN...>::output_type output_type;
237+
using input_ports_type = typename indexer_types<T0, TN...>::input_ports_type;
238+
using output_type = typename indexer_types<T0, TN...>::output_type;
239+
using tuple_types = std::tuple<T0, TN...>;
240+
240241
private:
241-
typedef typename indexer_types<T0, TN...>::indexer_base_type base_type;
242+
using base_type = typename indexer_types<T0, TN...>::indexer_base_type;
242243
public:
243244
unfolded_indexer_node(graph& g) : base_type(g) {}
244245
unfolded_indexer_node(const unfolded_indexer_node &other) : base_type(other) {}

include/oneapi/tbb/detail/_flow_graph_join_impl.h

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*
2-
Copyright (c) 2005-2024 Intel Corporation
2+
Copyright (c) 2005-2025 Intel Corporation
3+
Copyright (c) 2025 UXL Foundation Contributors
34
45
Licensed under the Apache License, Version 2.0 (the "License");
56
you may not use this file except in compliance with the License.
@@ -1458,34 +1459,19 @@
14581459
// join base class type generator
14591460
template<template<class> class PT, typename OutputTuple, typename JP>
14601461
struct join_base {
1461-
typedef join_node_base<JP, typename wrap_tuple_elements<PT,OutputTuple>::type, OutputTuple> type;
1462+
using type = join_node_base<JP, typename wrap_tuple_elements<PT, OutputTuple>::type, OutputTuple>;
14621463
};
14631464

14641465
template<typename OutputTuple, typename K, typename KHash>
14651466
struct join_base<key_matching_port, OutputTuple, key_matching<K,KHash> > {
1466-
typedef key_matching<K, KHash> key_traits_type;
1467-
typedef K key_type;
1468-
typedef KHash key_hash_compare;
1469-
typedef join_node_base< key_traits_type,
1470-
// ports type
1471-
typename wrap_key_tuple_elements<key_matching_port,key_traits_type,OutputTuple>::type,
1472-
OutputTuple > type;
1473-
};
1467+
using key_type = K;
1468+
using key_hash_compare = KHash;
1469+
using key_traits_type = key_matching<key_type, key_hash_compare>;
14741470

1475-
//! unfolded_join_node : passes input_ports_type to join_node_base. We build the input port type
1476-
// using tuple_element. The class PT is the port type (reserving_port, queueing_port, key_matching_port)
1477-
// and should match the typename.
1478-
1479-
template<int M, template<class> class PT, typename OutputTuple, typename JP>
1480-
class unfolded_join_node : public join_base<M,PT,OutputTuple,JP>::type {
1481-
public:
1482-
typedef typename wrap_tuple_elements<M, PT, OutputTuple>::type input_ports_type;
1483-
typedef OutputTuple output_type;
1484-
private:
1485-
typedef join_node_base<JP, input_ports_type, output_type > base_type;
1486-
public:
1487-
unfolded_join_node(graph &g) : base_type(g) {}
1488-
unfolded_join_node(const unfolded_join_node &other) : base_type(other) {}
1471+
using type = join_node_base<key_traits_type,
1472+
// ports type
1473+
typename wrap_key_tuple_elements<key_matching_port, key_traits_type, OutputTuple>::type,
1474+
OutputTuple>;
14891475
};
14901476

14911477
#if __TBB_PREVIEW_MESSAGE_BASED_KEY_MATCHING
@@ -1504,25 +1490,23 @@
15041490
};
15051491
#endif /* __TBB_PREVIEW_MESSAGE_BASED_KEY_MATCHING */
15061492

1507-
//! unfolded_join_node : passes input_ports_type to join_node_base. We build the input port type
1508-
// using tuple_element. The class PT is the port type (reserving_port, queueing_port, key_matching_port)
1493+
//! unfolded_join_node : passes input_ports_type to join_node_base. We build the input port type
1494+
// using tuple_element. The class PortType is the port type (reserving_port, queueing_port, key_matching_port)
15091495
// and should match the typename.
1510-
1511-
template<template<class> class PT, typename OutputTuple, typename JP>
1512-
class unfolded_join_node : public join_base<PT, OutputTuple, JP>::type {
1496+
template<template<class> class PortType, typename OutputTuple, typename JoinPolicy>
1497+
class unfolded_join_node : public join_base<PortType, OutputTuple, JoinPolicy>::type {
15131498
public:
1514-
using input_ports_type = typename wrap_tuple_elements<PT, OutputTuple>::type;
1499+
using input_ports_type = typename wrap_tuple_elements<PortType, OutputTuple>::type;
15151500
using output_type = OutputTuple;
15161501
private:
1517-
using base_type = join_node_base<JP, input_ports_type, output_type>;
1502+
using base_type = join_node_base<JoinPolicy, input_ports_type, output_type>;
15181503
public:
15191504
unfolded_join_node(graph& g) : base_type(g) {}
15201505
unfolded_join_node(const unfolded_join_node &other) : base_type(other) {}
15211506
};
15221507

15231508
// key_matching unfolded_join_node. This must be a separate specialization because the constructors
15241509
// differ.
1525-
15261510
template<typename K, typename KHash, typename...Types>
15271511
class unfolded_join_node<key_matching_port, std::tuple<Types...>, key_matching<K, KHash>>
15281512
: public join_base<key_matching_port, std::tuple<Types...>, key_matching<K, KHash>>::type

include/oneapi/tbb/detail/_flow_graph_types_impl.h

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*
2-
Copyright (c) 2005-2024 Intel Corporation
2+
Copyright (c) 2005-2025 Intel Corporation
3+
Copyright (c) 2025 UXL Foundation Contributors
34
45
Licensed under the Apache License, Version 2.0 (the "License");
56
you may not use this file except in compliance with the License.
@@ -43,24 +44,24 @@ struct KeyTrait {
4344
};
4445

4546
// wrap each element of a tuple in a template, and make a tuple of the result.
46-
template<template<class> class PT, typename TypeTuple>
47+
template<template<class> class PortType, typename TypeTuple>
4748
struct wrap_tuple_elements;
4849

4950
// A wrapper that generates the traits needed for each port of a key-matching join,
5051
// and the type of the tuple of input ports.
51-
template<template<class> class PT, typename KeyTraits, typename TypeTuple>
52+
template<template<class> class PortType, typename KeyTraits, typename TypeTuple>
5253
struct wrap_key_tuple_elements;
5354

54-
template<template<class> class PT, typename... Args>
55-
struct wrap_tuple_elements<PT, std::tuple<Args...> >{
56-
typedef typename std::tuple<PT<Args>... > type;
55+
template<template<class> class PortType, typename... Args>
56+
struct wrap_tuple_elements<PortType, std::tuple<Args...> >{
57+
using type = std::tuple<PortType<Args>...>;
5758
};
5859

59-
template<template<class> class PT, typename KeyTraits, typename... Args>
60-
struct wrap_key_tuple_elements<PT, KeyTraits, std::tuple<Args...> > {
61-
typedef typename KeyTraits::key_type K;
62-
typedef typename KeyTraits::hash_compare_type KHash;
63-
typedef typename std::tuple<PT<KeyTrait<K, KHash, Args> >... > type;
60+
template<template<class> class PortType, typename KeyTraits, typename... Args>
61+
struct wrap_key_tuple_elements<PortType, KeyTraits, std::tuple<Args...> > {
62+
using key_type = typename KeyTraits::key_type;
63+
using hash_compare_type = typename KeyTraits::hash_compare_type;
64+
using type = std::tuple<PortType<KeyTrait<key_type, hash_compare_type, Args>>...>;
6465
};
6566

6667
template< int... S > class sequence {};
@@ -309,9 +310,9 @@ struct do_if<T, false> {
309310

310311
using tbb::detail::punned_cast;
311312

312-
template<typename TagType, typename T0, typename... TN>
313+
template<typename TagType, typename... TN>
313314
class tagged_msg {
314-
using Tuple = std::tuple<T0, TN...>;
315+
using Tuple = std::tuple<TN...>;
315316

316317
class variant {
317318
static const size_t N = std::tuple_size<Tuple>::value;
@@ -387,7 +388,7 @@ class tagged_msg {
387388
bool is_a() const {return my_msg.template variant_is_a<V>();}
388389

389390
bool is_default_constructed() const {return my_msg.variant_is_default_constructed();}
390-
};
391+
}; // class tagged_msg
391392

392393
// template to simplify cast and test for tagged_msg in template contexts
393394
template<typename V, typename T>

include/oneapi/tbb/flow_graph.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*
22
Copyright (c) 2005-2025 Intel Corporation
3+
Copyright (c) 2025 UXL Foundation Contributors
34
45
Licensed under the Apache License, Version 2.0 (the "License");
56
you may not use this file except in compliance with the License.
@@ -2495,9 +2496,9 @@ class join_node<OutputTuple, key_matching<K, KHash> > : public unfolded_join_nod
24952496
join_node(graph &g) : unfolded_type(g) {}
24962497
#endif /* __TBB_PREVIEW_MESSAGE_BASED_KEY_MATCHING */
24972498

2498-
template<typename... Bodies, typename = typename std::enable_if<sizeof...(Bodies) == N>>
2499-
__TBB_requires(join_node_functions<OutputTuple, K, Bodies...>)
2500-
__TBB_NOINLINE_SYM join_node(graph& g, Bodies... bodies) : unfolded_type(g, bodies...) {
2499+
template <typename Body, typename... Bodies, typename = typename std::enable_if<1 + sizeof...(Bodies) == N>>
2500+
__TBB_requires(join_node_functions<OutputTuple, K, Body, Bodies...>)
2501+
__TBB_NOINLINE_SYM join_node(graph& g, Body body, Bodies... bodies) : unfolded_type(g, body, bodies...) {
25012502
fgt_multiinput_node<N>( CODEPTR(), FLOW_JOIN_NODE_TAG_MATCHING, &this->my_graph,
25022503
this->input_ports(), static_cast< sender< output_type > *>(this) );
25032504
}

test/conformance/conformance_flowgraph.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*
22
Copyright (c) 2020-2025 Intel Corporation
3+
Copyright (c) 2025 UXL Foundation Contributors
34
45
Licensed under the Apache License, Version 2.0 (the "License");
56
you may not use this file except in compliance with the License.
@@ -811,4 +812,25 @@ void test_with_reserving_join_node_class() {
811812
if at least one successor accepts the tuple must consume messages");
812813
}
813814
}
815+
816+
template <std::size_t N>
817+
struct edge_maker {
818+
template <typename Sender, typename NodeType>
819+
static void make(Sender& sender, NodeType& node) {
820+
oneapi::tbb::flow::make_edge(sender, oneapi::tbb::flow::input_port<N - 1>(node));
821+
edge_maker<N - 1>::make(sender, node);
822+
}
823+
824+
template <typename Sender, typename NodeType>
825+
static void make(std::vector<Sender>& senders, NodeType& node) {
826+
oneapi::tbb::flow::make_edge(senders[N - 1], oneapi::tbb::flow::input_port<N - 1>(node));
827+
edge_maker<N - 1>::make(senders, node);
828+
}
829+
};
830+
831+
template <>
832+
struct edge_maker<0> {
833+
template <typename Sender, typename NodeType>
834+
static void make(Sender&, NodeType&) {}
835+
};
814836
#endif // __TBB_test_conformance_conformance_flowgraph_H

test/conformance/conformance_indexer_node.cpp

Lines changed: 18 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*
2-
Copyright (c) 2020-2021 Intel Corporation
2+
Copyright (c) 2020-2025 Intel Corporation
3+
Copyright (c) 2025 UXL Foundation Contributors
34
45
Licensed under the Apache License, Version 2.0 (the "License");
56
you may not use this file except in compliance with the License.
@@ -19,6 +20,7 @@
1920
#endif
2021

2122
#include "conformance_flowgraph.h"
23+
#include <unordered_set>
2224

2325
//! \file conformance_indexer_node.cpp
2426
//! \brief Test for [flow_graph.indexer_node] specification
@@ -162,8 +164,6 @@ TEST_CASE("indexer_node output_type") {
162164
CHECK_MESSAGE((conformance::check_output_type<my_output_type, oneapi::tbb::flow::tagged_msg<size_t, int, float, input_msg>>()), "indexer_node output_type should returns a tagged_msg");
163165
}
164166

165-
#include <unordered_set>
166-
167167
template <std::size_t N, typename... Args>
168168
struct indexer_node_type_generator_impl {
169169
using type = typename indexer_node_type_generator_impl<N - 1, Args..., int>::type;
@@ -178,66 +178,40 @@ struct indexer_node_type_generator_impl<0, Args...> {
178178
template <std::size_t N>
179179
using indexer_node_type_generator_t = typename indexer_node_type_generator_impl<N>::type;
180180

181-
template <std::size_t N>
182-
struct edge_maker {
183-
template <typename SendersVector, typename IndexerNodeType>
184-
static void make(SendersVector& senders, IndexerNodeType& indexer) {
185-
oneapi::tbb::flow::make_edge(senders[N - 1], oneapi::tbb::flow::input_port<N - 1>(indexer));
186-
edge_maker<N - 1>::make(senders, indexer);
187-
}
188-
};
189-
190-
template <>
191-
struct edge_maker<1> {
192-
template <typename SendersVector, typename IndexerNodeType>
193-
static void make(SendersVector& senders, IndexerNodeType& indexer) {
194-
oneapi::tbb::flow::make_edge(senders[0], oneapi::tbb::flow::input_port<0>(indexer));
195-
}
196-
};
197-
198181
template <std::size_t NInputs>
199182
void test_indexer_node_with_n_inputs() {
200183
using namespace oneapi::tbb::flow;
201-
graph g;
202-
int message = 42;
203-
204-
using submitter_type = function_node<int, int>;
205-
206-
std::vector<submitter_type> submitters;
207-
submitters.reserve(NInputs);
208-
for (std::size_t i = 0; i < NInputs; ++i) {
209-
submitters.emplace_back(g, unlimited, [](int obj) { return obj; });
210-
}
211-
212184
using indexer_type = indexer_node_type_generator_t<NInputs>;
185+
int message = 42;
213186

187+
graph g;
188+
189+
broadcast_node<int> submitter(g);
214190
indexer_type indexer(g);
215191

216-
edge_maker<NInputs>::make(submitters, indexer);
217-
218192
using output_type = typename indexer_type::output_type;
219193

220-
std::unordered_set<std::size_t> indices;
221-
indices.reserve(NInputs);
194+
std::unordered_set<std::size_t> tags;
222195

223196
function_node<output_type> receiver(g, serial, [&](const output_type& indexer_output) {
224-
indices.emplace(indexer_output.tag());
225-
CHECK_MESSAGE(cast_to<int>(indexer_output) == message, "invalid message returned from indexer_node");
197+
auto result = tags.emplace(indexer_output.tag());
198+
CHECK_MESSAGE(result.second, "Duplicated tags returned from the indexer_node");
199+
CHECK_MESSAGE(cast_to<int>(indexer_output) == message, "Invalid message returned from indexer node");
226200
});
227201

202+
edge_maker<NInputs>::make(submitter, indexer);
228203
make_edge(indexer, receiver);
229204

230-
for (auto& submitter : submitters) {
231-
submitter.try_put(message);
232-
}
205+
submitter.try_put(message);
233206
g.wait_for_all();
234207

235-
CHECK_MESSAGE(indices.size() == NInputs, "Message from some port lost");
208+
CHECK_MESSAGE(tags.size() == NInputs, "Incorrect number of tags returned from the indexer_node");
236209
for (std::size_t i = 0; i < NInputs; ++i) {
237-
CHECK(indices.find(i) != indices.end());
210+
CHECK_MESSAGE(tags.count(i) == 1, "Some tag was not returned from indexer_node");
238211
}
239212
}
240213

241-
TEST_CASE("indexer_node with large number of inputs") {
214+
//! \brief \ref interface \ref requirement
215+
TEST_CASE("indexer_node with large number of input ports") {
242216
test_indexer_node_with_n_inputs<50>();
243-
}
217+
}

0 commit comments

Comments
 (0)