Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions doc/primitives/matmul.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,11 @@ types for source, destination, weights, and bias tensors:
| Source | Weights | Destination | Bias |
|:-----------------|:---------------------------------------|:---------------------------------|:----------------------------|
| f64 | f64 | f64 | f64, f32, f16, bf16, s8, u8 |
| f32 | f32 | f32 | f32, bf16, f16, u8, s8 |
| f32 | f32, u8, s8, u4, s4 | f32 | f32, bf16, f16, u8, s8 |
| f16 | f16, u8, s8, u4, s4 | f16, u8, s8 | f32 |
| f16 | f16, u8, s8 | f32 | f32, f16 |
| f16 | f16, u8, s8, u4, s4 | f32, f16 | f32, f16 |
| bf16 | bf16, u8, s8, u4, s4 | f32, bf16 | f32, bf16 |
| f32, bf16, f16 | u8, s8 | f32, bf16, f16 | f32, bf16, f16 |
| f32, bf16, f16 | u8, s8 | f32, bf16, f16 | f32, bf16, f16 |
| f32, bf16, f16 | u8, s8, u4, s4 | f32, bf16, f16 | f32, bf16, f16 |
| bf16, f16 | f8_e5m2, f8_e4m3, f4_e2m1, f4_e3m0 | f32, f16, bf16 | f32, bf16, f16 |
| f8_e5m2, f8_e4m3 | f8_e5m2, f8_e4m3 | f32, f16, bf16, f8_e5m2, f8_e4m3 | f32, bf16, f16 |
| f4_e2m1, f4_e3m0 | f4_e2m1, f4_e3m0 | f32, f16, bf16, f4_e2m1, f4_e3m0 | f32, bf16, f16 |
Expand Down Expand Up @@ -146,8 +145,8 @@ The following attributes and post-ops are supported:

| Type | Operation | Description | Restrictions |
|:----------|:---------------------------------------------------------------|:------------------------------------------------------------------------------|:------------------------------------|
| Attribute | [Scales](@ref dnnl::primitive_attr::set_scales_mask) | Scales the result by given scaling factor(s) | |
| Attribute | [Zero-points](@ref dnnl::primitive_attr::set_zero_points_mask) | Sets zero-point(s) for the corresponding tensors | `int8` computations only |
| Attribute | [Scales](@ref dnnl::primitive_attr::set_scales_mask) | Scales the result by given scaling factor(s) | |
| Attribute | [Zero-points](@ref dnnl::primitive_attr::set_zero_points_mask) | Sets zero-point(s) for the corresponding tensors | |
| Attribute | [Dropout](@ref dnnl::primitive_attr::set_dropout) | Applies pseudo-random dropout to destination buffer, also fills mask buffer | |
| Attribute | [Precomputed reductions](@ref dnnl::primitive_attr::set_precomputed_reductions) | Sets precomputed reductions for the corresponding tensors | Requires weight zero-points and full matrix mask |
| Post-op | [Eltwise](@ref dnnl::post_ops::append_eltwise) | Applies an @ref dnnl_api_eltwise operation to the result | |
Expand All @@ -156,9 +155,13 @@ The following attributes and post-ops are supported:
| Post-op | [Prelu](@ref dnnl::post_ops::append_prelu) | Applies an @ref dnnl_api_prelu operation to the result | |

The following masks are supported by the primitive:
- 0, which applies one scale / zero point value to an entire tensor, and
- 2, which applies a scale value per column along the
`n`dimension for `DNNL_ARG_WEIGHTS`.
- 0, which applies one scale / zero point value to an entire tensor
- 1, which applies a scale / zero point values along `k`-dimension
for `DNNL_ARG_WEIGHTS`. Values could be grouped along this dimension
via specifying scales / zero points shapes for the attribute
(see the code example @ref weights_decompression_matmul_cpp).
- 2, which applies a scale / zero point values per column along the
`n`-dimension for `DNNL_ARG_WEIGHTS`.

When scales and/or zero-points masks are specified, the user must
provide the corresponding scales and/or zero-points as additional
Expand Down
14 changes: 7 additions & 7 deletions examples/tutorials/matmul/weights_decompression_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void init_vector(std::vector<float> &v) {
int number_of_runs = 1;

// Create a MatMul primitive descriptor for the following op:
// C_f32 = A_f32 * (B_s8 - zp_B) * sc_B[:]
// C_f32 = A_f32 * (B_s8 - zp_B[:]) * sc_B[:]
//
// Here:
// - Matrices A and C are of f32 data type.
Expand All @@ -96,15 +96,15 @@ matmul::primitive_desc matmul_pd_create(
// Create attributes and indicate that the alpha and zero points are
// runtime parameters
primitive_attr attr;
// Set scales with multiple scales along K and N dimensions and with groups along K.
// Set scales with multiple values along K and N dimensions and with groups along K.
attr.set_scales(DNNL_ARG_WEIGHTS,
/* mask */ (1 << 0) + (1 << 1), {G, 1}, memory::data_type::f32);
// Set a single zero point with s8 data type.
attr.set_zero_points(
DNNL_ARG_WEIGHTS, /* mask */ 0, {}, memory::data_type::s8);
// Set zero points with multiple values along K and N dimensions and with groups along K.
attr.set_zero_points(DNNL_ARG_WEIGHTS, /* mask */ (1 << 0) + (1 << 1),
Copy link
Contributor

@mzhukova mzhukova Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not need to also change a memory object for zero points somewhere in the example..?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's true. Updated!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it seems like the example passed before, even though dimensions of mem obj for zp were not in sync with this setting.. Do you have an understanding why it worked? 👀

{G, 1}, memory::data_type::s8);
// Set fpmath mode with `apply_to_int=true` to apply fpmath mode behavior to
// integral primitives (in this example, matmul).
attr.set_fpmath_mode(fpmath_mode::bf16, true);
attr.set_fpmath_mode(fpmath_mode::strict, true);

// Create a MatMul primitive descriptor
return matmul::primitive_desc(eng, a_md, b_md, c_md, attr);
Expand Down Expand Up @@ -136,7 +136,7 @@ void infer(const matmul &matmul_p, int64_t M, int64_t N, int64_t K, int64_t G,
// De-quantization parameters (eg. Scale and Shift)
const int64_t n_groups = K / G;
memory sc_B_mem({{N, n_groups}, memory::data_type::f32, {1, N}}, eng);
memory zp_B_mem({{1}, memory::data_type::s8, {1}}, eng);
memory zp_B_mem({{N, n_groups}, memory::data_type::s8, {1, N}}, eng);

// the function below fills dnnl::memory with some values
// these memories, typically, come from the previous layers / operations
Expand Down
Loading