Skip to content

Commit f8e69e6

Browse files
committed
xe: conv: jit: rework walk order heuristic
1 parent fd51c7d commit f8e69e6

File tree

1 file changed

+48
-17
lines changed

1 file changed

+48
-17
lines changed

src/gpu/intel/conv/jit/config.cpp

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,31 +1644,62 @@ walk_order_t compute_walk_order(const config_t &cfg) {
16441644
size_t ab_bytes = get_memory_footprint(cfg, inner, outer);
16451645
if (ab_bytes <= l3_size) grid_inner = std::move(outer);
16461646
}
1647+
1648+
auto &w_inner = grid_inner[pvars::ow];
1649+
auto &h_inner = grid_inner[pvars::oh];
1650+
1651+
// Prefer square spatial dimensions to increase cache reuse due to iteration
1652+
// over kernel spatial dimensions.
1653+
auto rebalance_hw = [&]() {
1654+
if (!cfg.prb().is_fwd) return false;
1655+
if (grid_tile[pvars::oh] % (h_inner * 2)) return false;
1656+
if (w_inner % 2) return false;
1657+
if (w_inner < h_inner * 4) return false;
1658+
return true;
1659+
};
1660+
1661+
while (rebalance_hw()) {
1662+
w_inner /= 2;
1663+
h_inner *= 2;
1664+
}
1665+
16471666
// Add the blocks in this order:
16481667
// - Step 1. Add grid_inner blocks (fitting L3 cache)
16491668
// - Step 2. Add the remaining M/N blocks
16501669
// - Step 3. Add the remaining B/K blocks
16511670
// Within a step follow the default walk order between dimensions.
16521671
walk_order_t walk_order;
16531672
for (int step = 0; step < 3; step++) {
1654-
for (auto &b : default_walk_order.blocks()) {
1655-
switch (step) {
1656-
case 0:
1657-
if (grid_inner.has(b.dim)) {
1658-
walk_order.add(b.dim, grid_inner[b.dim], 0);
1659-
}
1660-
break;
1661-
case 1:
1662-
case 2:
1663-
dim_t rem = utils::div_up(
1664-
grid_tile[b.dim], grid_inner.get(b.dim, 1));
1665-
if (rem == 1) continue;
1666-
auto bmnk = to_gemm(b.dim, prb);
1667-
bool is_bk = utils::one_of(bmnk, pvars::b, pvars::k);
1668-
if ((step == 2) != is_bk) continue;
1669-
walk_order.add(b.dim, rem, 0);
1670-
break;
1673+
if (step == 0) {
1674+
// Transpose spatial for better reuse
1675+
auto blocks = default_walk_order.blocks();
1676+
for (size_t i = 0; i < blocks.size() - 1; i++) {
1677+
if (cfg.prb().is_fwd && blocks[i].dim == pvars::ow
1678+
&& blocks[i + 1].dim == pvars::oh) {
1679+
std::swap(blocks[i], blocks[i + 1]);
1680+
}
1681+
}
1682+
for (auto &b : blocks) {
1683+
if (grid_inner.has(b.dim)) {
1684+
walk_order.add(b.dim, grid_inner[b.dim], 0);
1685+
printf("step %d: %s: %ld\n", step, b.dim.str().c_str(),
1686+
grid_inner[b.dim]);
1687+
}
16711688
}
1689+
continue;
1690+
}
1691+
1692+
for (auto &b : default_walk_order.blocks()) {
1693+
dim_t rem
1694+
= utils::div_up(grid_tile[b.dim], grid_inner.get(b.dim, 1));
1695+
if (rem == 1) continue;
1696+
auto bmnk = to_gemm(b.dim, prb);
1697+
bool is_bk = utils::one_of(bmnk, pvars::b, pvars::k);
1698+
if ((step == 2) != is_bk) continue;
1699+
walk_order.add(b.dim, rem, 0);
1700+
printf("step %d: %s: %ld\n", step, b.dim.str().c_str(),
1701+
grid_inner[b.dim]);
1702+
break;
16721703
}
16731704
}
16741705
walk_order.finalize(grid_tile);

0 commit comments

Comments
 (0)