Skip to content

Commit 12dec1e

Browse files
fixes
1 parent b7720f2 commit 12dec1e

File tree

4 files changed

+40
-10
lines changed

4 files changed

+40
-10
lines changed

cpp/oneapi/dal/algo/linear_regression/backend/cpu/train_kernel_norm_eq.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ static train_result<Task> call_daal_kernel(const context_cpu& ctx,
117117
const table& data,
118118
const table& resp) {
119119
using dal::detail::check_mul_overflow;
120-
120+
std::cout << "here cpu branch" << std::endl;
121121
using model_t = model<Task>;
122122
using model_impl_t = detail::model_impl<Task>;
123123

cpp/oneapi/dal/algo/linear_regression/backend/gpu/train_kernel_norm_eq_dpc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ static train_result<Task> call_dal_kernel(const context_gpu& ctx,
4545
const table& data,
4646
const table& resp) {
4747
using dal::detail::check_mul_overflow;
48-
48+
std::cout << "here gpu branch" << std::endl;
4949
using model_t = model<Task>;
5050
using model_impl_t = detail::model_impl<Task>;
5151

cpp/oneapi/dal/backend/dispatcher.hpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,7 @@ struct kernel_dispatcher<kernel_spec<single_node_cpu_kernel, CpuKernel>> {
212212
// We have to specify return type for this lambda as compiler cannot
213213
// infer it from a body that consist of single `throw` expression
214214
using msg = detail::error_messages;
215-
throw unimplemented{
216-
msg::spmd_version_of_algorithm_is_not_implemented_for_this_device()
217-
};
215+
throw unimplemented{ msg::algorithm_is_not_implemented_for_this_device() };
218216
});
219217
}
220218
template <typename... Args>
@@ -318,8 +316,39 @@ struct kernel_dispatcher<kernel_spec<single_node_cpu_kernel, CpuKernel>,
318316
return dispatch_by_device(
319317
policy.get_local(),
320318
[&]() -> gpu_kernel_return_t<GpuKernel, Args...> {
321-
// We have to specify return type for this lambda as compiler cannot
322-
// infer it from a body that consist of single `throw` expression
319+
return CpuKernel{}(context_cpu{}, std::forward<Args>(args)...);
320+
},
321+
[&]() {
322+
return GpuKernel{}(context_gpu{ policy }, std::forward<Args>(args)...);
323+
});
324+
}
325+
};
326+
327+
/// Dispatcher for the case of multi-node CPU algorithm based on universal SPMD kernel and
328+
/// multi-node GPU algorithm based on universal SPMD kernel
329+
template <typename CpuKernel, typename GpuKernel>
330+
struct kernel_dispatcher<kernel_spec<universal_spmd_cpu_kernel, CpuKernel>,
331+
kernel_spec<universal_spmd_gpu_kernel, GpuKernel>> {
332+
template <typename... Args>
333+
auto operator()(const detail::spmd_host_policy& policy, Args&&... args) const {
334+
return dispatch_by_device(
335+
policy,
336+
[&]() {
337+
return CpuKernel{}(context_cpu{ policy }, std::forward<Args>(args)...);
338+
},
339+
[&]() {
340+
using msg = detail::error_messages;
341+
throw unimplemented{
342+
msg::spmd_version_of_algorithm_is_not_implemented_for_this_device()
343+
};
344+
});
345+
}
346+
347+
template <typename... Args>
348+
auto operator()(const detail::spmd_data_parallel_policy& policy, Args&&... args) const {
349+
return dispatch_by_device(
350+
policy.get_local(),
351+
[&]() {
323352
using msg = detail::error_messages;
324353
throw unimplemented{
325354
msg::spmd_version_of_algorithm_is_not_implemented_for_this_device()

samples/oneapi/dpc/mpi/sources/linear_regression_distr_mpi.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void run(sycl::queue &queue) {
4545
const auto y_test =
4646
dal::read<dal::table>(queue, dal::csv::data_source{ test_response_file_name });
4747

48-
auto comm = dal::preview::spmd::make_communicator<dal::preview::spmd::backend::mpi>();
48+
auto comm = dal::preview::spmd::make_communicator<dal::preview::spmd::backend::mpi>(queue);
4949
auto rank_id = comm.get_rank();
5050
auto rank_count = comm.get_rank_count();
5151

@@ -59,12 +59,13 @@ void run(sycl::queue &queue) {
5959
const auto result_train =
6060
dal::preview::train(comm, lr_desc, x_train_vec.at(rank_id), y_train_vec.at(rank_id));
6161

62-
const auto result_infer = dal::infer(lr_desc, x_test, result_train.get_model());
62+
const auto result_infer =
63+
dal::preview::infer(comm, lr_desc, x_test_vec.at(rank_id), result_train.get_model());
6364

6465
if (comm.get_rank() == 0) {
6566
std::cout << "Prediction results:\n" << result_infer.get_responses() << std::endl;
6667

67-
std::cout << "Ground truth:\n" << y_test << std::endl;
68+
std::cout << "Ground truth:\n" << y_test_vec.at(rank_id) << std::endl;
6869
}
6970
}
7071

0 commit comments

Comments
 (0)