@@ -195,6 +195,84 @@ class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
195195} // namespace testing
196196} // namespace dnnl
197197
198+ #elif defined(DNNL_TEST_THREADPOOL_USE_EIGEN_ASYNC)
199+
200+ // absl sources define its own version of `CHECK` macro. oneDNN's version is not
201+ // needed further the file, thus, disable it for compilation reason.
202+ #undef CHECK
203+
204+ #define EIGEN_USE_THREADS
205+ #include " Eigen/ThreadPool"
206+
207+ #include " xla/backends/cpu/runtime/work_queue.h"
208+ #include " xla/tsl/concurrency/async_value_ref.h"
209+ #include " xla/tsl/concurrency/chain.h"
210+
211+ #include < cstddef>
212+ #include < cstdint>
213+ #include < functional>
214+
215+ namespace dnnl {
216+ namespace testing {
217+
218+ static tsl::AsyncValueRef<tsl::Chain> OkDoneEventSingleton () {
219+ static std::unique_ptr<tsl::AsyncValueOwningRef<tsl::Chain>> singleton =
220+ [] {
221+ static auto storage = std::make_unique<
222+ tsl::internal::AsyncValueStorage<tsl::Chain>>();
223+ return std::make_unique<tsl::AsyncValueOwningRef<tsl::Chain>>(
224+ tsl::MakeAvailableAsyncValueRef<tsl::Chain>(*storage));
225+ }();
226+ return singleton->AsRef ();
227+ }
228+
229+ class threadpool_t : public dnnl ::threadpool_interop::threadpool_iface {
230+ private:
231+ // Original `OneDnnThreadPool` at
232+ // `xla/backends/cpu/runtime/onednn/onednn_threadpool.h` takes
233+ // `Eigen::ThreadPoolInterface` instead. Since `Eigen::ThreadPool` is
234+ // a parent class, which is an alias to `NonBlockingThreadPool`, it fits
235+ // the need.
236+ std::unique_ptr<Eigen::ThreadPool> thread_pool_;
237+
238+ // Async value that signals completion of the last scheduled parallel loop.
239+ // This is used only when is_async_ is true.
240+ tsl::AsyncValueRef<tsl::Chain> done_event_;
241+
242+ public:
243+ explicit threadpool_t (int num_threads = 0 ) {
244+ if (num_threads <= 0 ) num_threads = read_num_threads_from_env ();
245+ thread_pool_.reset (new Eigen::ThreadPool (num_threads));
246+ done_event_ = OkDoneEventSingleton ();
247+ }
248+ int get_num_threads () const override { return thread_pool_->NumThreads (); }
249+ bool get_in_parallel () const override { return false ; }
250+ uint64_t get_flags () const override { return ASYNCHRONOUS; }
251+ void parallel_for (int n, const std::function<void (int , int )> &fn) override {
252+ // If we are using oneDNN with async support, we need to schedule the
253+ // parallel loop using the done_event_. This allows us to return
254+ // immediately and not block the caller thread.
255+ auto parallelize = [this , n, fn](tsl::Chain) {
256+ return xla::cpu::Worker::Parallelize (thread_pool_.get (),
257+ thread_pool_->NumThreads (), n,
258+ [fn, n](size_t i) { fn (static_cast <int >(i), n); });
259+ };
260+
261+ done_event_ = done_event_.FlatMap (parallelize);
262+ }
263+ void wait () override {
264+ // While performing asynchronous execution, wait() method is needed to
265+ // notify the user that the output is ready. oneDNN will not call wait()
266+ // inside the library to avoid deadlock.
267+ tsl::BlockUntilReady (done_event_);
268+ }
269+
270+ tsl::AsyncValueRef<tsl::Chain> done_event () const { return done_event_; }
271+ };
272+
273+ } // namespace testing
274+ } // namespace dnnl
275+
198276#elif defined(DNNL_TEST_THREADPOOL_USE_TBB)
199277#include " tbb/parallel_for.h"
200278#include " tbb/task_arena.h"
0 commit comments