Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 37 additions & 18 deletions src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3325,7 +3325,6 @@ void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8(
mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
}

int scale_offset = jcp.is_ic_scale * (sizeof(float) * icb * ic_block);
if (jcp.with_bias) {
int bias_offset = jcp.typesize_bia * icb * ic_block;
auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
Expand All @@ -3335,8 +3334,22 @@ void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8(
/* add bias to zmm_accum */
vcvtdq2ps(zmm_out, zmm_out);
const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag);
vmulps(zmm_out_msk, zmm_out,
EVEX_compress_addr(reg_ptr_scales, scale_offset));

if (jcp.with_src_scales) {
mov(reg_ptr_src_scales, ptr[param1 + GET_OFF(src_scales)]);
vmulps(zmm_out_msk, zmm_out,
EVEX_compress_addr(reg_ptr_src_scales, 0, /* bcast = */ true));
}

if (jcp.with_wei_scales) {
mov(reg_ptr_wei_scales, ptr[param1 + GET_OFF(wei_scales)]);
const int scale_offset
= jcp.is_ic_scale * (sizeof(float) * icb * ic_block);
vmulps(zmm_out_msk, zmm_out,
EVEX_compress_addr(reg_ptr_wei_scales, scale_offset,
/* bcast = */ !jcp.is_ic_scale));
}

if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias);

/* Do post-ops */
Expand All @@ -3354,7 +3367,11 @@ void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8(
}
if (maybe_eltwise(1)) eltwise_injector_->compute_vector(zmm_out.getIdx());

if (jcp.dst_scale) { vmulps(zmm_out_msk, zmm_out, zmm_dst_scale); }
if (jcp.with_dst_scales) {
mov(reg_ptr_dst_scales, ptr[param1 + GET_OFF(dst_scales)]);
vmulps(zmm_out_msk, zmm_out,
EVEX_compress_addr(reg_ptr_dst_scales, 0, /* bcast = */ true));
}

// Properly saturate the accumulators for integer datatypes
if (one_of(jcp.dsrc_dt, u8, s8, s32)) {
Expand Down Expand Up @@ -3643,12 +3660,6 @@ void jit_avx512_core_amx_bwd_data_kernel_t::generate() {

if (jcp.with_bias) mov(reg_bias, ptr[param1 + GET_OFF(bias)]);

if (jcp.dst_scale) {
mov(reg_ptr_dst_scales, ptr[param1 + GET_OFF(dst_scale)]);
vmovups(zmm_dst_scale, EVEX_compress_addr(reg_ptr_dst_scales, 0));
}
mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);

mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]);

const int inp_stride = jcp.oc_block_int * jcp.typesize_in;
Expand Down Expand Up @@ -4012,10 +4023,11 @@ status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
jcp.wsp_buffer_size = (size_t)jcp.nb_ih_blocking * jcp.nb_ic_blocking
* jcp.full_tile_width * jcp.ic_block;

const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
jcp.is_ic_scale = wei_scales.get_mask() > 0;
jcp.dst_scale = !dst_scales.has_default_values();
jcp.is_ic_scale = attr.scales_.get_mask(DNNL_ARG_WEIGHTS) > 0;
jcp.with_src_scales = !attr.scales_.get(DNNL_ARG_SRC).has_default_values();
jcp.with_wei_scales
= !attr.scales_.get(DNNL_ARG_WEIGHTS).has_default_values();
jcp.with_dst_scales = !attr.scales_.get(DNNL_ARG_DST).has_default_values();

return status::success;
}
Expand All @@ -4032,10 +4044,15 @@ void jit_avx512_core_amx_bwd_data_kernel_t::init_scratchpad(
assert(jcp.ngroups == 1);
scratchpad.book(key_conv_padded_bias, jcp.ic, jcp.typesize_bia);
}
scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
// One cache-line for each thread for a palette.
scratchpad.book(key_conv_amx_tilecfg, jcp.nthr * AMX_PALETTE_SIZE,
sizeof(char), 0, PAGE_4K);

book_precomputed_scales(
scratchpad, attr.scales_, jcp.ngroups * jcp.ic_without_padding);
if (jcp.with_dst_scales) {
// See brgemm_types.hpp comment for `with_dst_scales`.
scratchpad.book(key_conv_dst_scales,
static_cast<size_t>(jcp.nthr) * sizeof(float), 4096);
}
}

const int jit_avx512_core_amx_bwd_weights_kernel_t::max_ur_w = 32;
Expand Down Expand Up @@ -5529,7 +5546,9 @@ status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_scratchpad(
scratchpad.book(key_conv_padded_bias,
jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.typesize_bia);
}
scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
// One cache-line for each thread for a palette.
scratchpad.book(key_conv_amx_tilecfg, jcp.nthr * AMX_PALETTE_SIZE,
sizeof(char), 0, PAGE_4K);

constexpr size_t scratchpad_limit_by_absolute_value = (size_t)32
<< 30; // 32Gb - TODO: may it's too large?
Expand Down
5 changes: 2 additions & 3 deletions src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ struct jit_avx512_core_amx_bwd_data_kernel_t : public jit_generator_t {
const Xbyak::Reg64 reg_wsp_ptr = r12;

const Xbyak::Reg64 reg_bias = r11;
const Xbyak::Reg64 reg_ptr_scales = r10;
const Xbyak::Reg64 reg_ptr_src_scales = r10;
const Xbyak::Reg64 reg_ptr_wei_scales = r10;
const Xbyak::Reg64 reg_ptr_dst_scales = r10;
const Xbyak::Reg64 reg_ptr_sum_scale = r9;
const Xbyak::Reg64 reg_ptr_sum_zp = abi_not_param1;
Expand All @@ -557,8 +558,6 @@ struct jit_avx512_core_amx_bwd_data_kernel_t : public jit_generator_t {
const Xbyak::Zmm zmm_zero = zmm30;
const Xbyak::Zmm zmm_prev_dst = zmm29;
const Xbyak::Zmm zmm_sum_zp = zmm28;
/* dst scale */
const Xbyak::Zmm &zmm_dst_scale = zmm27;

// AUX: Steps, shifts and offsets
size_t get_inp_ocb_step() const;
Expand Down
202 changes: 9 additions & 193 deletions src/cpu/x64/jit_avx512_core_amx_conv_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,16 @@
#include "common/dnnl_thread.hpp"
#include "common/utils.hpp"

#include "cpu/x64/jit_avx512_core_amx_convolution.hpp"
#include "cpu/x64/jit_primitive_conf.hpp"

namespace dnnl {
namespace impl {
namespace cpu {
namespace x64 {
namespace amx_utils {

using namespace dnnl::impl::memory_tracking::names;
using namespace dnnl::impl::utils;

#define wht_blk_off(d, g, ...) \
(with_groups ? (d).blk_off((g), __VA_ARGS__) : (d).blk_off(__VA_ARGS__))

struct spatial_features_3d_t {

spatial_features_3d_t(const jit_conv_conf_t &jcp)
Expand Down Expand Up @@ -109,15 +105,14 @@ struct spatial_features_3d_t {
inline int get_output_offset() const { return output_offset_; }

private:
const int input_size_;
const int filter_size_;
const int dilate_;
const int stride_;
const int init_pad_; // f_pad
const int end_pad_; // back_pad
const bool is_fast_path_; // 'dilate_ == 1 && stride_ == 1'
const bool
compute_extended_features_; // eq. '(!is_fast_path_) && dilate_ == 1'
int input_size_;
int filter_size_;
int dilate_;
int stride_;
int init_pad_; // f_pad
int end_pad_; // back_pad
bool is_fast_path_; // 'dilate_ == 1 && stride_ == 1'
bool compute_extended_features_; // eq. '(!is_fast_path_) && dilate_ == 1'

int filter_;
int lower_offset_; // d_lo
Expand All @@ -127,185 +122,6 @@ struct spatial_features_3d_t {
int end_overflow_; // d_b_overflow
};

inline void execute_backward_convolution_body(const exec_ctx_t &ctx,
const jit_conv_conf_t &jcp,
const std::unique_ptr<jit_avx512_core_amx_bwd_data_kernel_t> &kernel,
const char *diff_dst, const char *weights, const char *bias,
const float *oscales, const float *dst_scales, char *diff_src,
const memory_desc_wrapper &diff_dst_d,
const memory_desc_wrapper &weights_d, const memory_desc_wrapper &bias_d,
const memory_desc_wrapper &diff_src_d) {
assert(jcp.nb_ic % jcp.nb_ic_blocking == 0);

const bool is_deconv = jcp.prop_kind != prop_kind::backward_data;
const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;

const size_t diff_dst_dt_size = jcp.typesize_in;
const size_t wei_dt_size = jcp.typesize_in;
const size_t bia_dt_size = jcp.typesize_bia;
const size_t diff_src_dt_size = jcp.typesize_out;

const dim_t wei_g_shift = wht_blk_off(weights_d, 1, 0);
const dim_t wei_ic_shift = is_deconv
? wht_blk_off(weights_d, 0, jcp.nb_ic_blocking)
: wht_blk_off(weights_d, 0, 0, jcp.nb_ic_blocking);
const size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);

auto inp_p_buffer = ctx.get_scratchpad_grantor().template get<char>(
key_conv_amx_inp_buffer);
auto wsp = ctx.get_scratchpad_grantor().template get<int32_t>(
key_conv_amx_wsp_buffer);
auto tcfg = ctx.get_scratchpad_grantor().template get<char>(
key_conv_amx_tilecfg);

const int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
const int ih_chunks = utils::div_up(jcp.ih, jcp.ih_blk_size);
const int work_amount
= jcp.mb * jcp.ngroups * jcp.id * ih_chunks * jcp.nb_iw * ic_chunks;

// Initialize the tile configuration in memory, so that each thread can
// load this configuration from memory via `amx_tile_configure(tcfg)`.
if (tcfg) kernel->tile_configure(tcfg);
const bool is_1d = jcp.ndims == 3;
const bool is_3d = jcp.ndims == 5;

parallel(jcp.nthr, [&](const int ithr, const int nthr) {
int start {0}, end {0};
balance211(work_amount, nthr, ithr, start, end);

auto p = jit_conv_args_t();
amx_tile_configure(tcfg);
spatial_features_3d_t sfd(jcp);

int mb {0}, g {0}, id_s {0}, ihc {0}, iwb {0}, icc {0};
nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, id_s, jcp.id, ihc,
ih_chunks, iwb, jcp.nb_iw, icc, ic_chunks);
int last_copied_mb = -1;
int last_copied_id = -1;
int last_copied_ihc = -1;
int last_copied_iwb = -1;
int last_copied_g = -1;
while (start < end) {
char *inp_buffer = inp_p_buffer
+ ithr * jcp.inp_buffer_size * diff_dst_dt_size;

assert(IMPLICATION(
jcp.ngroups > 1, jcp.ic == jcp.ic_without_padding));
int ic = g * jcp.ic + icc * jcp.nb_ic_blocking * jcp.ic_block;
int icb = jcp.is_nspc ? ic : ic / jcp.ic_block;
assert(IMPLICATION(
jcp.ngroups > 1, jcp.oc == jcp.oc_without_padding));
const int ocb = g * (jcp.is_nspc ? jcp.oc : jcp.nb_oc);
auto bias_w = bias ? bias + (bias_d.blk_off(ic) * bia_dt_size)
: nullptr;

const int ih_b = ihc * jcp.ih_blk_size;
const int ih_e = nstl::min(jcp.ih, ih_b + jcp.ih_blk_size);
const int iw = iwb * jcp.iw_block;
bool is_inp_buffer_relevant = true && last_copied_mb == mb
&& last_copied_id == id_s && last_copied_ihc == ihc
&& last_copied_iwb == iwb && last_copied_g == g;

sfd.update_params(id_s);
p.kd_padding = sfd.get_filter_padding();
const int d_lo = sfd.get_lower_offset();
const int d_oj = sfd.get_output_offset();

int ih_step = jcp.nb_ih_blocking;
for (int ih = ih_b; ih < ih_e; ih += ih_step) {
if (!is_inp_buffer_relevant) {
const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
// dox: x-index dilated by strides (dox = ox * stride_x)
const int doh = ih + jcp.t_pad - (gen_kh - 1);
const int dow = iw + jcp.l_pad - (gen_kw - 1);
const int doh_b = ih_b + jcp.t_pad - (gen_kh - 1);
const int doh_l = (jcp.oh - 1) * jcp.stride_h; // last oh
const int dow_l = (jcp.ow - 1) * jcp.stride_w; // last ow

// dox_{s,f}: start and finish indices for copy kernel
const int doh_s = doh + (ih == ih_b ? 0 : gen_kh - 1);
const int doh_f = doh + (ih_step - 1) + (gen_kh - 1);
const int delta_h = doh_f - doh_s + 1;
const int doh_t_overflow = 0 < doh_s && doh_s < doh_l
? nstl::additive_inverse_modulo(doh_s, jcp.stride_h)
: nstl::max(0, -doh_s);
const int doh_b_overflow = 0 < doh_f && doh_f < doh_l
? nstl::modulo(doh_f, jcp.stride_h)
: nstl::max(0, nstl::min(delta_h, doh_f - doh_l));
int dow_s = dow;
int dow_f = dow + jcp.owp - 1;
const int delta_w = dow_f - dow_s + 1;
const int dow_l_overflow = 0 < dow_s && dow_s < dow_l
? nstl::additive_inverse_modulo(dow_s, jcp.stride_w)
: nstl::max(0, -dow_s);
const int dow_r_overflow = 0 < dow_f && dow_f < dow_l
? nstl::modulo(dow_f, jcp.stride_w)
: nstl::max(0, nstl::min(delta_w, dow_f - dow_l));
const int oh_s
= nstl::max(0, utils::div_up(doh_s, jcp.stride_h));
const int ow_s
= nstl::max(0, utils::div_up(dow_s, jcp.stride_w));
// how many real data rows to copy (including padding)
p.t_overflow = nstl::min(delta_h, doh_t_overflow);
p.b_overflow = nstl::min<size_t>(
delta_h - p.t_overflow, doh_b_overflow);
p.kh_padding = nstl::max<size_t>(
0, delta_h - p.t_overflow - p.b_overflow);
p.l_overflow = nstl::min(delta_w, dow_l_overflow);
p.kw_padding = nstl::max<size_t>(
0, delta_w - dow_l_overflow - dow_r_overflow);
p.r_overflow = nstl::min<size_t>(
delta_w - dow_l_overflow, dow_r_overflow);
size_t inp_offset = is_1d
? diff_dst_d.blk_off(mb, ocb, ow_s)
: is_3d
? diff_dst_d.blk_off(mb, ocb, d_oj, oh_s, ow_s)
: diff_dst_d.blk_off(mb, ocb, oh_s, ow_s);
p.src = diff_dst + diff_dst_dt_size * inp_offset;
p.dst = inp_buffer
+ (size_t)(doh_s - doh_b) * jcp.owp
* jcp.oc_block_int * diff_dst_dt_size;

kernel->bwd_data_copy_kernel()(&p);
}

size_t diff_src_offset = is_1d ? diff_src_d.blk_off(mb, icb, iw)
: is_3d ? diff_src_d.blk_off(mb, icb, id_s, ih, iw)
: diff_src_d.blk_off(mb, icb, ih, iw);
p.dst = inp_buffer
+ (size_t)(ih - ih_b) * jcp.owp * jcp.oc_block_int
* diff_dst_dt_size;
p.src = diff_src + diff_src_dt_size * diff_src_offset;
p.filt = weights
+ wei_dt_size
* (g * wei_g_shift + icc * wei_ic_shift
+ d_lo * wht_d_stride);
p.bias = bias_w;
p.scales = &oscales[jcp.is_ic_scale * ic];
p.dst_scale = &dst_scales[0];
p.acc_s32 = wsp + ithr * jcp.wsp_buffer_size;
p.last_h = (ih + ih_step <= ih_e);
p.iwb = iwb;
p.ic_blocks = icc * jcp.nb_ic_blocking;

(*kernel)(&p);
}
last_copied_mb = mb;
last_copied_id = id_s;
last_copied_ihc = ihc;
last_copied_iwb = iwb;
last_copied_g = g;
++start;
nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, id_s, jcp.id, ihc,
ih_chunks, iwb, jcp.nb_iw, icc, ic_chunks);
}
amx_tile_release();
});
}

#undef wht_blk_off

} // namespace amx_utils
} // namespace x64
} // namespace cpu
Expand Down
Loading