diff --git a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp index f5c6a4eeb6e..f9ac821a304 100644 --- a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp @@ -3325,7 +3325,6 @@ void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8( mov(reg_ptr_sum_zp, reinterpret_cast(p_sum_zp)); } - int scale_offset = jcp.is_ic_scale * (sizeof(float) * icb * ic_block); if (jcp.with_bias) { int bias_offset = jcp.typesize_bia * icb * ic_block; auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); @@ -3335,8 +3334,22 @@ void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8( /* add bias to zmm_accum */ vcvtdq2ps(zmm_out, zmm_out); const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag); - vmulps(zmm_out_msk, zmm_out, - EVEX_compress_addr(reg_ptr_scales, scale_offset)); + + if (jcp.with_src_scales) { + mov(reg_ptr_src_scales, ptr[param1 + GET_OFF(src_scales)]); + vmulps(zmm_out_msk, zmm_out, + EVEX_compress_addr(reg_ptr_src_scales, 0, /* bcast = */ true)); + } + + if (jcp.with_wei_scales) { + mov(reg_ptr_wei_scales, ptr[param1 + GET_OFF(wei_scales)]); + const int scale_offset + = jcp.is_ic_scale * (sizeof(float) * icb * ic_block); + vmulps(zmm_out_msk, zmm_out, + EVEX_compress_addr(reg_ptr_wei_scales, scale_offset, + /* bcast = */ !jcp.is_ic_scale)); + } + if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias); /* Do post-ops */ @@ -3354,7 +3367,11 @@ void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8( } if (maybe_eltwise(1)) eltwise_injector_->compute_vector(zmm_out.getIdx()); - if (jcp.dst_scale) { vmulps(zmm_out_msk, zmm_out, zmm_dst_scale); } + if (jcp.with_dst_scales) { + mov(reg_ptr_dst_scales, ptr[param1 + GET_OFF(dst_scales)]); + vmulps(zmm_out_msk, zmm_out, + EVEX_compress_addr(reg_ptr_dst_scales, 0, /* bcast = */ true)); + } // Properly saturate the accumulators for integer datatypes if (one_of(jcp.dsrc_dt, u8, s8, s32)) { @@ -3643,12 +3660,6 @@ void jit_avx512_core_amx_bwd_data_kernel_t::generate() { if (jcp.with_bias) mov(reg_bias, ptr[param1 + GET_OFF(bias)]); - if (jcp.dst_scale) { - mov(reg_ptr_dst_scales, ptr[param1 + GET_OFF(dst_scale)]); - vmovups(zmm_dst_scale, EVEX_compress_addr(reg_ptr_dst_scales, 0)); - } - mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); - mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]); const int inp_stride = jcp.oc_block_int * jcp.typesize_in; @@ -4012,10 +4023,11 @@ status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp, jcp.wsp_buffer_size = (size_t)jcp.nb_ih_blocking * jcp.nb_ic_blocking * jcp.full_tile_width * jcp.ic_block; - const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); - const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); - jcp.is_ic_scale = wei_scales.get_mask() > 0; - jcp.dst_scale = !dst_scales.has_default_values(); + jcp.is_ic_scale = attr.scales_.get_mask(DNNL_ARG_WEIGHTS) > 0; + jcp.with_src_scales = !attr.scales_.get(DNNL_ARG_SRC).has_default_values(); + jcp.with_wei_scales + = !attr.scales_.get(DNNL_ARG_WEIGHTS).has_default_values(); + jcp.with_dst_scales = !attr.scales_.get(DNNL_ARG_DST).has_default_values(); return status::success; } @@ -4032,10 +4044,15 @@ void jit_avx512_core_amx_bwd_data_kernel_t::init_scratchpad( assert(jcp.ngroups == 1); scratchpad.book(key_conv_padded_bias, jcp.ic, jcp.typesize_bia); } - scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline + // One cache-line for each thread for a palette. + scratchpad.book(key_conv_amx_tilecfg, jcp.nthr * AMX_PALETTE_SIZE, + sizeof(char), 0, PAGE_4K); - book_precomputed_scales( - scratchpad, attr.scales_, jcp.ngroups * jcp.ic_without_padding); + if (jcp.with_dst_scales) { + // See brgemm_types.hpp comment for `with_dst_scales`. + scratchpad.book(key_conv_dst_scales, + static_cast(jcp.nthr) * sizeof(float), 4096); + } } const int jit_avx512_core_amx_bwd_weights_kernel_t::max_ur_w = 32; @@ -5529,7 +5546,9 @@ status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_scratchpad( scratchpad.book(key_conv_padded_bias, jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.typesize_bia); } - scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline + // One cache-line for each thread for a palette. + scratchpad.book(key_conv_amx_tilecfg, jcp.nthr * AMX_PALETTE_SIZE, + sizeof(char), 0, PAGE_4K); constexpr size_t scratchpad_limit_by_absolute_value = (size_t)32 << 30; // 32Gb - TODO: may it's too large? diff --git a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp index ae300e35b28..33ea1c2a7a6 100644 --- a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp @@ -532,7 +532,8 @@ struct jit_avx512_core_amx_bwd_data_kernel_t : public jit_generator_t { const Xbyak::Reg64 reg_wsp_ptr = r12; const Xbyak::Reg64 reg_bias = r11; - const Xbyak::Reg64 reg_ptr_scales = r10; + const Xbyak::Reg64 reg_ptr_src_scales = r10; + const Xbyak::Reg64 reg_ptr_wei_scales = r10; const Xbyak::Reg64 reg_ptr_dst_scales = r10; const Xbyak::Reg64 reg_ptr_sum_scale = r9; const Xbyak::Reg64 reg_ptr_sum_zp = abi_not_param1; @@ -557,8 +558,6 @@ struct jit_avx512_core_amx_bwd_data_kernel_t : public jit_generator_t { const Xbyak::Zmm zmm_zero = zmm30; const Xbyak::Zmm zmm_prev_dst = zmm29; const Xbyak::Zmm zmm_sum_zp = zmm28; - /* dst scale */ - const Xbyak::Zmm &zmm_dst_scale = zmm27; // AUX: Steps, shifts and offsets size_t get_inp_ocb_step() const; diff --git a/src/cpu/x64/jit_avx512_core_amx_conv_utils.hpp b/src/cpu/x64/jit_avx512_core_amx_conv_utils.hpp index 3f23c475182..404f46ba5cd 100644 --- a/src/cpu/x64/jit_avx512_core_amx_conv_utils.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_conv_utils.hpp @@ -20,7 +20,7 @@ #include "common/dnnl_thread.hpp" #include "common/utils.hpp" -#include "cpu/x64/jit_avx512_core_amx_convolution.hpp" +#include "cpu/x64/jit_primitive_conf.hpp" namespace dnnl { namespace impl { @@ -28,12 +28,8 @@ namespace cpu { namespace x64 { namespace amx_utils { -using namespace dnnl::impl::memory_tracking::names; using namespace dnnl::impl::utils; -#define wht_blk_off(d, g, ...) \ - (with_groups ? (d).blk_off((g), __VA_ARGS__) : (d).blk_off(__VA_ARGS__)) - struct spatial_features_3d_t { spatial_features_3d_t(const jit_conv_conf_t &jcp) @@ -109,15 +105,14 @@ struct spatial_features_3d_t { inline int get_output_offset() const { return output_offset_; } private: - const int input_size_; - const int filter_size_; - const int dilate_; - const int stride_; - const int init_pad_; // f_pad - const int end_pad_; // back_pad - const bool is_fast_path_; // 'dilate_ == 1 && stride_ == 1' - const bool - compute_extended_features_; // eq. '(!is_fast_path_) && dilate_ == 1' + int input_size_; + int filter_size_; + int dilate_; + int stride_; + int init_pad_; // f_pad + int end_pad_; // back_pad + bool is_fast_path_; // 'dilate_ == 1 && stride_ == 1' + bool compute_extended_features_; // eq. '(!is_fast_path_) && dilate_ == 1' int filter_; int lower_offset_; // d_lo @@ -127,185 +122,6 @@ struct spatial_features_3d_t { int end_overflow_; // d_b_overflow }; -inline void execute_backward_convolution_body(const exec_ctx_t &ctx, - const jit_conv_conf_t &jcp, - const std::unique_ptr &kernel, - const char *diff_dst, const char *weights, const char *bias, - const float *oscales, const float *dst_scales, char *diff_src, - const memory_desc_wrapper &diff_dst_d, - const memory_desc_wrapper &weights_d, const memory_desc_wrapper &bias_d, - const memory_desc_wrapper &diff_src_d) { - assert(jcp.nb_ic % jcp.nb_ic_blocking == 0); - - const bool is_deconv = jcp.prop_kind != prop_kind::backward_data; - const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; - - const size_t diff_dst_dt_size = jcp.typesize_in; - const size_t wei_dt_size = jcp.typesize_in; - const size_t bia_dt_size = jcp.typesize_bia; - const size_t diff_src_dt_size = jcp.typesize_out; - - const dim_t wei_g_shift = wht_blk_off(weights_d, 1, 0); - const dim_t wei_ic_shift = is_deconv - ? wht_blk_off(weights_d, 0, jcp.nb_ic_blocking) - : wht_blk_off(weights_d, 0, 0, jcp.nb_ic_blocking); - const size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); - - auto inp_p_buffer = ctx.get_scratchpad_grantor().template get( - key_conv_amx_inp_buffer); - auto wsp = ctx.get_scratchpad_grantor().template get( - key_conv_amx_wsp_buffer); - auto tcfg = ctx.get_scratchpad_grantor().template get( - key_conv_amx_tilecfg); - - const int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; - const int ih_chunks = utils::div_up(jcp.ih, jcp.ih_blk_size); - const int work_amount - = jcp.mb * jcp.ngroups * jcp.id * ih_chunks * jcp.nb_iw * ic_chunks; - - // Initialize the tile configuration in memory, so that each thread can - // load this configuration from memory via `amx_tile_configure(tcfg)`. - if (tcfg) kernel->tile_configure(tcfg); - const bool is_1d = jcp.ndims == 3; - const bool is_3d = jcp.ndims == 5; - - parallel(jcp.nthr, [&](const int ithr, const int nthr) { - int start {0}, end {0}; - balance211(work_amount, nthr, ithr, start, end); - - auto p = jit_conv_args_t(); - amx_tile_configure(tcfg); - spatial_features_3d_t sfd(jcp); - - int mb {0}, g {0}, id_s {0}, ihc {0}, iwb {0}, icc {0}; - nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, id_s, jcp.id, ihc, - ih_chunks, iwb, jcp.nb_iw, icc, ic_chunks); - int last_copied_mb = -1; - int last_copied_id = -1; - int last_copied_ihc = -1; - int last_copied_iwb = -1; - int last_copied_g = -1; - while (start < end) { - char *inp_buffer = inp_p_buffer - + ithr * jcp.inp_buffer_size * diff_dst_dt_size; - - assert(IMPLICATION( - jcp.ngroups > 1, jcp.ic == jcp.ic_without_padding)); - int ic = g * jcp.ic + icc * jcp.nb_ic_blocking * jcp.ic_block; - int icb = jcp.is_nspc ? ic : ic / jcp.ic_block; - assert(IMPLICATION( - jcp.ngroups > 1, jcp.oc == jcp.oc_without_padding)); - const int ocb = g * (jcp.is_nspc ? jcp.oc : jcp.nb_oc); - auto bias_w = bias ? bias + (bias_d.blk_off(ic) * bia_dt_size) - : nullptr; - - const int ih_b = ihc * jcp.ih_blk_size; - const int ih_e = nstl::min(jcp.ih, ih_b + jcp.ih_blk_size); - const int iw = iwb * jcp.iw_block; - bool is_inp_buffer_relevant = true && last_copied_mb == mb - && last_copied_id == id_s && last_copied_ihc == ihc - && last_copied_iwb == iwb && last_copied_g == g; - - sfd.update_params(id_s); - p.kd_padding = sfd.get_filter_padding(); - const int d_lo = sfd.get_lower_offset(); - const int d_oj = sfd.get_output_offset(); - - int ih_step = jcp.nb_ih_blocking; - for (int ih = ih_b; ih < ih_e; ih += ih_step) { - if (!is_inp_buffer_relevant) { - const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1; - const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1; - // dox: x-index dilated by strides (dox = ox * stride_x) - const int doh = ih + jcp.t_pad - (gen_kh - 1); - const int dow = iw + jcp.l_pad - (gen_kw - 1); - const int doh_b = ih_b + jcp.t_pad - (gen_kh - 1); - const int doh_l = (jcp.oh - 1) * jcp.stride_h; // last oh - const int dow_l = (jcp.ow - 1) * jcp.stride_w; // last ow - - // dox_{s,f}: start and finish indices for copy kernel - const int doh_s = doh + (ih == ih_b ? 0 : gen_kh - 1); - const int doh_f = doh + (ih_step - 1) + (gen_kh - 1); - const int delta_h = doh_f - doh_s + 1; - const int doh_t_overflow = 0 < doh_s && doh_s < doh_l - ? nstl::additive_inverse_modulo(doh_s, jcp.stride_h) - : nstl::max(0, -doh_s); - const int doh_b_overflow = 0 < doh_f && doh_f < doh_l - ? nstl::modulo(doh_f, jcp.stride_h) - : nstl::max(0, nstl::min(delta_h, doh_f - doh_l)); - int dow_s = dow; - int dow_f = dow + jcp.owp - 1; - const int delta_w = dow_f - dow_s + 1; - const int dow_l_overflow = 0 < dow_s && dow_s < dow_l - ? nstl::additive_inverse_modulo(dow_s, jcp.stride_w) - : nstl::max(0, -dow_s); - const int dow_r_overflow = 0 < dow_f && dow_f < dow_l - ? nstl::modulo(dow_f, jcp.stride_w) - : nstl::max(0, nstl::min(delta_w, dow_f - dow_l)); - const int oh_s - = nstl::max(0, utils::div_up(doh_s, jcp.stride_h)); - const int ow_s - = nstl::max(0, utils::div_up(dow_s, jcp.stride_w)); - // how many real data rows to copy (including padding) - p.t_overflow = nstl::min(delta_h, doh_t_overflow); - p.b_overflow = nstl::min( - delta_h - p.t_overflow, doh_b_overflow); - p.kh_padding = nstl::max( - 0, delta_h - p.t_overflow - p.b_overflow); - p.l_overflow = nstl::min(delta_w, dow_l_overflow); - p.kw_padding = nstl::max( - 0, delta_w - dow_l_overflow - dow_r_overflow); - p.r_overflow = nstl::min( - delta_w - dow_l_overflow, dow_r_overflow); - size_t inp_offset = is_1d - ? diff_dst_d.blk_off(mb, ocb, ow_s) - : is_3d - ? diff_dst_d.blk_off(mb, ocb, d_oj, oh_s, ow_s) - : diff_dst_d.blk_off(mb, ocb, oh_s, ow_s); - p.src = diff_dst + diff_dst_dt_size * inp_offset; - p.dst = inp_buffer - + (size_t)(doh_s - doh_b) * jcp.owp - * jcp.oc_block_int * diff_dst_dt_size; - - kernel->bwd_data_copy_kernel()(&p); - } - - size_t diff_src_offset = is_1d ? diff_src_d.blk_off(mb, icb, iw) - : is_3d ? diff_src_d.blk_off(mb, icb, id_s, ih, iw) - : diff_src_d.blk_off(mb, icb, ih, iw); - p.dst = inp_buffer - + (size_t)(ih - ih_b) * jcp.owp * jcp.oc_block_int - * diff_dst_dt_size; - p.src = diff_src + diff_src_dt_size * diff_src_offset; - p.filt = weights - + wei_dt_size - * (g * wei_g_shift + icc * wei_ic_shift - + d_lo * wht_d_stride); - p.bias = bias_w; - p.scales = &oscales[jcp.is_ic_scale * ic]; - p.dst_scale = &dst_scales[0]; - p.acc_s32 = wsp + ithr * jcp.wsp_buffer_size; - p.last_h = (ih + ih_step <= ih_e); - p.iwb = iwb; - p.ic_blocks = icc * jcp.nb_ic_blocking; - - (*kernel)(&p); - } - last_copied_mb = mb; - last_copied_id = id_s; - last_copied_ihc = ihc; - last_copied_iwb = iwb; - last_copied_g = g; - ++start; - nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, id_s, jcp.id, ihc, - ih_chunks, iwb, jcp.nb_iw, icc, ic_chunks); - } - amx_tile_release(); - }); -} - -#undef wht_blk_off - } // namespace amx_utils } // namespace x64 } // namespace cpu diff --git a/src/cpu/x64/jit_avx512_core_amx_convolution.cpp b/src/cpu/x64/jit_avx512_core_amx_convolution.cpp index 8b915c4a5bf..179c63d23e6 100644 --- a/src/cpu/x64/jit_avx512_core_amx_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_convolution.cpp @@ -848,22 +848,189 @@ status_t jit_avx512_core_amx_convolution_bwd_data_t::execute_backward( const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); - // unused in kernel for bf16, but attributes have scales buffer by default - // and using it here simplifies the shared `execute_backward_loop`. - DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); - DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); - DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - - const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); - const float *oscales = scale_utils::precompute_scales( - ctx.get_scratchpad_grantor(), src_scales, wei_scales, pd()->IC(), - pd()->OC(), false, wei_scale_mask > 0, pd()->attr(), - jit_scale_precompute_.get()); - - amx_utils::execute_backward_convolution_body(ctx, pd()->jcp_, kernel_, - diff_dst, weights, nullptr /* no bias */, oscales, dst_scales, - diff_src, diff_dst_d, weights_d, - memory_desc_wrapper(nullptr) /* no bias */, diff_src_d); + const void *src_scales + = CTX_IN_MEM(const void *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); + const void *wei_scales + = CTX_IN_MEM(const void *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); + const void *dst_scales + = CTX_IN_MEM(const void *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_ic % jcp.nb_ic_blocking == 0); + + const size_t diff_dst_dt_size = jcp.typesize_in; + const size_t wei_dt_size = jcp.typesize_in; + const size_t diff_src_dt_size = jcp.typesize_out; + + const dim_t wei_g_shift = wht_blk_off(weights_d, 1, 0); + const dim_t wei_ic_shift = wht_blk_off(weights_d, 0, 0, jcp.nb_ic_blocking); + const size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + auto inp_p_buffer = ctx.get_scratchpad_grantor().template get( + key_conv_amx_inp_buffer); + auto wsp = ctx.get_scratchpad_grantor().template get( + key_conv_amx_wsp_buffer); + auto global_tcfg = ctx.get_scratchpad_grantor().template get( + key_conv_amx_tilecfg); + + const int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + const int ih_chunks = utils::div_up(jcp.ih, jcp.ih_blk_size); + const int work_amount + = jcp.mb * jcp.ngroups * jcp.id * ih_chunks * jcp.nb_iw * ic_chunks; + + const bool is_1d = jcp.ndims == 3; + const bool is_3d = jcp.ndims == 5; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + int start {0}, end {0}; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_conv_args_t(); + + char *const __restrict tcfg = global_tcfg + ithr * AMX_PALETTE_SIZE; + kernel_->tile_configure(tcfg); + amx_tile_configure(tcfg); + + amx_utils::spatial_features_3d_t sfd(jcp); + + float *dst_scales_inv_ptr = nullptr; + if (jcp.with_dst_scales) { + const float *dst_scales_ptr + = static_cast(dst_scales); + dst_scales_inv_ptr + = ctx.get_scratchpad_grantor().template get( + key_conv_dst_scales) + + ithr; + dst_scales_inv_ptr[0] = 1.f / dst_scales_ptr[0]; + } + + int mb {0}, g {0}, id_s {0}, ihc {0}, iwb {0}, icc {0}; + nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, id_s, jcp.id, ihc, + ih_chunks, iwb, jcp.nb_iw, icc, ic_chunks); + int last_copied_mb = -1; + int last_copied_id = -1; + int last_copied_ihc = -1; + int last_copied_iwb = -1; + int last_copied_g = -1; + while (start < end) { + char *inp_buffer = inp_p_buffer + + ithr * jcp.inp_buffer_size * diff_dst_dt_size; + + assert(IMPLICATION( + jcp.ngroups > 1, jcp.ic == jcp.ic_without_padding)); + int ic = g * jcp.ic + icc * jcp.nb_ic_blocking * jcp.ic_block; + int icb = jcp.is_nspc ? ic : ic / jcp.ic_block; + assert(IMPLICATION( + jcp.ngroups > 1, jcp.oc == jcp.oc_without_padding)); + const int ocb = g * (jcp.is_nspc ? jcp.oc : jcp.nb_oc); + + const int ih_b = ihc * jcp.ih_blk_size; + const int ih_e = nstl::min(jcp.ih, ih_b + jcp.ih_blk_size); + const int iw = iwb * jcp.iw_block; + bool is_inp_buffer_relevant = true && last_copied_mb == mb + && last_copied_id == id_s && last_copied_ihc == ihc + && last_copied_iwb == iwb && last_copied_g == g; + + sfd.update_params(id_s); + p.kd_padding = sfd.get_filter_padding(); + const int d_lo = sfd.get_lower_offset(); + const int d_oj = sfd.get_output_offset(); + + int ih_step = jcp.nb_ih_blocking; + for (int ih = ih_b; ih < ih_e; ih += ih_step) { + if (!is_inp_buffer_relevant) { + const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1; + const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1; + // dox: x-index dilated by strides (dox = ox * stride_x) + const int doh = ih + jcp.t_pad - (gen_kh - 1); + const int dow = iw + jcp.l_pad - (gen_kw - 1); + const int doh_b = ih_b + jcp.t_pad - (gen_kh - 1); + const int doh_l = (jcp.oh - 1) * jcp.stride_h; // last oh + const int dow_l = (jcp.ow - 1) * jcp.stride_w; // last ow + + // dox_{s,f}: start and finish indices for copy kernel + const int doh_s = doh + (ih == ih_b ? 0 : gen_kh - 1); + const int doh_f = doh + (ih_step - 1) + (gen_kh - 1); + const int delta_h = doh_f - doh_s + 1; + const int doh_t_overflow = 0 < doh_s && doh_s < doh_l + ? nstl::additive_inverse_modulo(doh_s, jcp.stride_h) + : nstl::max(0, -doh_s); + const int doh_b_overflow = 0 < doh_f && doh_f < doh_l + ? nstl::modulo(doh_f, jcp.stride_h) + : nstl::max(0, nstl::min(delta_h, doh_f - doh_l)); + int dow_s = dow; + int dow_f = dow + jcp.owp - 1; + const int delta_w = dow_f - dow_s + 1; + const int dow_l_overflow = 0 < dow_s && dow_s < dow_l + ? nstl::additive_inverse_modulo(dow_s, jcp.stride_w) + : nstl::max(0, -dow_s); + const int dow_r_overflow = 0 < dow_f && dow_f < dow_l + ? nstl::modulo(dow_f, jcp.stride_w) + : nstl::max(0, nstl::min(delta_w, dow_f - dow_l)); + const int oh_s + = nstl::max(0, utils::div_up(doh_s, jcp.stride_h)); + const int ow_s + = nstl::max(0, utils::div_up(dow_s, jcp.stride_w)); + // how many real data rows to copy (including padding) + p.t_overflow = nstl::min(delta_h, doh_t_overflow); + p.b_overflow = nstl::min( + delta_h - p.t_overflow, doh_b_overflow); + p.kh_padding = nstl::max( + 0, delta_h - p.t_overflow - p.b_overflow); + p.l_overflow = nstl::min(delta_w, dow_l_overflow); + p.kw_padding = nstl::max( + 0, delta_w - dow_l_overflow - dow_r_overflow); + p.r_overflow = nstl::min( + delta_w - dow_l_overflow, dow_r_overflow); + size_t inp_offset = is_1d + ? diff_dst_d.blk_off(mb, ocb, ow_s) + : is_3d + ? diff_dst_d.blk_off(mb, ocb, d_oj, oh_s, ow_s) + : diff_dst_d.blk_off(mb, ocb, oh_s, ow_s); + p.src = diff_dst + diff_dst_dt_size * inp_offset; + p.dst = inp_buffer + + (size_t)(doh_s - doh_b) * jcp.owp + * jcp.oc_block_int * diff_dst_dt_size; + + kernel_->bwd_data_copy_kernel()(&p); + } + + size_t diff_src_offset = is_1d ? diff_src_d.blk_off(mb, icb, iw) + : is_3d ? diff_src_d.blk_off(mb, icb, id_s, ih, iw) + : diff_src_d.blk_off(mb, icb, ih, iw); + p.dst = inp_buffer + + (size_t)(ih - ih_b) * jcp.owp * jcp.oc_block_int + * diff_dst_dt_size; + p.src = diff_src + diff_src_dt_size * diff_src_offset; + p.filt = weights + + wei_dt_size + * (g * wei_g_shift + icc * wei_ic_shift + + d_lo * wht_d_stride); + p.bias = nullptr; + p.src_scales = src_scales; + p.wei_scales = jcp.with_wei_scales + ? static_cast(wei_scales) + + jcp.is_ic_scale * ic + : nullptr; + p.dst_scales = dst_scales_inv_ptr; + p.acc_s32 = wsp + ithr * jcp.wsp_buffer_size; + p.last_h = (ih + ih_step <= ih_e); + p.iwb = iwb; + p.ic_blocks = icc * jcp.nb_ic_blocking; + + (*kernel_)(&p); + } + last_copied_mb = mb; + last_copied_id = id_s; + last_copied_ihc = ihc; + last_copied_iwb = iwb; + last_copied_g = g; + ++start; + nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, id_s, jcp.id, ihc, + ih_chunks, iwb, jcp.nb_iw, icc, ic_chunks); + } + amx_tile_release(); + }); return status::success; } @@ -2009,15 +2176,16 @@ void jit_avx512_core_amx_convolution_bwd_weights_t::execute_backward_weights( const exec_ctx_t &ctx) const { prepare_scratchpad_data(ctx); - auto tcfg = ctx.get_scratchpad_grantor().template get( + auto global_tcfg = ctx.get_scratchpad_grantor().template get( key_conv_amx_tilecfg); - kernel_->tile_configure(tcfg); const auto &jcp = pd()->jcp_; parallel(nthr_, [&](const int ithr, const int nthr) { assert(nthr_ == nthr); assert(utils::one_of(pd()->ndims(), 3, 4, 5)); + char *const __restrict tcfg = global_tcfg + ithr * AMX_PALETTE_SIZE; + kernel_->tile_configure(tcfg); amx_tile_configure(tcfg); thread_info_t thread_info(this, ctx, ithr); diff --git a/src/cpu/x64/jit_avx512_core_amx_convolution.hpp b/src/cpu/x64/jit_avx512_core_amx_convolution.hpp index 7cb739244cf..5e5c766029d 100644 --- a/src/cpu/x64/jit_avx512_core_amx_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_convolution.hpp @@ -30,7 +30,6 @@ #include "cpu/x64/cpu_barrier.hpp" #include "cpu/x64/cpu_reducer.hpp" #include "cpu/x64/jit_avx512_core_amx_conv_kernel.hpp" -#include "cpu/x64/jit_avx512_core_scale_precompute.hpp" #include "cpu/x64/jit_transpose_utils.hpp" namespace dnnl { @@ -124,7 +123,6 @@ struct jit_avx512_core_amx_convolution_fwd_t : public primitive_t { const memory_tracking::grantor_t &scratchpad) const; std::unique_ptr kernel_; - std::unique_ptr jit_scale_precompute_; }; struct jit_avx512_core_amx_convolution_bwd_data_t : public primitive_t { @@ -175,19 +173,6 @@ struct jit_avx512_core_amx_convolution_bwd_data_t : public primitive_t { pd()->jcp_, *pd()->attr()))); CHECK(kernel_->create_kernel()); - // JIT to precompute scales - const bool is_jit_supported = mayiuse(avx512_core); - const auto attr = pd()->attr(); - const auto &attr_scales = attr->scales_; - if (is_jit_supported && pd()->OC() > 1 - && req_copy_scales(attr_scales)) { - int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); - if (wei_scale_mask > 0) { - CHECK(safe_ptr_assign(jit_scale_precompute_, - new jit_avx512_core_scale_precompute_t(attr))); - CHECK(jit_scale_precompute_->create_kernel()); - } - } return status::success; } @@ -205,7 +190,6 @@ struct jit_avx512_core_amx_convolution_bwd_data_t : public primitive_t { const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } std::unique_ptr kernel_; - std::unique_ptr jit_scale_precompute_; }; struct jit_avx512_core_amx_convolution_bwd_weights_t : public primitive_t { diff --git a/src/cpu/x64/jit_avx512_core_amx_deconvolution.cpp b/src/cpu/x64/jit_avx512_core_amx_deconvolution.cpp index cabd338661d..993ab1b86fe 100644 --- a/src/cpu/x64/jit_avx512_core_amx_deconvolution.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_deconvolution.cpp @@ -31,6 +31,7 @@ namespace cpu { namespace x64 { using namespace dnnl::impl::memory_tracking::names; +using namespace dnnl::impl::utils; #define wht_blk_off(d, g, ...) \ (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \ @@ -74,22 +75,190 @@ status_t jit_avx512_core_amx_deconvolution_fwd_t::execute_forward( prepare_padded_bias(bias, ctx.get_scratchpad_grantor()); - DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); - DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); - DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); + const void *src_scales + = CTX_IN_MEM(const void *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); + const void *wei_scales + = CTX_IN_MEM(const void *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); + const void *dst_scales + = CTX_IN_MEM(const void *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); - const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); - const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(), - src_scales, wei_scales, src_d.dims()[1], dst_d.dims()[1], false, - wei_scale_mask > 0, pd()->attr()); + const auto &jcp = pd()->jcp_; + assert(jcp.nb_ic % jcp.nb_ic_blocking == 0); - // The body of bwd/d convolution harness is called with: - // 1. src as input instead of diff_dst - // 2. dst as output instead of diff_src - amx_utils::execute_backward_convolution_body(ctx, pd()->jcp_, kernel_, src, - weights, bias, oscales, dst_scales, dst, src_d, weights_d, bias_d, - dst_d); + const size_t src_dt_size = jcp.typesize_in; + const size_t wei_dt_size = jcp.typesize_in; + const size_t bia_dt_size = jcp.typesize_bia; + const size_t dst_dt_size = jcp.typesize_out; + const dim_t wei_g_shift = wht_blk_off(weights_d, 1, 0); + const dim_t wei_ic_shift = wht_blk_off(weights_d, 0, jcp.nb_ic_blocking); + const size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + auto inp_p_buffer = ctx.get_scratchpad_grantor().template get( + key_conv_amx_inp_buffer); + auto wsp = ctx.get_scratchpad_grantor().template get( + key_conv_amx_wsp_buffer); + auto global_tcfg = ctx.get_scratchpad_grantor().template get( + key_conv_amx_tilecfg); + + const int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + const int ih_chunks = utils::div_up(jcp.ih, jcp.ih_blk_size); + const int work_amount + = jcp.mb * jcp.ngroups * jcp.id * ih_chunks * jcp.nb_iw * ic_chunks; + + const bool is_1d = jcp.ndims == 3; + const bool is_3d = jcp.ndims == 5; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + int start {0}, end {0}; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_conv_args_t(); + + char *const __restrict tcfg = global_tcfg + ithr * AMX_PALETTE_SIZE; + kernel_->tile_configure(tcfg); + amx_tile_configure(tcfg); + + amx_utils::spatial_features_3d_t sfd(jcp); + + float *dst_scales_inv_ptr = nullptr; + if (jcp.with_dst_scales) { + const float *dst_scales_ptr + = static_cast(dst_scales); + dst_scales_inv_ptr + = ctx.get_scratchpad_grantor().template get( + key_conv_dst_scales) + + ithr; + dst_scales_inv_ptr[0] = 1.f / dst_scales_ptr[0]; + } + + int mb {0}, g {0}, id_s {0}, ihc {0}, iwb {0}, icc {0}; + nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, id_s, jcp.id, ihc, + ih_chunks, iwb, jcp.nb_iw, icc, ic_chunks); + int last_copied_mb = -1; + int last_copied_id = -1; + int last_copied_ihc = -1; + int last_copied_iwb = -1; + int last_copied_g = -1; + while (start < end) { + char *inp_buffer + = inp_p_buffer + ithr * jcp.inp_buffer_size * src_dt_size; + + assert(IMPLICATION( + jcp.ngroups > 1, jcp.ic == jcp.ic_without_padding)); + int ic = g * jcp.ic + icc * jcp.nb_ic_blocking * jcp.ic_block; + int icb = jcp.is_nspc ? ic : ic / jcp.ic_block; + assert(IMPLICATION( + jcp.ngroups > 1, jcp.oc == jcp.oc_without_padding)); + const int ocb = g * (jcp.is_nspc ? jcp.oc : jcp.nb_oc); + auto bias_w = bias ? bias + (bias_d.blk_off(ic) * bia_dt_size) + : nullptr; + + const int ih_b = ihc * jcp.ih_blk_size; + const int ih_e = nstl::min(jcp.ih, ih_b + jcp.ih_blk_size); + const int iw = iwb * jcp.iw_block; + bool is_inp_buffer_relevant = true && last_copied_mb == mb + && last_copied_id == id_s && last_copied_ihc == ihc + && last_copied_iwb == iwb && last_copied_g == g; + + sfd.update_params(id_s); + p.kd_padding = sfd.get_filter_padding(); + const int d_lo = sfd.get_lower_offset(); + const int d_oj = sfd.get_output_offset(); + + int ih_step = jcp.nb_ih_blocking; + for (int ih = ih_b; ih < ih_e; ih += ih_step) { + if (!is_inp_buffer_relevant) { + const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1; + const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1; + // dox: x-index dilated by strides (dox = ox * stride_x) + const int doh = ih + jcp.t_pad - (gen_kh - 1); + const int dow = iw + jcp.l_pad - (gen_kw - 1); + const int doh_b = ih_b + jcp.t_pad - (gen_kh - 1); + const int doh_l = (jcp.oh - 1) * jcp.stride_h; // last oh + const int dow_l = (jcp.ow - 1) * jcp.stride_w; // last ow + + // dox_{s,f}: start and finish indices for copy kernel + const int doh_s = doh + (ih == ih_b ? 0 : gen_kh - 1); + const int doh_f = doh + (ih_step - 1) + (gen_kh - 1); + const int delta_h = doh_f - doh_s + 1; + const int doh_t_overflow = 0 < doh_s && doh_s < doh_l + ? nstl::additive_inverse_modulo(doh_s, jcp.stride_h) + : nstl::max(0, -doh_s); + const int doh_b_overflow = 0 < doh_f && doh_f < doh_l + ? nstl::modulo(doh_f, jcp.stride_h) + : nstl::max(0, nstl::min(delta_h, doh_f - doh_l)); + int dow_s = dow; + int dow_f = dow + jcp.owp - 1; + const int delta_w = dow_f - dow_s + 1; + const int dow_l_overflow = 0 < dow_s && dow_s < dow_l + ? nstl::additive_inverse_modulo(dow_s, jcp.stride_w) + : nstl::max(0, -dow_s); + const int dow_r_overflow = 0 < dow_f && dow_f < dow_l + ? nstl::modulo(dow_f, jcp.stride_w) + : nstl::max(0, nstl::min(delta_w, dow_f - dow_l)); + const int oh_s + = nstl::max(0, utils::div_up(doh_s, jcp.stride_h)); + const int ow_s + = nstl::max(0, utils::div_up(dow_s, jcp.stride_w)); + // how many real data rows to copy (including padding) + p.t_overflow = nstl::min(delta_h, doh_t_overflow); + p.b_overflow = nstl::min( + delta_h - p.t_overflow, doh_b_overflow); + p.kh_padding = nstl::max( + 0, delta_h - p.t_overflow - p.b_overflow); + p.l_overflow = nstl::min(delta_w, dow_l_overflow); + p.kw_padding = nstl::max( + 0, delta_w - dow_l_overflow - dow_r_overflow); + p.r_overflow = nstl::min( + delta_w - dow_l_overflow, dow_r_overflow); + size_t inp_offset = is_1d ? src_d.blk_off(mb, ocb, ow_s) + : is_3d ? src_d.blk_off(mb, ocb, d_oj, oh_s, ow_s) + : src_d.blk_off(mb, ocb, oh_s, ow_s); + p.src = src + src_dt_size * inp_offset; + p.dst = inp_buffer + + (size_t)(doh_s - doh_b) * jcp.owp + * jcp.oc_block_int * src_dt_size; + + kernel_->bwd_data_copy_kernel()(&p); + } + + size_t dst_offset = is_1d ? dst_d.blk_off(mb, icb, iw) + : is_3d ? dst_d.blk_off(mb, icb, id_s, ih, iw) + : dst_d.blk_off(mb, icb, ih, iw); + p.dst = inp_buffer + + (size_t)(ih - ih_b) * jcp.owp * jcp.oc_block_int + * src_dt_size; + p.src = dst + dst_dt_size * dst_offset; + p.filt = weights + + wei_dt_size + * (g * wei_g_shift + icc * wei_ic_shift + + d_lo * wht_d_stride); + p.bias = bias_w; + p.src_scales = src_scales; + p.wei_scales = jcp.with_wei_scales + ? static_cast(wei_scales) + + jcp.is_ic_scale * ic + : nullptr; + p.dst_scales = dst_scales_inv_ptr; + p.acc_s32 = wsp + ithr * jcp.wsp_buffer_size; + p.last_h = (ih + ih_step <= ih_e); + p.iwb = iwb; + p.ic_blocks = icc * jcp.nb_ic_blocking; + + (*kernel_)(&p); + } + last_copied_mb = mb; + last_copied_id = id_s; + last_copied_ihc = ihc; + last_copied_iwb = iwb; + last_copied_g = g; + ++start; + nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, id_s, jcp.id, ihc, + ih_chunks, iwb, jcp.nb_iw, icc, ic_chunks); + } + amx_tile_release(); + }); return status::success; } diff --git a/src/cpu/x64/jit_avx512_core_scale_precompute.cpp b/src/cpu/x64/jit_avx512_core_scale_precompute.cpp deleted file mode 100644 index fbbe4ec1f17..00000000000 --- a/src/cpu/x64/jit_avx512_core_scale_precompute.cpp +++ /dev/null @@ -1,203 +0,0 @@ -/******************************************************************************* -* Copyright 2024-2025 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "common/memory_tracking.hpp" - -#include "cpu/x64/cpu_isa_traits.hpp" -#include "cpu/x64/jit_avx512_core_scale_precompute.hpp" -#include "cpu/x64/jit_generator.hpp" - -namespace dnnl { -namespace impl { -namespace cpu { -namespace x64 { - -using namespace Xbyak; - -namespace scale_utils { - -const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, - const float *src_scales, const float *wei_scales, dim_t IC, dim_t OC, - const bool wei_scale_per_ic, const bool wei_scale_per_oc, - const primitive_attr_t *attr, - const jit_avx512_core_scale_precompute_t *const jit_scale_precompute, - float scale_adjust_factor, bool req_transpose) { - - const float *scales = nullptr; - const dim_t wei_scale_count - = (wei_scale_per_ic ? IC : 1) * (wei_scale_per_oc ? OC : 1); - ; - if (jit_scale_precompute) { - const auto &attr_scales = attr->scales_; - const int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); - size_t size = 0; - auto loc_scales = scratchpad.template get( - memory_tracking::names::key_precomputed_scales, &size); - const dim_t count = nstl::min( - static_cast(size / sizeof(float)), wei_scale_count); - const auto wei_scale_stride_ic - = wei_scale_per_ic ? wei_scale_per_oc ? OC : 1 : 0; - const auto with_wei_scale - = !attr_scales.has_default_values(DNNL_ARG_WEIGHTS); - const auto wei_scale_has_groups = with_wei_scale - && !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_groups(); - const auto wei_scale_group_stride = wei_scale_has_groups - ? wei_scale_stride_ic * sizeof(float) - : 0; - - // JIT run-time params - jit_call_t jrp(src_scales, wei_scales, loc_scales, count, - wei_scale_group_stride); - - assert(req_copy_scales(attr_scales, scale_adjust_factor)); - assert(mayiuse(avx512_core)); - assert(wei_scale_mask > 0); - if (wei_scale_has_groups) { - assert(count == wei_scale_count); - const auto wei_scale_groups_ic - = attr_scales.get_group(DNNL_ARG_WEIGHTS, 0); - const dim_t wei_scale_nb_ic = IC / wei_scale_groups_ic; - const auto wei_scale_dt_sz = types::data_type_size( - attr_scales.get_data_type(DNNL_ARG_WEIGHTS)); - for (int nb_ic = 0; nb_ic < wei_scale_nb_ic; nb_ic++) { - const auto offset = nb_ic * wei_scale_stride_ic; - jrp.nelems_ = wei_scale_stride_ic; - jrp.wei_scales_ = (char *)wei_scales + offset * wei_scale_dt_sz; - jrp.scales_ = &loc_scales[offset * wei_scale_groups_ic]; - (*jit_scale_precompute)(&jrp); - } - } else - (*jit_scale_precompute)(&jrp); - - scales = loc_scales; - MAYBE_UNUSED(wei_scale_mask); - } else - scales = cpu::precompute_scales(scratchpad, src_scales, wei_scales, IC, - OC, wei_scale_per_ic, wei_scale_per_oc, attr, - scale_adjust_factor, req_transpose); - return scales; -} - -} // namespace scale_utils - -#define GET_OFF(field) offsetof(scale_utils::jit_call_t, field) - -void jit_avx512_core_scale_precompute_t::store( - const int offset_base, const bool compute_tail) { - mov(reg_aux_dst_scales_, reg_dst_scales_); - const auto addr_offset = static_cast(offset_base) * sizeof(float); - const Vmm vmm_m_dst = compute_tail ? vmm_dst_ | ktail_f32_mask_ : vmm_dst_; - for (size_t g = 0; g < wei_groups_ic_; g++) { - vmovups(ptr[reg_aux_dst_scales_ + addr_offset], vmm_m_dst); - add(reg_aux_dst_scales_, reg_groups_stride_); - } -} - -void jit_avx512_core_scale_precompute_t::cvt2ps(data_type_t type_in, - const Vmm vmm_in, const Xbyak::Operand &op, - const bool mask_flag = false) { - const Vmm vmm = mask_flag ? vmm_in | ktail_f32_mask_ | T_z : vmm_in; - switch (type_in) { - case data_type::f32: vmovups(vmm, op); break; - case data_type::f16: vcvtph2psx(vmm, op); break; - case data_type::bf16: - vpmovzxwd(vmm, op); - vpslld(vmm_in, vmm_in, 0x10); - break; - default: assert(!"unsupported data type"); - } -} - -void jit_avx512_core_scale_precompute_t::compute_scale( - const int offset_base, const bool compute_tail) { - const size_t wei_addr_offset - = static_cast(offset_base) * wei_scales_dsz_; - const Vmm vmm_m_wei_scales = compute_tail - ? vmm_wei_scales_ | ktail_f32_mask_ | T_z - : vmm_wei_scales_; - - cvt2ps(wei_scales_dt_, vmm_wei_scales_, - ptr[reg_wei_scales_ + wei_addr_offset], compute_tail); - if (compute_scale_factor_) - vmulps(vmm_m_wei_scales, vmm_scale_factor_, vmm_m_wei_scales); - - vmulps(vmm_dst_, vmm_m_wei_scales, ptr_b[reg_src_scales_]); - store(offset_base, compute_tail); -} - -void jit_avx512_core_scale_precompute_t::setup_mask() { - mov(reg_mask_, 1); - shl(reg_mask_, reg_tail_.cvt8()); - sub(reg_mask_, 1); - kmovw(ktail_f32_mask_, reg_mask_); -} - -void jit_avx512_core_scale_precompute_t::generate() { - - preamble(); - - // get params - mov(reg_src_scales_, ptr[abi_param1 + GET_OFF(src_scales_)]); - mov(reg_wei_scales_, ptr[abi_param1 + GET_OFF(wei_scales_)]); - mov(reg_dst_scales_, ptr[abi_param1 + GET_OFF(scales_)]); - mov(reg_nelems_, ptr[abi_param1 + GET_OFF(nelems_)]); - mov(reg_groups_stride_, ptr[abi_param1 + GET_OFF(stride_per_groups_)]); - if (compute_scale_factor_) { - const Xmm xmm_scale_factor_(vmm_scale_factor_.getIdx()); - mov(reg_scale_factor_, float2int(scale_adjust_factor_)); - vmovq(xmm_scale_factor_, reg_scale_factor_); - vbroadcastss(vmm_scale_factor_, xmm_scale_factor_); - } - - constexpr int n_unroll = 2; - Xbyak::Label l_simd_loop[n_unroll + 2], l_done; - for (int i = n_unroll; i >= 0; i--) { - const int unroll = 1 << i; // 4, 2, 1 - const size_t addr_step = static_cast(simd_w_) * unroll; - L(l_simd_loop[i + 1]); - { - cmp(reg_nelems_, addr_step); - jl(l_simd_loop[i], T_NEAR); - for (int offset_base = 0; offset_base < unroll; offset_base++) { - compute_scale(offset_base * simd_w_, false); - } - add(reg_wei_scales_, addr_step * wei_scales_dsz_); - add(reg_dst_scales_, addr_step * sizeof(float)); - sub(reg_nelems_, addr_step); - jmp(l_simd_loop[i + 1], T_NEAR); - } - } - L(l_simd_loop[0]); - - test(reg_nelems_, reg_nelems_); - jz(l_done, T_NEAR); - - mov(reg_tail_, reg_nelems_); - setup_mask(); - - compute_scale(0, true); - - L(l_done); - - postamble(); -} - -} // namespace x64 -} // namespace cpu -} // namespace impl -} // namespace dnnl diff --git a/src/cpu/x64/jit_avx512_core_scale_precompute.hpp b/src/cpu/x64/jit_avx512_core_scale_precompute.hpp deleted file mode 100644 index 66c2a80663d..00000000000 --- a/src/cpu/x64/jit_avx512_core_scale_precompute.hpp +++ /dev/null @@ -1,130 +0,0 @@ -/******************************************************************************* -* Copyright 2024-2025 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_X64_JIT_AVX512_CORE_SCALE_PRECOMPUTE_HPP -#define CPU_X64_JIT_AVX512_CORE_SCALE_PRECOMPUTE_HPP - -#include - -#include "common/c_types_map.hpp" -#include "common/nstl.hpp" -#include "common/type_helpers.hpp" - -#include "cpu/scale_utils.hpp" - -#include "cpu/x64/cpu_isa_traits.hpp" -#include "cpu/x64/jit_generator.hpp" - -#include "oneapi/dnnl/dnnl_debug.h" - -namespace dnnl { -namespace impl { -namespace cpu { -namespace x64 { - -struct jit_avx512_core_scale_precompute_t; - -namespace scale_utils { -struct jit_call_t { - jit_call_t(const float *src_scales, const float *wei_scales, float *scales, - size_t nelems, size_t stride_per_groups) - : src_scales_(src_scales) - , wei_scales_(wei_scales) - , scales_(scales) - , nelems_(nelems) - , stride_per_groups_(stride_per_groups) {} - - const void *src_scales_; - const void *wei_scales_; - float *scales_; - size_t nelems_; - size_t stride_per_groups_; -}; - -const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, - const float *src_scales, const float *wei_scales, dim_t IC, dim_t OC, - const bool wei_scale_per_ic, const bool wei_scale_per_oc, - const primitive_attr_t *attr, - const jit_avx512_core_scale_precompute_t *const jit_scale_precompute, - float scale_adjust_factor = 1.0f, bool req_transpose = false); -} // namespace scale_utils - -struct jit_avx512_core_scale_precompute_t : public jit_generator_t { - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_scale_precompute_t) - - jit_avx512_core_scale_precompute_t( - const primitive_attr_t *attr, const float scale_adjust_factor = 1) - : jit_generator_t(jit_name()) - , attr_(attr) - , with_wei_scales_(!attr_->scales_.has_default_values(DNNL_ARG_WEIGHTS)) - , wei_scales_dt_(with_wei_scales_ - ? attr_->scales_.get_data_type(DNNL_ARG_WEIGHTS) - : data_type::f32) - , wei_scales_dsz_(types::data_type_size(wei_scales_dt_)) - , wei_groups_ic_(attr_->scales_.get_group(DNNL_ARG_WEIGHTS, 0)) - , scale_adjust_factor_(scale_adjust_factor) - , compute_scale_factor_(scale_adjust_factor_ != 1) {} - - void generate() override; - - void operator()(scale_utils::jit_call_t *params) const { - jit_generator_t::operator()(params); - msan_unpoison(params->scales_, params->nelems_ * sizeof(float)); - } - -private: - constexpr static int simd_w_ - = cpu_isa_traits_t::vlen / sizeof(float); - using Vmm = typename cpu_isa_traits_t::Vmm; - - const primitive_attr_t *attr_; - const bool with_wei_scales_; - const data_type_t wei_scales_dt_; - const size_t wei_scales_dsz_; - const size_t wei_groups_ic_; - const float scale_adjust_factor_; - const bool compute_scale_factor_; - - Xbyak::Reg64 reg_src_scales_ = r15; - Xbyak::Reg64 reg_wei_scales_ = r14; - Xbyak::Reg64 reg_dst_scales_ = r13; - Xbyak::Reg64 reg_scale_factor_ = r12; - Xbyak::Reg64 reg_nelems_ = r11; - Xbyak::Reg64 reg_groups_stride_ = r10; - Xbyak::Reg64 reg_aux_dst_scales_ = r9; - Xbyak::Reg64 reg_tail_ = rcx; - Xbyak::Reg32 reg_mask_ = eax; - - const Xbyak::Opmask ktail_f32_mask_ = Xbyak::Opmask(1); - - const Vmm vmm_dst_ = Vmm(0); - const Vmm vmm_wei_scales_ = Vmm(1); - const Vmm vmm_scale_factor_ = Vmm(2); - - void setup_mask(); - void store(const int offset_base, const bool compute_tail); - void cvt2ps(data_type_t type_in, const Vmm vmm_in, const Xbyak::Operand &op, - const bool mask_flag); - void compute_scale(const int offset_base, const bool compute_tail); -}; - -} // namespace x64 -} // namespace cpu -} // namespace impl -} // namespace dnnl - -#endif // CPU_X64_JIT_AVX512_CORE_SCALE_PRECOMPUTE_HPP