@@ -1644,31 +1644,62 @@ walk_order_t compute_walk_order(const config_t &cfg) {
16441644 size_t ab_bytes = get_memory_footprint (cfg, inner, outer);
16451645 if (ab_bytes <= l3_size) grid_inner = std::move (outer);
16461646 }
1647+
1648+ auto &w_inner = grid_inner[pvars::ow];
1649+ auto &h_inner = grid_inner[pvars::oh];
1650+
1651+ // Prefer square spatial dimensions to increase cache reuse due to iteration
1652+ // over kernel spatial dimensions.
1653+ auto rebalance_hw = [&]() {
1654+ if (!cfg.prb ().is_fwd ) return false ;
1655+ if (grid_tile[pvars::oh] % (h_inner * 2 )) return false ;
1656+ if (w_inner % 2 ) return false ;
1657+ if (w_inner < h_inner * 4 ) return false ;
1658+ return true ;
1659+ };
1660+
1661+ while (rebalance_hw ()) {
1662+ w_inner /= 2 ;
1663+ h_inner *= 2 ;
1664+ }
1665+
16471666 // Add the blocks in this order:
16481667 // - Step 1. Add grid_inner blocks (fitting L3 cache)
16491668 // - Step 2. Add the remaining M/N blocks
16501669 // - Step 3. Add the remaining B/K blocks
16511670 // Within a step follow the default walk order between dimensions.
16521671 walk_order_t walk_order;
16531672 for (int step = 0 ; step < 3 ; step++) {
1654- for (auto &b : default_walk_order.blocks ()) {
1655- switch (step) {
1656- case 0 :
1657- if (grid_inner.has (b.dim )) {
1658- walk_order.add (b.dim , grid_inner[b.dim ], 0 );
1659- }
1660- break ;
1661- case 1 :
1662- case 2 :
1663- dim_t rem = utils::div_up (
1664- grid_tile[b.dim ], grid_inner.get (b.dim , 1 ));
1665- if (rem == 1 ) continue ;
1666- auto bmnk = to_gemm (b.dim , prb);
1667- bool is_bk = utils::one_of (bmnk, pvars::b, pvars::k);
1668- if ((step == 2 ) != is_bk) continue ;
1669- walk_order.add (b.dim , rem, 0 );
1670- break ;
1673+ if (step == 0 ) {
1674+ // Transpose spatial for better reuse
1675+ auto blocks = default_walk_order.blocks ();
1676+ for (size_t i = 0 ; i < blocks.size () - 1 ; i++) {
1677+ if (cfg.prb ().is_fwd && blocks[i].dim == pvars::ow
1678+ && blocks[i + 1 ].dim == pvars::oh) {
1679+ std::swap (blocks[i], blocks[i + 1 ]);
1680+ }
1681+ }
1682+ for (auto &b : blocks) {
1683+ if (grid_inner.has (b.dim )) {
1684+ walk_order.add (b.dim , grid_inner[b.dim ], 0 );
1685+ printf (" step %d: %s: %ld\n " , step, b.dim .str ().c_str (),
1686+ grid_inner[b.dim ]);
1687+ }
16711688 }
1689+ continue ;
1690+ }
1691+
1692+ for (auto &b : default_walk_order.blocks ()) {
1693+ dim_t rem
1694+ = utils::div_up (grid_tile[b.dim ], grid_inner.get (b.dim , 1 ));
1695+ if (rem == 1 ) continue ;
1696+ auto bmnk = to_gemm (b.dim , prb);
1697+ bool is_bk = utils::one_of (bmnk, pvars::b, pvars::k);
1698+ if ((step == 2 ) != is_bk) continue ;
1699+ walk_order.add (b.dim , rem, 0 );
1700+ printf (" step %d: %s: %ld\n " , step, b.dim .str ().c_str (),
1701+ grid_inner[b.dim ]);
1702+ break ;
16721703 }
16731704 }
16741705 walk_order.finalize (grid_tile);
0 commit comments