-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[GPU] Dynamic Dst Scale #4245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[GPU] Dynamic Dst Scale #4245
Conversation
|
make test |
f58ea11 to
0c409b2
Compare
src/gpu/intel/mx_scale.cl
Outdated
| float scale_val | ||
| = cvt_e8m0_to_f32(cvt_f32_to_e8m0(max_group)) / DST_DATA_FMAX; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CPU reference implementation also seems to be missing it (and the comment is missing a closing paren), but
oneDNN/src/cpu/matmul/ref_matmul.cpp
Lines 324 to 327 in a031d34
| // MXSPEC does round_down_pow2(dst_d.data_type() / | |
| // round_down_pow2(max_dst_group) so the rounding | |
| // to a power of two happens before the division, | |
| // and not after. |
| float scale_val | |
| = cvt_e8m0_to_f32(cvt_f32_to_e8m0(max_group)) / DST_DATA_FMAX; | |
| #define E8M0(x) cvt_e8m0_to_f32(cvt_f32_to_e8m0(x)) | |
| float scale_val = E8M0(max_group) / E8M0(DST_DATA_FMAX); | |
| #undef E8M0 |
Since scale_val can be outside the range of e8m0, this will need an additional outer E8M0. Without it, we'd be scaling with and storing different values. Consider: all values in the group are zero.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dzarukin a similar change is required to make the benchdnn ref implementation align with this behavior, added as part of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summoning @mgouicem as an architect of the feature.
| for (int i = 0; i < 16; i += 4) { | ||
| h->mov(4, tmp.ud(0)(1), max.ud(i)(1)); | ||
| h->sel(4 | ge, max.f(0), max.f(0)(1), tmp.f(0)(1)); | ||
| } | ||
| h->mov(2, tmp.ud(0)(1), max.ud(2)(1)); | ||
| h->sel(2 | ge, max, max.f(0)(1), tmp.f(0)(1)); | ||
| h->mov(2, tmp.ud(0)(1), max.ud(1)(1)); | ||
| h->sel(1 | ge, max, max.f(0)(1), tmp.f(0)(1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this work:
| for (int i = 0; i < 16; i += 4) { | |
| h->mov(4, tmp.ud(0)(1), max.ud(i)(1)); | |
| h->sel(4 | ge, max.f(0), max.f(0)(1), tmp.f(0)(1)); | |
| } | |
| h->mov(2, tmp.ud(0)(1), max.ud(2)(1)); | |
| h->sel(2 | ge, max, max.f(0)(1), tmp.f(0)(1)); | |
| h->mov(2, tmp.ud(0)(1), max.ud(1)(1)); | |
| h->sel(1 | ge, max, max.f(0)(1), tmp.f(0)(1)); | |
| h->sel(8 | ge, max.ud(0)(1), max.ud(0)(2), max.ud(1)(2)); | |
| h->sel(4 | ge, max.ud(0)(1), max.ud(0)(2), max.ud(1)(2)); | |
| h->sel(3 | ge, max.ud(0)(1), max.ud(0)(2), max.ud(1)(2)); | |
| h->sel(1 | ge, max.ud(0)(1), max.ud(0)(0), max.ud(1)(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess there'd need to be some special handling of nan/infs before this sequence to avoid propagating them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@atkassen -- you can add 0x80400000:ud to the inputs prior to the max, and then subtract after the max.
Another optimization to consider (maybe not in this PR) is a fully vectorized horizontal reduction, where you recombine partly reduced vectors as you go, so that you can get full SIMD usage at each stage -- vISA example here (DEFINE_HREDUCE16_FLOAT).
0c409b2 to
b4cb984
Compare
b4cb984 to
3237d6b
Compare
|
make test |
820c954 to
158d43c
Compare
|
make test |
158d43c to
5fc7b98
Compare
|
make test |
Description
Enable MXFP4/FP8 dynamic dst scale generation in ref, JIT.
Fixes # MFDNN-14330
Checklist
General
make testandmake test_benchdnn_*) pass locally for each commit?