Skip to content

Commit 6468abc

Browse files
mgouicemdzarukin
authored andcommitted
api: threadpool: add wait() method for proper async support
1 parent 1b1d658 commit 6468abc

File tree

4 files changed

+32
-2
lines changed

4 files changed

+32
-2
lines changed

include/oneapi/dnnl/dnnl_threadpool_iface.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ struct threadpool_iface {
5757
/// Returns threadpool behavior flags bit mask (see below).
5858
virtual uint64_t get_flags() const = 0;
5959

60+
// Does nothing if SYNCHRONOUS, waits for all jobs for ASYNCHRONOUS
61+
virtual void wait() = 0;
62+
6063
/// If set, parallel_for() returns immediately and oneDNN needs implement
6164
/// waiting for the submitted closures to finish execution on its own.
6265
static constexpr uint64_t ASYNCHRONOUS = 1;

src/cpu/cpu_stream.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ struct cpu_stream_t : public stream_t {
3838

3939
dnnl::impl::status_t wait() override {
4040
// CPU execution is synchronous so return immediately
41+
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
42+
dnnl::threadpool_interop::threadpool_iface *tp;
43+
auto rc = this->get_threadpool(&tp);
44+
if (rc == status::success) {
45+
if (tp->get_flags()
46+
& threadpool_interop::threadpool_iface::ASYNCHRONOUS)
47+
tp->wait();
48+
}
49+
#endif
4150
return dnnl::impl::status::success;
4251
}
4352

tests/benchdnn/utils/parallel.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2022-2023 Intel Corporation
2+
* Copyright 2022-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -25,30 +25,39 @@
2525
#define ACTIVATE_THREADPOOL
2626
#endif
2727

28+
void synchronize() {
29+
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
30+
dnnl::testing::get_threadpool()->wait();
31+
#endif
32+
}
33+
2834
// Note: no need in deactivation as `scoped_activation` object will deactivate
2935
// it automatically at destruction.
30-
3136
void benchdnn_parallel_nd(int64_t D0, const std::function<void(int64_t)> &f) {
3237
ACTIVATE_THREADPOOL;
3338
dnnl::impl::parallel_nd(D0, f);
39+
synchronize();
3440
}
3541

3642
void benchdnn_parallel_nd(int64_t D0, int64_t D1,
3743
const std::function<void(int64_t, int64_t)> &f) {
3844
ACTIVATE_THREADPOOL;
3945
dnnl::impl::parallel_nd(D0, D1, f);
46+
synchronize();
4047
}
4148

4249
void benchdnn_parallel_nd(int64_t D0, int64_t D1, int64_t D2,
4350
const std::function<void(int64_t, int64_t, int64_t)> &f) {
4451
ACTIVATE_THREADPOOL;
4552
dnnl::impl::parallel_nd(D0, D1, D2, f);
53+
synchronize();
4654
}
4755

4856
void benchdnn_parallel_nd(int64_t D0, int64_t D1, int64_t D2, int64_t D3,
4957
const std::function<void(int64_t, int64_t, int64_t, int64_t)> &f) {
5058
ACTIVATE_THREADPOOL;
5159
dnnl::impl::parallel_nd(D0, D1, D2, D3, f);
60+
synchronize();
5261
}
5362

5463
void benchdnn_parallel_nd(int64_t D0, int64_t D1, int64_t D2, int64_t D3,
@@ -57,6 +66,7 @@ void benchdnn_parallel_nd(int64_t D0, int64_t D1, int64_t D2, int64_t D3,
5766
&f) {
5867
ACTIVATE_THREADPOOL;
5968
dnnl::impl::parallel_nd(D0, D1, D2, D3, D4, f);
69+
synchronize();
6070
}
6171

6272
void benchdnn_parallel_nd(int64_t D0, int64_t D1, int64_t D2, int64_t D3,
@@ -65,6 +75,7 @@ void benchdnn_parallel_nd(int64_t D0, int64_t D1, int64_t D2, int64_t D3,
6575
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t)> &f) {
6676
ACTIVATE_THREADPOOL;
6777
dnnl::impl::parallel_nd(D0, D1, D2, D3, D4, D5, f);
78+
synchronize();
6879
}
6980

7081
int benchdnn_get_max_threads() {

tests/test_thread.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
186186

187187
counter.Wait();
188188
};
189+
190+
void wait() override {
191+
// Nothing to do, runtime is synchronous
192+
}
189193
};
190194

191195
} // namespace testing
@@ -210,6 +214,7 @@ class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
210214
tbb::parallel_for(
211215
0, n, [&](int i) { fn(i, n); }, tbb::static_partitioner());
212216
}
217+
void wait() override {}
213218
};
214219

215220
} // namespace testing
@@ -282,6 +287,8 @@ class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
282287
}
283288
}
284289

290+
void wait() override {}
291+
285292
private:
286293
int num_threads_;
287294
std::mutex master_mutex_;

0 commit comments

Comments
 (0)