blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp Source File

blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp Source File
blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MScaleBlock,
32 index_t NScaleBlock,
33 index_t KScaleBlock,
34 index_t MPerXDL,
35 index_t NPerXDL,
36 index_t MRepeat,
37 index_t NRepeat,
38 index_t KPacks>
42
43template <index_t BlockSize,
44 typename ADataType,
45 typename BDataType,
46 typename ComputeDataType,
47 typename AccDataType,
48 typename ATileDesc,
49 typename BTileDesc,
50 typename AMmaTileDesc,
51 typename BMmaTileDesc,
52 index_t ABlockTransferSrcScalarPerVector,
53 index_t BBlockTransferSrcScalarPerVector,
54 index_t MPerBlock,
55 index_t NPerBlock,
56 index_t KPerBlock,
57 index_t MScaleBlock,
58 index_t NScaleBlock,
59 index_t KScaleBlock,
60 index_t MPerXDL,
61 index_t NPerXDL,
62 index_t MRepeat,
63 index_t NRepeat,
64 index_t KPack
65 // ,bool TransposeC //disable transposec right now...
66 >
69 BlockSize,
70 ADataType,
71 BDataType,
72 ComputeDataType,
73 AccDataType,
74 ATileDesc,
75 BTileDesc,
76 AMmaTileDesc,
77 BMmaTileDesc,
78 ABlockTransferSrcScalarPerVector,
79 BBlockTransferSrcScalarPerVector,
80 MPerBlock,
81 NPerBlock,
82 KPerBlock,
83 MScaleBlock,
84 NScaleBlock,
85 KScaleBlock,
86 MPerXDL,
87 NPerXDL,
88 MRepeat,
89 NRepeat,
90 KPack> : BlockwiseGemmXdlops_pipeline_base<BlockSize,
91 ADataType,
92 BDataType,
93 ComputeDataType,
94 AccDataType,
95 ATileDesc,
96 BTileDesc,
97 AMmaTileDesc,
98 BMmaTileDesc,
99 ABlockTransferSrcScalarPerVector,
100 BBlockTransferSrcScalarPerVector,
101 MPerBlock,
102 NPerBlock,
103 KPerBlock,
104 MPerXDL,
105 NPerXDL,
106 MRepeat,
107 NRepeat,
108 KPack,
109 true>
110
111{
113 ADataType,
114 BDataType,
115 ComputeDataType,
116 AccDataType,
117 ATileDesc,
118 BTileDesc,
119 AMmaTileDesc,
120 BMmaTileDesc,
121 ABlockTransferSrcScalarPerVector,
122 BBlockTransferSrcScalarPerVector,
123 MPerBlock,
124 NPerBlock,
125 KPerBlock,
126 MPerXDL,
127 NPerXDL,
128 MRepeat,
129 NRepeat,
130 KPack,
131 true>;
132 using Base::A_K1;
133 using Base::B_K1;
134 using Base::I0;
135 using Base::I1;
136 using Base::KGroup;
137 using Base::KRepeat;
138 using Base::xdlops_gemm;
139 using typename Base::HotLoopInstList;
140
153
154 using Base::MWaves;
155 using Base::NWaves;
156 using Base::WaveSize;
157
158 static constexpr index_t PrefetchStages = 2;
159 static constexpr index_t PrefillStages = 1;
160 static constexpr index_t GlobalBufferNum = 2;
161
162 template <typename TileDesc_M0_M1_M2_K>
163 __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
164 {
165 constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
166 constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
167 constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
168 constexpr index_t K2 = KPack / KGroup;
169 constexpr index_t K1 = WaveSize / NPerXDL;
170 constexpr index_t K0 = KRepeat * KGroup;
171
173 TileDesc_M0_M1_M2_K{},
181 }
182
183 static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
185
186 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
187 {
188 return num_loop > PrefetchStages;
189 }
190
191 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
192 {
193 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
194 }
195
196 __device__ static constexpr auto HotLoopScheduler()
197 {
198 constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
199 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
200 constexpr auto num_buffer_load_inst_b =
202 constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
203 // B global
205 ignore = i;
206 if constexpr(MPerBlock >= 128 && NPerBlock >= 64)
207 {
208 __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0);
209 }
210 else
211 {
212 __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0);
213 }
214 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
215 });
216
217 // A global
219 ignore = i;
220 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
221 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
222 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
223 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
224 });
225
226 // A local
227 static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}(
228 [&](auto i) {
229 ignore = i;
230 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
231 __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read
232 });
233 }
234
235 template <bool HasMainLoop,
236 int NumKBlockPerScale,
237 TailNumber TailNum,
238 typename AGridDesc,
239 typename ABlockDesc,
240 typename ABlockTransfer,
241 typename AGridBuffer,
242 typename ABlockBuffer,
243 typename ABlockTransferStep,
244 typename BGridDesc,
245 typename BBlockDesc,
246 typename BBlockTransfer,
247 typename BGridBuffer,
248 typename BBlockBuffer,
249 typename BBlockTransferStep,
250 typename CScaleThreadDesc,
251 typename CThreadBuffer,
252 typename AScaleGridBuffer,
253 typename AScaleGridDesc,
254 typename AScaleThreadDesc,
255 typename AScaleThreadTransfer,
256 typename AScaleThreadTransferStep,
257 typename BScaleGridBuffer,
258 typename BScaleGridDesc,
259 typename BScaleThreadDesc,
260 typename BScaleThreadTransfer,
261 typename BScaleThreadTransferStep>
262 __device__ void Run(
263 // ABlockCopy
264 const AGridDesc& a_grid_desc,
265 const ABlockDesc& a_block_desc,
266 ABlockTransfer& a_blockwise_copy,
267 const AGridBuffer& a_grid_buf,
268 ABlockBuffer& a_block_buf,
269 const ABlockTransferStep& a_block_copy_step,
270 // BBlockCopy
271 const BGridDesc& b_grid_desc,
272 const BBlockDesc& b_block_desc,
273 BBlockTransfer& b_blockwise_copy,
274 BBlockTransfer& b_blockwise_copy_up,
275 const BGridBuffer& b_grid_buf,
276 const BGridBuffer& b_grid_buf_up,
277 BBlockBuffer& b_block_buf,
278 const BBlockTransferStep& b_block_copy_step,
279 // CThread
280 const CScaleThreadDesc& c_scale_thread_desc,
281 CThreadBuffer& c_thread_buf,
282 CThreadBuffer& c_thread_buf_up,
283 // AScaleThreadCopy
284 const AScaleGridDesc& a_scale_grid_desc,
285 const AScaleThreadDesc& a_scale_thread_desc,
286 AScaleThreadTransfer& a_scale_thread_copy,
287 const AScaleGridBuffer& a_scale_grid_buf,
288 const AScaleThreadTransferStep& a_scale_thread_copy_step,
289 // BScaleThreadCopy
290 const BScaleGridDesc& b_scale_grid_desc,
291 const BScaleThreadDesc& b_scale_thread_desc,
292 BScaleThreadTransfer& b_scale_thread_copy,
293 BScaleThreadTransfer& b_scale_thread_copy_up,
294 const BScaleGridBuffer& b_scale_grid_buf,
295 const BScaleGridBuffer& b_scale_grid_buf_up,
296 const BScaleThreadTransferStep& b_scale_thread_copy_step,
297 // num_loop
298 index_t num_loop) const
299 {
300 ignore = b_block_desc;
301 ignore = b_block_buf;
302 // __builtin_amdgcn_sched_barrier(0);
304 a_thread_desc_.GetElementSpaceSize());
306 b_thread_desc_.GetElementSpaceSize());
307
308 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
309 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
310 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
311
313 a_scale_thread_desc.GetElementSpaceSize());
315 b_scale_thread_desc.GetElementSpaceSize());
317 b_scale_thread_desc.GetElementSpaceSize());
319 c_scale_thread_desc.GetElementSpaceSize());
321 c_scale_thread_desc.GetElementSpaceSize());
322
323 // Global prefetch A1 B1
324 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
325 b_blockwise_copy.Run(b_grid_desc,
326 b_grid_buf,
328 b_block_origin_idx,
329 b_thread_bufs(I0));
330 b_blockwise_copy_up.Run(b_grid_desc,
331 b_grid_buf_up,
333 b_block_origin_idx,
334 b_thread_bufs_up(I0));
335
336 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
337 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
338 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
339
340 a_scale_thread_copy.Run(a_scale_grid_desc,
341 a_scale_grid_buf,
342 a_scale_thread_desc,
343 make_tuple(I0, I0),
344 a_scale_thread_buf);
345
346 if constexpr(NumKBlockPerScale == 1)
347 {
348 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
349 a_scale_thread_copy_step.At(Number<1>{}));
350 }
351 else
352 {
353 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
354 a_scale_thread_copy_step.At(Number<0>{}));
355 }
356
357 b_scale_thread_copy.Run(b_scale_grid_desc,
358 b_scale_grid_buf,
359 b_scale_thread_desc,
360 make_tuple(I0, I0),
361 b_scale_thread_buf);
362
363 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
364
365 b_scale_thread_copy_up.Run(b_scale_grid_desc,
366 b_scale_grid_buf_up,
367 b_scale_thread_desc,
368 make_tuple(I0, I0),
369 b_scale_thread_buf_up);
370
371 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
372
373 // __builtin_amdgcn_sched_barrier(0);
374
375 constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
376 constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
377 constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
381 constexpr index_t c_offset =
382 CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
383 constexpr index_t a_offset =
384 AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
385 constexpr index_t b_offset =
386 BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
387
388 c_scale_thread_buf(Number<c_offset>{}) =
389 a_scale_thread_buf[Number<a_offset>{}] *
390 b_scale_thread_buf[Number<b_offset>{}];
391 c_scale_thread_buf_up(Number<c_offset>{}) =
392 a_scale_thread_buf[Number<a_offset>{}] *
393 b_scale_thread_buf_up[Number<b_offset>{}];
394 });
395 });
396 });
397
398 // Local prefill A1
399 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
400
401 // Global prefetch A2
402 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
403 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
404
405 a_scale_thread_copy.Run(a_scale_grid_desc,
406 a_scale_grid_buf,
407 a_scale_thread_desc,
408 make_tuple(I0, I0),
409 a_scale_thread_buf);
410
411 if constexpr(NumKBlockPerScale == 1)
412 {
413 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
414 a_scale_thread_copy_step.At(Number<1>{}));
415 }
416 else
417 {
418 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
419 a_scale_thread_copy_step.At(Number<0>{}));
420 }
421
422 b_scale_thread_copy.Run(b_scale_grid_desc,
423 b_scale_grid_buf,
424 b_scale_thread_desc,
425 make_tuple(I0, I0),
426 b_scale_thread_buf);
427
428 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
429
430 b_scale_thread_copy_up.Run(b_scale_grid_desc,
431 b_scale_grid_buf_up,
432 b_scale_thread_desc,
433 make_tuple(I0, I0),
434 b_scale_thread_buf_up);
435
436 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
437
439 AccDataType,
440 1,
441 xdlops_gemm.GetRegSizePerXdlops(),
442 true>
443 c_thread_buf_per_scale;
445 AccDataType,
446 1,
447 xdlops_gemm.GetRegSizePerXdlops(),
448 true>
449 c_thread_buf_per_scale_up;
450
451 // Local prefetch A1
453 static_for<0, MRepeat, 1>{}([&](auto m0) {
454 static_for<0, KRepeat, 1>{}([&](auto k0) {
455 static_for<0, KGroup, 1>{}([&](auto kg0) {
456 a_thread_copy_.Run(
459 a_block_buf,
462 a_thread_buf);
463 });
464 });
465 });
466
467 // Initialize C
468 c_thread_buf.Clear();
469 c_thread_buf_up.Clear();
470
471 // __builtin_amdgcn_sched_barrier(0);
472
473 // main body
474 if constexpr(HasMainLoop)
475 {
476 index_t i = 0;
477 do
478 {
479 auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
480 b_blockwise_copy.Run(b_grid_desc,
481 b_grid_buf,
483 b_block_origin_idx,
484 b_thread_bufs(local_read_buf));
485 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
486
487 b_blockwise_copy_up.Run(b_grid_desc,
488 b_grid_buf_up,
490 b_block_origin_idx,
491 b_thread_bufs_up(local_read_buf));
492 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
494 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
495
496 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
497 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
498
499 static_for<0, MRepeat, 1>{}([&](auto m0) {
500 static_for<0, NRepeat, 1>{}([&](auto n0) {
501 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
502 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
503 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
504 .template AsType<AccDataType>()(Number<t>{}) = 0;
505 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
506 .template AsType<AccDataType>()(Number<t>{}) = 0;
507 });
508 vector_type<AccDataType, 2> c_scale_thread_vec;
509 vector_type<AccDataType, 2> c_scale_thread_vec_up;
510 constexpr index_t cscale_offset =
511 CScaleThreadDesc{}.CalculateOffset(
512 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
513
514 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
515 c_scale_thread_buf[Number<cscale_offset>{}];
516 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
517 c_scale_thread_buf[Number<cscale_offset>{}];
518 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
519 c_scale_thread_buf_up[Number<cscale_offset>{}];
520 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
521 c_scale_thread_buf_up[Number<cscale_offset>{}];
522
523 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
527
528 static_for<0, KPack, 1>{}([&](auto ik) {
529 a_thread_vec.template AsType<ComputeDataType>()(ik) =
530 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
531 make_tuple(m0,
532 I0,
533 I0,
534 kscale0 * KRepeat / num_scale_k_block +
535 k0,
536 I0,
537 ik))>{}];
538 b_thread_vec.template AsType<ComputeDataType>()(ik) =
539 b_thread_bufs[mfma_reg_buf][Number<
540 b_thread_desc_.CalculateOffset(make_tuple(
541 n0,
542 I0,
543 kscale0 * KRepeat / num_scale_k_block + k0,
544 ik))>{}];
545 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
546 b_thread_bufs_up[mfma_reg_buf][Number<
547 b_thread_desc_.CalculateOffset(make_tuple(
548 n0,
549 I0,
550 kscale0 * KRepeat / num_scale_k_block + k0,
551 ik))>{}];
552 });
553
554 using mfma_input_type =
555 typename vector_type<ComputeDataType,
556 xdlops_gemm.K1PerXdlops>::type;
557
558 xdlops_gemm.template Run<>(
559 a_thread_vec.template AsType<mfma_input_type>(),
560 b_thread_vec.template AsType<mfma_input_type>(),
561 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
562 xdlops_gemm.template Run<>(
563 a_thread_vec.template AsType<mfma_input_type>(),
564 b_thread_vec_up.template AsType<mfma_input_type>(),
565 c_thread_buf_per_scale_up.GetVectorTypeReference(
566 Number<0>{}));
567 });
568
569 constexpr index_t c_offset =
570 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
571
572 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}(
573 [&](auto t) {
574 using pk_fma_type =
576
577 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
578 .template AsType<pk_fma_type>()(t) =
579 __builtin_elementwise_fma(
580 c_thread_buf_per_scale
581 .GetVectorTypeReference(Number<0>{})
582 .template AsType<pk_fma_type>()[t],
583 c_scale_thread_vec
584 .template AsType<pk_fma_type>()[Number<0>{}],
585 c_thread_buf
586 .GetVectorTypeReference(Number<c_offset>{})
587 .template AsType<pk_fma_type>()[t]);
588 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
589 .template AsType<pk_fma_type>()(t) =
590 __builtin_elementwise_fma(
591 c_thread_buf_per_scale_up
592 .GetVectorTypeReference(Number<0>{})
593 .template AsType<pk_fma_type>()[t],
594 c_scale_thread_vec_up
595 .template AsType<pk_fma_type>()[Number<0>{}],
596 c_thread_buf_up
597 .GetVectorTypeReference(Number<c_offset>{})
598 .template AsType<pk_fma_type>()[t]);
599 });
600 });
601 });
602 });
603
605
606 static_for<0, MRepeat, 1>{}([&](auto m0) {
607 static_for<0, KRepeat, 1>{}([&](auto k0) {
608 static_for<0, KGroup, 1>{}([&](auto kg0) {
609 a_thread_copy_.Run(
612 a_block_buf,
615 a_thread_buf);
616 });
617 });
618 });
619
621 __builtin_amdgcn_sched_barrier(0);
622
623 static_for<0, MRepeat, 1>{}([&](auto m0) {
626 constexpr index_t c_offset =
627 CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
628 constexpr index_t a_offset =
629 AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
630 constexpr index_t b_offset =
631 BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
632
633 c_scale_thread_buf(Number<c_offset>{}) =
634 a_scale_thread_buf[Number<a_offset>{}] *
635 b_scale_thread_buf[Number<b_offset>{}];
636 c_scale_thread_buf_up(Number<c_offset>{}) =
637 a_scale_thread_buf[Number<a_offset>{}] *
638 b_scale_thread_buf_up[Number<b_offset>{}];
639 });
640 });
641 });
642
643 a_scale_thread_copy.Run(a_scale_grid_desc,
644 a_scale_grid_buf,
645 a_scale_thread_desc,
646 make_tuple(I0, I0),
647 a_scale_thread_buf);
648
649 if constexpr(NumKBlockPerScale == 1)
650 {
651 a_scale_thread_copy.MoveSrcSliceWindow(
652 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
653 }
654 else
655 {
656 a_scale_thread_copy.MoveSrcSliceWindow(
657 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
658 }
659
660 b_scale_thread_copy.Run(b_scale_grid_desc,
661 b_scale_grid_buf,
662 b_scale_thread_desc,
663 make_tuple(I0, I0),
664 b_scale_thread_buf);
665
666 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
667 b_scale_thread_copy_step);
668 b_scale_thread_copy_up.Run(b_scale_grid_desc,
669 b_scale_grid_buf_up,
670 b_scale_thread_desc,
671 make_tuple(I0, I0),
672 b_scale_thread_buf_up);
673
674 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc,
675 b_scale_thread_copy_step);
676 };
677
678 LoopFunc(I0, I1);
679 LoopFunc(I1, I0);
680
681 i += 2;
682 } while(i < (num_loop - 2));
683 }
684
685 // tail
686 if constexpr(TailNum == TailNumber::Even)
687 {
688 b_blockwise_copy.Run(b_grid_desc,
689 b_grid_buf,
691 b_block_origin_idx,
692 b_thread_bufs(I1));
693
694 b_blockwise_copy_up.Run(b_grid_desc,
695 b_grid_buf_up,
697 b_block_origin_idx,
698 b_thread_bufs_up(I1));
700 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
701
702 static_for<0, MRepeat, 1>{}([&](auto m0) {
703 static_for<0, NRepeat, 1>{}([&](auto n0) {
704 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
705 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
706 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
707 .template AsType<AccDataType>()(Number<t>{}) = 0;
708 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
709 .template AsType<AccDataType>()(Number<t>{}) = 0;
710 });
711 vector_type<AccDataType, 2> c_scale_thread_vec;
712 vector_type<AccDataType, 2> c_scale_thread_vec_up;
713 constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
714 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
715
716 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
717 c_scale_thread_buf[Number<cscale_offset>{}];
718 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
719 c_scale_thread_buf[Number<cscale_offset>{}];
720 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
721 c_scale_thread_buf_up[Number<cscale_offset>{}];
722 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
723 c_scale_thread_buf_up[Number<cscale_offset>{}];
724
725 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
729
730 static_for<0, KPack, 1>{}([&](auto ik) {
731 a_thread_vec.template AsType<ComputeDataType>()(ik) =
732 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
733 make_tuple(m0,
734 I0,
735 I0,
736 kscale0 * KRepeat / num_scale_k_block + k0,
737 I0,
738 ik))>{}];
739 b_thread_vec.template AsType<ComputeDataType>()(ik) =
740 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
741 make_tuple(n0,
742 I0,
743 kscale0 * KRepeat / num_scale_k_block + k0,
744 ik))>{}];
745 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
746 b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
747 make_tuple(n0,
748 I0,
749 kscale0 * KRepeat / num_scale_k_block + k0,
750 ik))>{}];
751 });
752
753 using mfma_input_type =
754 typename vector_type<ComputeDataType,
755 xdlops_gemm.K1PerXdlops>::type;
756
757 xdlops_gemm.template Run<>(
758 a_thread_vec.template AsType<mfma_input_type>(),
759 b_thread_vec.template AsType<mfma_input_type>(),
760 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
761 xdlops_gemm.template Run<>(
762 a_thread_vec.template AsType<mfma_input_type>(),
763 b_thread_vec_up.template AsType<mfma_input_type>(),
764 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
765 });
766 constexpr index_t c_offset =
767 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
768
769 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
770 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
771
772 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
773 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
774 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
775 .template AsType<pk_fma_type>()[t],
776 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
777 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
778 .template AsType<pk_fma_type>()[t]);
779 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
780 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
781 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
782 .template AsType<pk_fma_type>()[t],
783 c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
784 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
785 .template AsType<pk_fma_type>()[t]);
786 });
787 });
788 });
789 });
790
791 static_for<0, MRepeat, 1>{}([&](auto m0) {
794 constexpr index_t c_offset =
795 CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
796 constexpr index_t a_offset =
797 AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
798 constexpr index_t b_offset =
799 BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
800
801 c_scale_thread_buf(Number<c_offset>{}) =
802 a_scale_thread_buf[Number<a_offset>{}] *
803 b_scale_thread_buf[Number<b_offset>{}];
804 c_scale_thread_buf_up(Number<c_offset>{}) =
805 a_scale_thread_buf[Number<a_offset>{}] *
806 b_scale_thread_buf_up[Number<b_offset>{}];
807 });
808 });
809 });
810
812
813 static_for<0, MRepeat, 1>{}([&](auto m0) {
814 static_for<0, KRepeat, 1>{}([&](auto k0) {
815 static_for<0, KGroup, 1>{}([&](auto kg0) {
816 a_thread_copy_.Run(
819 a_block_buf,
822 a_thread_buf);
823 });
824 });
825 });
826
827 // __builtin_amdgcn_sched_barrier(0);
828
829 static_for<0, MRepeat, 1>{}([&](auto m0) {
830 static_for<0, NRepeat, 1>{}([&](auto n0) {
831 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
832 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
833 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
834 .template AsType<AccDataType>()(Number<t>{}) = 0;
835 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
836 .template AsType<AccDataType>()(Number<t>{}) = 0;
837 });
838 vector_type<AccDataType, 2> c_scale_thread_vec;
839 vector_type<AccDataType, 2> c_scale_thread_vec_up;
840 constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
841 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
842
843 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
844 c_scale_thread_buf[Number<cscale_offset>{}];
845 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
846 c_scale_thread_buf[Number<cscale_offset>{}];
847 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
848 c_scale_thread_buf_up[Number<cscale_offset>{}];
849 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
850 c_scale_thread_buf_up[Number<cscale_offset>{}];
851
852 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
856
857 static_for<0, KPack, 1>{}([&](auto ik) {
858 a_thread_vec.template AsType<ComputeDataType>()(ik) =
859 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
860 make_tuple(m0,
861 I0,
862 I0,
863 kscale0 * KRepeat / num_scale_k_block + k0,
864 I0,
865 ik))>{}];
866 b_thread_vec.template AsType<ComputeDataType>()(ik) =
867 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
868 make_tuple(n0,
869 I0,
870 kscale0 * KRepeat / num_scale_k_block + k0,
871 ik))>{}];
872 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
873 b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
874 make_tuple(n0,
875 I0,
876 kscale0 * KRepeat / num_scale_k_block + k0,
877 ik))>{}];
878 });
879
880 using mfma_input_type =
881 typename vector_type<ComputeDataType,
882 xdlops_gemm.K1PerXdlops>::type;
883
884 xdlops_gemm.template Run<>(
885 a_thread_vec.template AsType<mfma_input_type>(),
886 b_thread_vec.template AsType<mfma_input_type>(),
887 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
888 xdlops_gemm.template Run<>(
889 a_thread_vec.template AsType<mfma_input_type>(),
890 b_thread_vec_up.template AsType<mfma_input_type>(),
891 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
892 });
893 constexpr index_t c_offset =
894 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
895
896 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
897 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
898
899 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
900 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
901 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
902 .template AsType<pk_fma_type>()[t],
903 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
904 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
905 .template AsType<pk_fma_type>()[t]);
906 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
907 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
908 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
909 .template AsType<pk_fma_type>()[t],
910 c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
911 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
912 .template AsType<pk_fma_type>()[t]);
913 });
914 });
915 });
916 });
917 }
918 else if constexpr(TailNum == TailNumber::Odd)
919 {
920 static_for<0, MRepeat, 1>{}([&](auto m0) {
921 static_for<0, NRepeat, 1>{}([&](auto n0) {
922 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
923 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
924 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
925 .template AsType<AccDataType>()(Number<t>{}) = 0;
926 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
927 .template AsType<AccDataType>()(Number<t>{}) = 0;
928 });
929 vector_type<AccDataType, 2> c_scale_thread_vec;
930 vector_type<AccDataType, 2> c_scale_thread_vec_up;
931 constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
932 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
933
934 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
935 c_scale_thread_buf[Number<cscale_offset>{}];
936 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
937 c_scale_thread_buf[Number<cscale_offset>{}];
938 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
939 c_scale_thread_buf_up[Number<cscale_offset>{}];
940 c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
941 c_scale_thread_buf_up[Number<cscale_offset>{}];
942
943 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
947
948 static_for<0, KPack, 1>{}([&](auto ik) {
949 a_thread_vec.template AsType<ComputeDataType>()(ik) =
950 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
951 make_tuple(m0,
952 I0,
953 I0,
954 kscale0 * KRepeat / num_scale_k_block + k0,
955 I0,
956 ik))>{}];
957 b_thread_vec.template AsType<ComputeDataType>()(ik) =
958 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
959 make_tuple(n0,
960 I0,
961 kscale0 * KRepeat / num_scale_k_block + k0,
962 ik))>{}];
963 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
964 b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
965 make_tuple(n0,
966 I0,
967 kscale0 * KRepeat / num_scale_k_block + k0,
968 ik))>{}];
969 });
970
971 using mfma_input_type =
972 typename vector_type<ComputeDataType,
973 xdlops_gemm.K1PerXdlops>::type;
974
975 xdlops_gemm.template Run<>(
976 a_thread_vec.template AsType<mfma_input_type>(),
977 b_thread_vec.template AsType<mfma_input_type>(),
978 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
979 xdlops_gemm.template Run<>(
980 a_thread_vec.template AsType<mfma_input_type>(),
981 b_thread_vec_up.template AsType<mfma_input_type>(),
982 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
983 });
984 constexpr index_t c_offset =
985 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
986
987 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
988 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
989
990 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
991 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
992 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
993 .template AsType<pk_fma_type>()[t],
994 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
995 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
996 .template AsType<pk_fma_type>()[t]);
997 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
998 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
999 c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
1000 .template AsType<pk_fma_type>()[t],
1001 c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
1002 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
1003 .template AsType<pk_fma_type>()[t]);
1004 });
1005 });
1006 });
1007 });
1008 }
1009 }
1010
1011 protected:
1012 // MRepeat MWave MLane KRepeat KLane KPack
1013 // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
1016
1018 ComputeDataType,
1020 decltype(a_thread_desc_),
1021 Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
1023 5,
1024 A_K1,
1025 A_K1>;
1026
1028
1031
1032 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
1033
1035};
1036
1037} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Vgpr
Definition amd_address_space.hpp:20
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, BBlockTransfer &b_blockwise_copy_up, const BGridBuffer &b_grid_buf, const BGridBuffer &b_grid_buf_up, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const CScaleThreadDesc &c_scale_thread_desc, CThreadBuffer &c_thread_buf, CThreadBuffer &c_thread_buf_up, const AScaleGridDesc &a_scale_grid_desc, const AScaleThreadDesc &a_scale_thread_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const AScaleThreadTransferStep &a_scale_thread_copy_step, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, BScaleThreadTransfer &b_scale_thread_copy_up, const BScaleGridBuffer &b_scale_grid_buf, const BScaleGridBuffer &b_scale_grid_buf_up, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp:262
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack/KGroup >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp:1017
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack, true > Base
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp:112
Definition blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp:40
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10