diff --git a/doc/primitives/matmul.md b/doc/primitives/matmul.md index a39b9616e55..9bc293601f7 100644 --- a/doc/primitives/matmul.md +++ b/doc/primitives/matmul.md @@ -98,12 +98,11 @@ types for source, destination, weights, and bias tensors: | Source | Weights | Destination | Bias | |:-----------------|:---------------------------------------|:---------------------------------|:----------------------------| | f64 | f64 | f64 | f64, f32, f16, bf16, s8, u8 | -| f32 | f32 | f32 | f32, bf16, f16, u8, s8 | +| f32 | f32, u8, s8, u4, s4 | f32 | f32, bf16, f16, u8, s8 | | f16 | f16, u8, s8, u4, s4 | f16, u8, s8 | f32 | -| f16 | f16, u8, s8 | f32 | f32, f16 | +| f16 | f16, u8, s8, u4, s4 | f32, f16 | f32, f16 | | bf16 | bf16, u8, s8, u4, s4 | f32, bf16 | f32, bf16 | -| f32, bf16, f16 | u8, s8 | f32, bf16, f16 | f32, bf16, f16 | -| f32, bf16, f16 | u8, s8 | f32, bf16, f16 | f32, bf16, f16 | +| f32, bf16, f16 | u8, s8, u4, s4 | f32, bf16, f16 | f32, bf16, f16 | | bf16, f16 | f8_e5m2, f8_e4m3, f4_e2m1, f4_e3m0 | f32, f16, bf16 | f32, bf16, f16 | | f8_e5m2, f8_e4m3 | f8_e5m2, f8_e4m3 | f32, f16, bf16, f8_e5m2, f8_e4m3 | f32, bf16, f16 | | f4_e2m1, f4_e3m0 | f4_e2m1, f4_e3m0 | f32, f16, bf16, f4_e2m1, f4_e3m0 | f32, bf16, f16 | @@ -146,8 +145,8 @@ The following attributes and post-ops are supported: | Type | Operation | Description | Restrictions | |:----------|:---------------------------------------------------------------|:------------------------------------------------------------------------------|:------------------------------------| -| Attribute | [Scales](@ref dnnl::primitive_attr::set_scales_mask) | Scales the result by given scaling factor(s) | | -| Attribute | [Zero-points](@ref dnnl::primitive_attr::set_zero_points_mask) | Sets zero-point(s) for the corresponding tensors | `int8` computations only | +| Attribute | [Scales](@ref dnnl::primitive_attr::set_scales_mask) | Scales the result by given scaling factor(s) | | +| Attribute | [Zero-points](@ref dnnl::primitive_attr::set_zero_points_mask) | Sets zero-point(s) for the corresponding tensors | | | Attribute | [Dropout](@ref dnnl::primitive_attr::set_dropout) | Applies pseudo-random dropout to destination buffer, also fills mask buffer | | | Attribute | [Precomputed reductions](@ref dnnl::primitive_attr::set_precomputed_reductions) | Sets precomputed reductions for the corresponding tensors | Requires weight zero-points and full matrix mask | | Post-op | [Eltwise](@ref dnnl::post_ops::append_eltwise) | Applies an @ref dnnl_api_eltwise operation to the result | | @@ -156,9 +155,13 @@ The following attributes and post-ops are supported: | Post-op | [Prelu](@ref dnnl::post_ops::append_prelu) | Applies an @ref dnnl_api_prelu operation to the result | | The following masks are supported by the primitive: -- 0, which applies one scale / zero point value to an entire tensor, and -- 2, which applies a scale value per column along the - `n`dimension for `DNNL_ARG_WEIGHTS`. +- 0, which applies one scale / zero point value to an entire tensor +- 1, which applies a scale / zero point values along `k`-dimension + for `DNNL_ARG_WEIGHTS`. Values could be grouped along this dimension + via specifying scales / zero points shapes for the attribute + (see the code example @ref weights_decompression_matmul_cpp). +- 2, which applies a scale / zero point values per column along the + `n`-dimension for `DNNL_ARG_WEIGHTS`. When scales and/or zero-points masks are specified, the user must provide the corresponding scales and/or zero-points as additional diff --git a/examples/tutorials/matmul/weights_decompression_matmul.cpp b/examples/tutorials/matmul/weights_decompression_matmul.cpp index 905480b75a3..8679b8ec95f 100644 --- a/examples/tutorials/matmul/weights_decompression_matmul.cpp +++ b/examples/tutorials/matmul/weights_decompression_matmul.cpp @@ -78,7 +78,7 @@ void init_vector(std::vector &v) { int number_of_runs = 1; // Create a MatMul primitive descriptor for the following op: -// C_f32 = A_f32 * (B_s8 - zp_B) * sc_B[:] +// C_f32 = A_f32 * (B_s8 - zp_B[:]) * sc_B[:] // // Here: // - Matrices A and C are of f32 data type. @@ -96,15 +96,15 @@ matmul::primitive_desc matmul_pd_create( // Create attributes and indicate that the alpha and zero points are // runtime parameters primitive_attr attr; - // Set scales with multiple scales along K and N dimensions and with groups along K. + // Set scales with multiple values along K and N dimensions and with groups along K. attr.set_scales(DNNL_ARG_WEIGHTS, /* mask */ (1 << 0) + (1 << 1), {G, 1}, memory::data_type::f32); - // Set a single zero point with s8 data type. - attr.set_zero_points( - DNNL_ARG_WEIGHTS, /* mask */ 0, {}, memory::data_type::s8); + // Set zero points with multiple values along K and N dimensions and with groups along K. + attr.set_zero_points(DNNL_ARG_WEIGHTS, /* mask */ (1 << 0) + (1 << 1), + {G, 1}, memory::data_type::s8); // Set fpmath mode with `apply_to_int=true` to apply fpmath mode behavior to // integral primitives (in this example, matmul). - attr.set_fpmath_mode(fpmath_mode::bf16, true); + attr.set_fpmath_mode(fpmath_mode::strict, true); // Create a MatMul primitive descriptor return matmul::primitive_desc(eng, a_md, b_md, c_md, attr); @@ -136,7 +136,7 @@ void infer(const matmul &matmul_p, int64_t M, int64_t N, int64_t K, int64_t G, // De-quantization parameters (eg. Scale and Shift) const int64_t n_groups = K / G; memory sc_B_mem({{N, n_groups}, memory::data_type::f32, {1, N}}, eng); - memory zp_B_mem({{1}, memory::data_type::s8, {1}}, eng); + memory zp_B_mem({{N, n_groups}, memory::data_type::s8, {1, N}}, eng); // the function below fills dnnl::memory with some values // these memories, typically, come from the previous layers / operations