flatmm_sn_32x128x512_1x4x1_16x16x32.hpp Source File

flatmm_sn_32x128x512_1x4x1_16x16x32.hpp Source File#

Composable Kernel: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp Source File
flatmm_sn_32x128x512_1x4x1_16x16x32.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12// "S"tream update output along "N"
13// A in smem, B load from global
14// require 4 wave, occupancy=1c
16{
17 static constexpr index_t Block_M = 32;
18 static constexpr index_t Block_N = 128;
19 static constexpr index_t Block_K = 512;
20
21 static constexpr index_t WarpPerBlock_M = 1;
22 static constexpr index_t WarpPerBlock_N = 4;
23 static constexpr index_t WarpPerBlock_K = 1;
24
25 static constexpr index_t Warp_M = 16;
26 static constexpr index_t Warp_N = 16;
27 static constexpr index_t Warp_K = 32;
28
29 static constexpr index_t BlockSize = 256;
30
31 // static constexpr index_t KPack = 2; // this is used to gurantee every threads can do dwordx4
32
33 // TODO: note Nr/Kr/W need consider KPack
34 static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
35 static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
36 static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
37
38 static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
39 static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 2
40 static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 16
41
42 static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
43 {
44 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
49 sequence<2, 1>, // !! note here is different
51
53
54 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
55 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
56 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
57 return c_block_dstr;
58 }
59
61 {
62 // y y p p p y
63 // reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
64 // but order is N0*M0*Nv
65 // in LDS we need store as
66 // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
67 // y y wave-id lid/16 lid%16 v
68 constexpr index_t nbufs = 2;
69 return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t) * nbufs;
70 }
71};
72
74{
77
78 // TODO: need paired with tile_window_linear!
79 // TODO: need call init_raw() before call this function!
80 // template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
81 template <typename BRes,
82 typename BCoords,
83 typename ORes,
84 typename OCoords,
85 typename OFlags,
86 typename ScaleTensor>
88 operator()(const BRes& res_b,
89 const BCoords& cached_coords_b,
90 const ORes& res_o,
91 const OCoords& cached_coords_o,
92 const OFlags& o_flags, // this should be in sgpr
93 CK_TILE_LDS_ADDR void* smem,
94 index_t n, // loop along n dim
95 const ScaleTensor& scale_,
96 index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
97 index_t tile_offset_o)
98 {
99 static_assert(BCoords::size() == 8); // 8
100 static_assert(OCoords::size() == 8);
101
102 const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
103 const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
104
105 static_assert(ScaleTensor::size() == 2);
106 float s0 = scale_[number<0>{}];
107 float s1 = scale_[number<1>{}];
108
109 index_t loop_cnt = n / Block_N;
110
111 register float v_c0 asm("v64");
112 register float v_c1 asm("v65");
113 register float v_c2 asm("v66");
114 register float v_c3 asm("v67");
115 register float v_c4 asm("v68");
116 register float v_c5 asm("v69");
117 register float v_c6 asm("v70");
118 register float v_c7 asm("v71");
119 register float v_c8 asm("v72");
120 register float v_c9 asm("v73");
121 register float v_c10 asm("v74");
122 register float v_c11 asm("v75");
123 register float v_c12 asm("v76");
124 register float v_c13 asm("v77");
125 register float v_c14 asm("v78");
126 register float v_c15 asm("v79");
127 register float v_c16 asm("v80");
128 register float v_c17 asm("v81");
129 register float v_c18 asm("v82");
130 register float v_c19 asm("v83");
131 register float v_c20 asm("v84");
132 register float v_c21 asm("v85");
133 register float v_c22 asm("v86");
134 register float v_c23 asm("v87");
135 register float v_c24 asm("v88");
136 register float v_c25 asm("v89");
137 register float v_c26 asm("v90");
138 register float v_c27 asm("v91");
139 register float v_c28 asm("v92");
140 register float v_c29 asm("v93");
141 register float v_c30 asm("v94");
142 register float v_c31 asm("v95");
143 int32_t nan_hi = 0x7fff0000;
144 int32_t nan_lo = 0x00007fff;
145
146 // in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
147 // every threads need 8xK in contiguous register
148 // ... and every wave need the same data
149 int lane_id = threadIdx.x % 64;
150 int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
151 sld_y_os *= 2;
152
153 // y y p p p y
154 // reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
155 // but order is N0*M0*Nv
156 // in LDS we need store as
157 // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
158 // y y wave-id lid/16 lid%16 v
159 // sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
160 int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
161 sfl_sst *= 2;
162
163 // from LDS we need load as
164 // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
165 // ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
166 // sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
167 int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
168 sfl_sld *= 2;
169
170 // B nr->kr
171 // clang-format off
172#pragma clang diagnostic push
173#pragma clang diagnostic ignored "-Winline-asm"
174 asm volatile(
175#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
177 :[smem_]"+r"(smem),
178 [s_loop_cnt]"+s"(loop_cnt),
179 [c0]"+v" (v_c0),
180 [c1]"+v" (v_c1),
181 [c2]"+v" (v_c2),
182 [c3]"+v" (v_c3),
183 [c4]"+v" (v_c4),
184 [c5]"+v" (v_c5),
185 [c6]"+v" (v_c6),
186 [c7]"+v" (v_c7),
187 [c8]"+v" (v_c8),
188 [c9]"+v" (v_c9),
189 [c10]"+v"(v_c10),
190 [c11]"+v"(v_c11),
191 [c12]"+v"(v_c12),
192 [c13]"+v"(v_c13),
193 [c14]"+v"(v_c14),
194 [c15]"+v"(v_c15),
195 [c16]"+v"(v_c16),
196 [c17]"+v"(v_c17),
197 [c18]"+v"(v_c18),
198 [c19]"+v"(v_c19),
199 [c20]"+v"(v_c20),
200 [c21]"+v"(v_c21),
201 [c22]"+v"(v_c22),
202 [c23]"+v"(v_c23),
203 [c24]"+v"(v_c24),
204 [c25]"+v"(v_c25),
205 [c26]"+v"(v_c26),
206 [c27]"+v"(v_c27),
207 [c28]"+v"(v_c28),
208 [c29]"+v"(v_c29),
209 [c30]"+v"(v_c30),
210 [c31]"+v"(v_c31)
211 :
212 [sld_a_base]"n"(0),
213 [shfl_base]"n"(0),
214 [v_sld_y_os]"v"(sld_y_os),
215 [v_sfl_sld]"v"(sfl_sld),
216 [v_sfl_sst]"v"(sfl_sst),
217 [s_res_o0]"s"(res_o[0]),
218 [s_res_o1]"s"(res_o[1]),
219 //[s_res_o2]"s"(res_o[2]),
220 //[s_res_o3]"s"(res_o[3]),
221 [s_res_b0]"s"(res_b[0]),
222 [s_res_b1]"s"(res_b[1]),
223 [s_res_b2]"s"(res_b[2]),
224 [s_res_b3]"s"(res_b[3]),
225 [v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
226 [v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
227 [v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
228 [v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
229 [v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
230 [v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
231 [v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
232 [v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
233 [v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
234 [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
235 [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
236 [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
237 [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
238 [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
239 [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
240 [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
241
242 [s_tile_os_o]"s"(tile_stride_o_bytes),
243 [s_tile_os_b]"s"(tile_stride_b_bytes),
244 [scale_0]"v"(s0),
245 [scale_1]"v"(s1),
246 [v_nan_lo]"v"(nan_lo),
247 [v_nan_hi]"v"(nan_hi),
248 [s_execflag_0]"s"(o_flags[number<0>{}]),
249 [s_execflag_1]"s"(o_flags[number<1>{}]),
250 [s_execflag_2]"s"(o_flags[number<2>{}]),
251 [s_execflag_3]"s"(o_flags[number<3>{}]),
252 [s_execflag_4]"s"(o_flags[number<4>{}]),
253 [s_execflag_5]"s"(o_flags[number<5>{}]),
254 [s_execflag_6]"s"(o_flags[number<6>{}]),
255 [s_execflag_7]"s"(o_flags[number<7>{}])
256 :
257 "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
258 "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
259 "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
260 "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
261 "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
262 "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
263 "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
264 "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
265 "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
266 "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
267 "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
268 "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
269 "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
270 "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
271 "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
272 "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
273 "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
274 "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
275 "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
276 "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
277 "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
278 "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
279 "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
280 "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
281 "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
282 "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
283 "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
284 "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
285 "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
286 "a252", "a253", "a254", "a255",
287 "s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
288 "s36", "s37",
289 "v50", "v54", "v55",
290 "v64","v65","v66","v67","v68","v69","v70","v71",
291 "v72","v73","v74","v75","v76","v77","v78","v79",
292 "v80","v81","v82","v83","v84","v85","v86","v87",
293 "v88","v89","v90","v91","v92","v93","v94","v95",
294 "v128", "v129", "v130", "v131",
295 "v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
296 "v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
297 "v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
298 "v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
299 "v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
300 "v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
301 "v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
302 "v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
303 "v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
304 "v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
305 "v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
306 "v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
307 "v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
308 "v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
309 "v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
310 "v252", "v253", "v254", "v255"
311 );
312#pragma clang diagnostic pop
313 // clang-format on
314 }
315};
316
318{
321
322 // TODO: need paired with tile_window_linear!
323 // TODO: need call init_raw() before call this function!
324 // template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
325 template <typename BRes,
326 typename BCoords,
327 typename ORes,
328 typename OCoords,
329 typename OFlags,
330 typename ScaleTensor>
331 CK_TILE_DEVICE auto
332 operator()(const BRes& res_b,
333 const BCoords& cached_coords_b,
334 const ORes& res_o,
335 const OCoords& cached_coords_o,
336 const OFlags& o_flags, // this should be in sgpr
337 CK_TILE_LDS_ADDR void* smem,
338 index_t n, // loop along n dim
339 const ScaleTensor& scale_,
340 index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
341 index_t tile_offset_o)
342 {
343 static_assert(BCoords::size() == 8); // 8
344 static_assert(OCoords::size() == 8);
345
346 const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
347 const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
348
349 static_assert(ScaleTensor::size() == 2);
350 float s0 = scale_[number<0>{}];
351 float s1 = scale_[number<1>{}];
352
353 index_t loop_cnt = n / Block_N;
354
355 register float v_c0 asm("v64");
356 register float v_c1 asm("v65");
357 register float v_c2 asm("v66");
358 register float v_c3 asm("v67");
359 register float v_c4 asm("v68");
360 register float v_c5 asm("v69");
361 register float v_c6 asm("v70");
362 register float v_c7 asm("v71");
363 register float v_c8 asm("v72");
364 register float v_c9 asm("v73");
365 register float v_c10 asm("v74");
366 register float v_c11 asm("v75");
367 register float v_c12 asm("v76");
368 register float v_c13 asm("v77");
369 register float v_c14 asm("v78");
370 register float v_c15 asm("v79");
371 register float v_c16 asm("v80");
372 register float v_c17 asm("v81");
373 register float v_c18 asm("v82");
374 register float v_c19 asm("v83");
375 register float v_c20 asm("v84");
376 register float v_c21 asm("v85");
377 register float v_c22 asm("v86");
378 register float v_c23 asm("v87");
379 register float v_c24 asm("v88");
380 register float v_c25 asm("v89");
381 register float v_c26 asm("v90");
382 register float v_c27 asm("v91");
383 register float v_c28 asm("v92");
384 register float v_c29 asm("v93");
385 register float v_c30 asm("v94");
386 register float v_c31 asm("v95");
387 int32_t nan_hi = 0x7fff0000;
388 int32_t nan_lo = 0x00007fff;
389
390 // in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
391 // every threads need 8xK in contiguous register
392 // ... and every wave need the same data
393 int lane_id = threadIdx.x % 64;
394 int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
395 sld_y_os *= 2;
396
397 // y y p p p y
398 // reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
399 // but order is N0*M0*Nv
400 // in LDS we need store as
401 // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
402 // y y wave-id lid/16 lid%16 v
403 // sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
404 int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
405 sfl_sst *= 2;
406
407 // from LDS we need load as
408 // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
409 // ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
410 // sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
411 int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
412 sfl_sld *= 2;
413
414 // B nr->kr
415 // clang-format off
416#pragma clang diagnostic push
417#pragma clang diagnostic ignored "-Winline-asm"
418 asm volatile(
419#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
421 :[smem_]"+r"(smem),
422 [s_loop_cnt]"+s"(loop_cnt),
423 [c0]"+v" (v_c0),
424 [c1]"+v" (v_c1),
425 [c2]"+v" (v_c2),
426 [c3]"+v" (v_c3),
427 [c4]"+v" (v_c4),
428 [c5]"+v" (v_c5),
429 [c6]"+v" (v_c6),
430 [c7]"+v" (v_c7),
431 [c8]"+v" (v_c8),
432 [c9]"+v" (v_c9),
433 [c10]"+v"(v_c10),
434 [c11]"+v"(v_c11),
435 [c12]"+v"(v_c12),
436 [c13]"+v"(v_c13),
437 [c14]"+v"(v_c14),
438 [c15]"+v"(v_c15),
439 [c16]"+v"(v_c16),
440 [c17]"+v"(v_c17),
441 [c18]"+v"(v_c18),
442 [c19]"+v"(v_c19),
443 [c20]"+v"(v_c20),
444 [c21]"+v"(v_c21),
445 [c22]"+v"(v_c22),
446 [c23]"+v"(v_c23),
447 [c24]"+v"(v_c24),
448 [c25]"+v"(v_c25),
449 [c26]"+v"(v_c26),
450 [c27]"+v"(v_c27),
451 [c28]"+v"(v_c28),
452 [c29]"+v"(v_c29),
453 [c30]"+v"(v_c30),
454 [c31]"+v"(v_c31)
455 :
456 [sld_a_base]"n"(0),
457 [shfl_base]"n"(0),
458 [v_sld_y_os]"v"(sld_y_os),
459 [v_sfl_sld]"v"(sfl_sld),
460 [v_sfl_sst]"v"(sfl_sst),
461 [s_res_o0]"s"(res_o[0]),
462 [s_res_o1]"s"(res_o[1]),
463 //[s_res_o2]"s"(res_o[2]),
464 //[s_res_o3]"s"(res_o[3]),
465 [s_res_b0]"s"(res_b[0]),
466 [s_res_b1]"s"(res_b[1]),
467 [s_res_b2]"s"(res_b[2]),
468 [s_res_b3]"s"(res_b[3]),
469 [v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
470 [v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
471 [v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
472 [v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
473 [v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
474 [v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
475 [v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
476 [v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
477 [v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
478 [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
479 [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
480 [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
481 [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
482 [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
483 [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
484 [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
485
486 [s_tile_os_o]"s"(tile_stride_o_bytes),
487 [s_tile_os_b]"s"(tile_stride_b_bytes),
488 [scale_0]"v"(s0),
489 [scale_1]"v"(s1),
490 [v_nan_lo]"v"(nan_lo),
491 [v_nan_hi]"v"(nan_hi),
492 [s_execflag_0]"s"(o_flags[number<0>{}]),
493 [s_execflag_1]"s"(o_flags[number<1>{}]),
494 [s_execflag_2]"s"(o_flags[number<2>{}]),
495 [s_execflag_3]"s"(o_flags[number<3>{}]),
496 [s_execflag_4]"s"(o_flags[number<4>{}]),
497 [s_execflag_5]"s"(o_flags[number<5>{}]),
498 [s_execflag_6]"s"(o_flags[number<6>{}]),
499 [s_execflag_7]"s"(o_flags[number<7>{}])
500 :
501 "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
502 "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
503 "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
504 "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
505 "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
506 "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
507 "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
508 "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
509 "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
510 "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
511 "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
512 "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
513 "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
514 "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
515 "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
516 "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
517 "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
518 "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
519 "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
520 "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
521 "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
522 "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
523 "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
524 "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
525 "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
526 "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
527 "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
528 "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
529 "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
530 "a252", "a253", "a254", "a255",
531 "s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
532 "s36", "s37",
533 "v50", "v54", "v55",
534 "v64","v65","v66","v67","v68","v69","v70","v71",
535 "v72","v73","v74","v75","v76","v77","v78","v79",
536 "v80","v81","v82","v83","v84","v85","v86","v87",
537 "v88","v89","v90","v91","v92","v93","v94","v95",
538 "v128", "v129", "v130", "v131",
539 "v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
540 "v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
541 "v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
542 "v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
543 "v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
544 "v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
545 "v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
546 "v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
547 "v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
548 "v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
549 "v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
550 "v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
551 "v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
552 "v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
553 "v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
554 "v252", "v253", "v254", "v255"
555 );
556#pragma clang diagnostic pop
557 // clang-format on
558 }
559};
560
561} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_LDS_ADDR
Definition config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
bfloat16_t bf16_t
Definition bfloat16.hpp:113
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t int32_t
Definition integer.hpp:10
WarpGemmImpl< WarpGemmAttributeMfmaIterateKAndTransposedCDistribution< WarpGemmAttributeMfmaImplF16F16F32M16N16K16< WGAttrCtlEnum::Default_ >, 2, AttrNumAccess > > WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
Definition warp_gemm.hpp:106
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:74
bf16_t BDataType
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:75
CK_TILE_DEVICE auto operator()(const BRes &res_b, const BCoords &cached_coords_b, const ORes &res_o, const OCoords &cached_coords_o, const OFlags &o_flags, CK_TILE_LDS_ADDR void *smem, index_t n, const ScaleTensor &scale_, index_t tile_offset_b, index_t tile_offset_o)
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:88
bf16_t ODataType
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:76
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:16
static constexpr index_t WarpPerBlock_M
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:21
static constexpr index_t WarpPerBlock_N
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:22
static constexpr index_t Block_N
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:18
static constexpr index_t Warp_K
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:27
static constexpr index_t Repeat_M
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:38
static constexpr index_t Block_Nr
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:35
static constexpr index_t WarpPerBlock_K
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:23
static constexpr index_t BlockSize
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:29
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:60
static constexpr index_t Block_K
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:19
static constexpr index_t Block_W
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:34
static constexpr index_t Repeat_K
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:40
static constexpr index_t Block_Kr
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:36
static constexpr index_t Warp_N
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:26
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:42
static constexpr index_t Warp_M
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:25
static constexpr index_t Repeat_N
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:39
static constexpr index_t Block_M
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:17
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:318
CK_TILE_DEVICE auto operator()(const BRes &res_b, const BCoords &cached_coords_b, const ORes &res_o, const OCoords &cached_coords_o, const OFlags &o_flags, CK_TILE_LDS_ADDR void *smem, index_t n, const ScaleTensor &scale_, index_t tile_offset_b, index_t tile_offset_o)
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:332
bf16_t BDataType
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:319
bf16_t ODataType
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:320
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192