gridwise_moe_gemm.hpp Source File

gridwise_moe_gemm.hpp Source File#

Composable Kernel: gridwise_moe_gemm.hpp Source File
gridwise_moe_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
18
19#define DEBUG_LOG 0
20
21namespace ck {
22
23// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24// kernel function Blockers:
25// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26// two lds chunks.
27// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28// buffer when we declare __shared__ inside blkgemmpipe
29
35
36template <typename GridwiseGemm,
37 bool HasMainKBlockLoop,
38 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
39 index_t MinimumOccupancy = 1,
41__global__ void
42#if CK_USE_LAUNCH_BOUNDS
43__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
44#endif
45 // __attribute__((amdgpu_waves_per_eu(1, 1)))
46 kernel_moe_gemm(typename GridwiseGemm::Argument karg)
47{
48#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
49 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
50 {
51 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52
53 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54
55 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56 karg.p_sorted_token_ids,
57 karg.p_sorted_expert_ids,
58 karg.p_max_token_id,
59 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
61 karg.p_ds_grid,
62 karg.p_c_grid,
63 p_shared,
64 karg,
65 karg.a_element_op,
66 karg.b_element_op,
67 karg.c_element_op);
68 }
69#else
70 ignore = karg;
71#endif // end of if (defined(__gfx9__))
72}
73
74template <typename GridwiseGemm,
75 bool HasMainKBlockLoop,
76 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
77 index_t MinimumOccupancy = 1,
79__global__ void
80#if CK_USE_LAUNCH_BOUNDS
81__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
82#endif
83 // __attribute__((amdgpu_waves_per_eu(1, 1)))
84 kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
85{
86#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
87 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
88 {
89 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
90 __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
91
92 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
93
94 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
95 karg.p_sorted_token_ids,
96 karg.p_sorted_expert_ids,
97 karg.p_max_token_id,
98 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
99 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
100 karg.p_ds_grid,
101 karg.p_c_grid,
102 p_shared,
103 p_shared1,
104 karg,
105 karg.a_element_op,
106 karg.b_element_op,
107 karg.c_element_op);
108 }
109#else
110 ignore = karg;
111#endif // end of if (defined(__gfx9__))
112}
113
114template <typename ALayout,
115 typename BLayout,
116 typename DsLayout,
117 typename CLayout,
118 typename ADataType,
119 typename BDataType,
120 typename AccDataType,
121 typename CShuffleDataType,
122 typename DsDataType,
123 typename CDataType,
124 typename AElementwiseOperation,
125 typename BElementwiseOperation,
126 typename CElementwiseOperation,
128 index_t BlockSize,
129 index_t MPerBlock,
130 index_t NPerBlock,
131 index_t KPerBlock,
132 index_t AK1Value,
133 index_t BK1Value,
134 index_t MPerXdl,
135 index_t NPerXdl,
136 index_t MXdlPerWave,
137 index_t NXdlPerWave,
138 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
139 typename ABlockTransferThreadClusterArrangeOrder,
140 typename ABlockTransferSrcAccessOrder,
141 index_t ABlockTransferSrcVectorDim,
142 index_t ABlockTransferSrcScalarPerVector,
143 index_t ABlockTransferDstScalarPerVector_AK1,
144 bool AThreadTransferSrcResetCoordinateAfterRun,
145 index_t ABlockLdsExtraM,
146 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
147 typename BBlockTransferThreadClusterArrangeOrder,
148 typename BBlockTransferSrcAccessOrder,
149 index_t BBlockTransferSrcVectorDim,
150 index_t BBlockTransferSrcScalarPerVector,
151 index_t BBlockTransferDstScalarPerVector_BK1,
152 bool BThreadTransferSrcResetCoordinateAfterRun,
153 index_t BBlockLdsExtraN,
154 index_t CShuffleMXdlPerWavePerShuffle,
155 index_t CShuffleNXdlPerWavePerShuffle,
156 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
157 typename CDEShuffleBlockTransferScalarPerVectors,
160 index_t ActivationOperation = 0,
161 bool NSwizzle = false,
162 bool IsInputGemm = true,
163 bool MulRoutedWeight = true,
164 bool PerTokenQuant = false,
165 typename IndexType = index_t,
166 typename ComputeTypeA = CDataType,
167 typename ComputeTypeB = ComputeTypeA,
168 typename LDSTypeA = ADataType,
169 typename LDSTypeB = BDataType>
171{
172 static constexpr auto I0 = Number<0>{};
173 static constexpr auto I1 = Number<1>{};
174 static constexpr auto I2 = Number<2>{};
175 static constexpr auto I3 = Number<3>{};
176 static constexpr auto I4 = Number<4>{};
177 static constexpr auto I5 = Number<5>{};
178 static constexpr auto I6 = Number<6>{};
179 static constexpr auto I7 = Number<7>{};
180
182 CDEShuffleBlockTransferScalarPerVectors{}[I0];
183 // K1 should be Number<...>
184 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
185 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
186 static constexpr auto AK1Number = Number<AK1Value>{};
187 static constexpr auto BK1Number = Number<BK1Value>{};
188 static constexpr auto BlockSizeNumber = Number<BlockSize>{};
189
190 static constexpr index_t NumDTensor = DsDataType::Size();
191
193 static constexpr index_t KPack =
195 static constexpr index_t KLane =
197
198 static constexpr index_t KGroup = []() {
200 // On gfx950, we have a mfma that required 32 f8 elements as input,
201 // splited into 2 groups of 16 f8 elements.
202 // the 2 groups is not contiguous in the B preshuffed layout.
203 // and we do not want it to be contiguous in the B preshuffled layout
204 // because a memory instruction can only read 16 f8 elements at a time.
205 return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
206 else
207 return 1;
208 }();
209
210 static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
211
212 static constexpr index_t NLane = NPerXdl;
213 static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
214 // static constexpr index_t NumTokens = 1;
215 static constexpr index_t SortedTileSize = MPerBlock;
216
217 static constexpr auto MakeDsGridPointer()
218 {
219 return generate_tuple(
220 [&](auto i) {
221 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
222
223 return static_cast<const DDataType*>(nullptr);
224 },
226 }
227
228 using DsGridPointer = decltype(MakeDsGridPointer());
229
231
232 static constexpr index_t APackedSize = []() {
234 return 2;
235 else
236 return 1;
237 }();
238
239 static constexpr index_t BPackedSize = []() {
241 return 2;
242 else
243 return 1;
244 }();
245
246 __host__ static auto CalculateGridSize(index_t M, index_t N)
247 {
248 const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
249 const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
250 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
251 const index_t gridy = NSwizzle ? 1 : mblock;
252
253 return std::make_tuple(gridx, gridy, 1);
254 }
255
256 __host__ __device__ static auto CalculateMPadded(index_t M)
257 {
258 return math::integer_least_multiple(M, MPerBlock);
259 }
260
261 __host__ __device__ static auto CalculateNPadded(index_t N)
262 {
263 return math::integer_least_multiple(N, NPerBlock);
264 }
265
266 __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
267 {
269 }
270 __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
271 {
273 }
274
275 __host__ __device__ static auto CalculateKPadded(index_t K)
276 {
277 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
278 }
279
280 __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
281 {
282 auto K_t = K_Batch * KPerBlock;
283 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
284 }
285
286 __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
287 {
288 auto K_t = K_Batch * KPerBlock;
289 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
290 }
291
292 __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
293 {
294 auto K_t = K_Batch * KPerBlock;
295 return (K + K_t - 1) / K_t * KPerBlock;
296 }
297
298 __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
299 {
300 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
301 auto K_t = K_Batch * KReadVec;
302 return (K + K_t - 1) / K_t * KReadVec;
303 }
304
305 __host__ __device__ static auto CalculateMBlock(index_t M)
306 {
307 return math::integer_divide_ceil(M, MPerBlock);
308 }
309
310 __host__ __device__ static auto CalculateNBlock(index_t N)
311 {
312 return math::integer_divide_ceil(N, NPerBlock);
313 }
314
315 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
316 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
317 {
318 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
319 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
320
322 TileDesc_K0_MN_K1{},
328 }
329
330 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
331 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
332 {
333 const auto a_grid_desc_mraw_kraw = [&]() {
335 {
336 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
337 }
339 {
340 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
341 }
342 }();
343
344 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
345
346 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
347 GemmSpec == GemmSpecialization::MNKPadding)
348 {
349 // pad both M and K
350 const auto a_grid_desc_m_k =
351 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
353 make_right_pad_transform(K, KPad - K)),
356
357 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
358 a_grid_desc_m_k,
363
364 return a_grid_desc_ak0_m_ak1;
365 }
366 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
367 GemmSpec == GemmSpecialization::MNPadding)
368 {
369 // pad M, but not K
370 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
371 a_grid_desc_mraw_kraw,
373 make_right_pad_transform(M, MPad - M)),
376
377 return a_grid_desc_ak0_m_ak1;
378 }
379 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
380 GemmSpec == GemmSpecialization::NKPadding)
381 {
382 // pad K, but not M
383 const auto a_grid_desc_m_k = transform_tensor_descriptor(
384 a_grid_desc_mraw_kraw,
388
389 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
390 a_grid_desc_m_k,
395
396 return a_grid_desc_ak0_m_ak1;
397 }
398 else
399 {
400 // not pad M or K
401 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
402 a_grid_desc_mraw_kraw,
407
408 return a_grid_desc_ak0_m_ak1;
409 }
410 }
411
412 __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
413 {
414 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
415 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
416 constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack / KGroup>{};
418 make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
419 make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
420 }
421
422 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
423 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
424 {
425 const auto b_grid_desc_nraw_kraw = [&]() {
427 {
428 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
429 }
431 {
432 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
433 }
434 }();
435
436 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
437
439 GemmSpec != GemmSpecialization::Default),
440 "pk_i4_t does not support padding");
441
442 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
443 GemmSpec == GemmSpecialization::MNKPadding)
444 {
445 // pad both N and K
446 const auto b_grid_desc_n_k =
447 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
449 make_right_pad_transform(K, KPad - K)),
452
453 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
454 b_grid_desc_n_k,
459
460 return b_grid_desc_bk0_n_bk1;
461 }
462 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
463 GemmSpec == GemmSpecialization::MNPadding)
464 {
465 // pad N, but not K
466 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
467 b_grid_desc_nraw_kraw,
469 make_right_pad_transform(N, NPad - N)),
472
473 return b_grid_desc_bk0_n_bk1;
474 }
475 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
476 GemmSpec == GemmSpecialization::MKPadding)
477 {
478 // pad K, but not N
479 const auto b_grid_desc_n_k = transform_tensor_descriptor(
480 b_grid_desc_nraw_kraw,
484
485 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
486 b_grid_desc_n_k,
491
492 return b_grid_desc_bk0_n_bk1;
493 }
494 else
495 {
496 // not pad N or K
497 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
498 b_grid_desc_nraw_kraw,
503
504 return b_grid_desc_bk0_n_bk1;
505 }
506 }
507
508 template <typename ABlockDesc_AK0_M_AK1>
509 __host__ __device__ static constexpr auto
510 MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
511 {
512 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
513
514 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
515 }
516
517 template <typename BBlockDesc_BK0_N_BK1>
518 __host__ __device__ static constexpr auto
519 MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
520 {
521 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
522 }
523
524 template <typename ELayout>
525 __host__ __device__ static auto MakeCGridDescriptor_M_N(
526 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
527 {
528 const auto c_grid_desc_mraw_nraw = [&]() {
530 {
531 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
532 }
534 {
535 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
536 }
537 }();
538
539 // pad M and N
540 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
542 make_right_pad_transform(N, NPad - N)),
545 }
546
547 template <typename DLayout>
548 __host__ __device__ static auto
550 {
551 const auto c_grid_desc_mraw_nraw = [&]() {
553 {
554 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
555 }
557 {
558 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
559 }
560 }();
561
562 // pad M and N
563 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
565 make_right_pad_transform(N, NPad - N)),
568 }
569
570 __host__ __device__ static auto MakeDsGridDescriptor_M_N(
571 index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
572 {
573 return generate_tuple(
574 [&](auto i) {
575 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
576 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
577 },
579 }
580
581 template <typename DsGridDesc>
583 const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
584 {
585 return generate_tuple(
586 [&](auto i) {
588 ds_grid_desc_m_n[i], MBlock, NBlock);
589 },
591 }
592
593 struct Problem
594 {
595 __host__ __device__ Problem(index_t NumTokens_,
596 index_t TopK_,
597 index_t M_,
598 index_t N_,
599 index_t K_,
600 index_t StrideA_,
601 index_t StrideB_,
602 std::array<index_t, NumDTensor> StrideDs_,
603 index_t StrideC_,
604 index_t KBatch_)
605 : NumTokens{NumTokens_},
606 TopK{TopK_},
607 M{M_},
608 N{N_},
609 K{K_},
610 StrideA{StrideA_},
611 StrideB{StrideB_},
612 StrideDs{StrideDs_},
613 StrideC{StrideC_},
614 KBatch{KBatch_},
617 KRead{CalculateKRead(K_, KBatch_)},
618 KPadded{CalculateKPadded(K_, KBatch_)},
619 AK0{CalculateAK0Padded(K_, KBatch_)},
620 BK0{CalculateBK0Padded(K_, KBatch_)},
623 {
624 }
625
626 __host__ void Print() const
627 {
628 std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
629 << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
630 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
631 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
632 << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
633 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
634 << "NBlock: " << NBlock << "}" << std::endl;
635 }
636
644 std::array<index_t, NumDTensor> StrideDs;
655 };
656
657 // Argument
659 {
660 __host__ Argument(const index_t* p_sorted_token_ids_,
661 const index_t* p_sorted_expert_ids_,
662 const index_t* p_max_token_id_,
663 const ADataType* p_a_grid_,
664 const BDataType* p_b_grid_,
665 std::array<const void*, NumDTensor> p_ds_grid_,
666 CDataType* p_c_grid_,
667 index_t NumTokens_,
668 index_t TopK_,
669 index_t M_,
670 index_t N_,
671 index_t K_,
672 index_t StrideA_,
673 index_t StrideB_,
674 std::array<index_t, NumDTensor> StrideDs_,
675 index_t StrideC_,
676 index_t k_batch_,
677 AElementwiseOperation a_element_op_,
678 BElementwiseOperation b_element_op_,
679 CElementwiseOperation c_element_op_)
680 : Problem{NumTokens_,
681 TopK_,
682 M_,
683 N_,
684 K_,
685 StrideA_,
686 StrideB_,
687 StrideDs_,
688 StrideC_,
689 k_batch_},
690 p_sorted_token_ids{p_sorted_token_ids_},
691 p_sorted_expert_ids{p_sorted_expert_ids_},
692 p_max_token_id{p_max_token_id_},
693 p_a_grid{p_a_grid_},
694 p_b_grid{p_b_grid_},
695 p_ds_grid{},
696 p_c_grid{p_c_grid_},
697 a_element_op{a_element_op_},
698 b_element_op{b_element_op_},
699 c_element_op{c_element_op_}
700 {
701
702 // populate pointer, desc for Ds
703 static_for<0, NumDTensor, 1>{}([&](auto i) {
704 using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
705
706 // D pointer
707 p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
708 });
709 }
710
714 const ADataType* p_a_grid;
715 const BDataType* p_b_grid;
717 CDataType* p_c_grid;
718
719 const AElementwiseOperation a_element_op;
720 const BElementwiseOperation b_element_op;
721 const CElementwiseOperation c_element_op;
722 };
723
725 {
726 __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
727 {
729 {
730 a_k_split_offset = k_id * karg.KRead / APackedSize;
731 }
733 {
734 a_k_split_offset = k_id * karg.KRead * karg.StrideA;
735 }
736
738 {
739 b_k_split_offset = k_id * karg.KRead * karg.StrideB;
740 }
742 {
743 // KPack * NLane * KLane * K0 * N0
744 b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
745 }
746
747 if(k_id < karg.KBatch - 1)
748 {
749 karg.K = karg.KRead;
750 }
751 else
752 {
753 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
754 }
755 }
756
759 };
760
761 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
762 {
763 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
764 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
765
766 // A matrix in LDS memory, dst of blockwise copy
767 if constexpr(ABlockLdsExtraM)
768 {
772 }
773 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
774 // in some cases.
776 {
777 constexpr auto a_lds_block_desc =
780
781 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
782 a_lds_block_desc,
788
789 return a_lds_block_desc_permuted;
790 }
791 else // ColumnMajor A
792 {
793 // kfold and mpair dimension is not always required.
794 // more dimension in merge_transform increase the difficulty of generating immarg offset
795 // for compiler.
796 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
797 constexpr auto M1 = MPerBlock / M0;
798
799 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
800 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
801 constexpr auto KThreadRead = WaveSize / MPerXdl;
802 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
803
804 constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
805 ? 1
806 : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
807 constexpr auto KThreadReadPerm =
808 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
809 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
810 : KThreadRead;
811
812 // 1<=mpair<=n0
813 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
814 ? 1
815 : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
816 ? M0
817 : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
818
819 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
823 Number<kfold * M0 / mpair>{},
825 AK1Number));
826
827 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
828 a_lds_block_desc,
833 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
840
841 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
842 a_lds_block_desc_permuted,
851 Sequence<1>{},
852 Sequence<2>{},
853 Sequence<3>{},
854 Sequence<4>{},
855 Sequence<5>{}),
857 Sequence<2>{},
860 Sequence<6>{},
861 Sequence<7>{}));
862
863 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
864 a_lds_block_desc_unmerged,
867 Number<KThreadWrite / kfold / KThreadReadPerm>{},
875
876 return a_lds_block_desc_ak0_m_ak1;
877 }
878 }
879
880 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
881 {
882 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
885 }
886
888 {
889 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
890
891 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
895 I1,
897
898 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
899 }
900
903 BlkGemmPipelineVer,
904 BlkGemmPipeSched,
905 BlockSize,
906 ADataType,
907 BDataType,
908 ComputeTypeA,
909 AccDataType,
916 ABlockTransferSrcScalarPerVector,
917 BBlockTransferSrcScalarPerVector,
918 MPerBlock,
919 NPerBlock,
920 KPerBlock,
921 MPerXdl,
922 NPerXdl,
923 MXdlPerWave,
924 NXdlPerWave,
925 KPack,
926 IsInputGemm>())>;
927
928 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
929 {
930 // LDS allocation for A and B: be careful of alignment
931 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
932 // lds max alignment
933 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
934
935 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
936 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
937
938 // LDS allocation for C shuffle in LDS
939 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
941
942 constexpr auto c_block_size =
943 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
944
945 return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize,
946 c_block_size * sizeof(CShuffleDataType));
947 }
948
950
951 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
952 __host__ static constexpr bool CheckValidity(const Argument& karg)
953 {
954 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
955 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
956 "Invalid tuning param!");
957
963 {
964 if(!(karg.M % MPerBlock == 0))
965 {
966#if DEBUG_LOG
967 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
968 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
969 << std::endl;
970
971#endif // DEBUG_LOG
972 return false;
973 }
974 }
975
981 {
982 if(!(karg.N % NPerBlock == 0))
983 {
984#if DEBUG_LOG
985 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
986 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
987 << std::endl;
988
989#endif // DEBUG_LOG
990 return false;
991 }
992 }
993
998 {
999
1000 auto K_t = karg.KBatch * KPerBlock;
1001 if(!(karg.K % K_t == 0))
1002 {
1003#if DEBUG_LOG
1004 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1005 << karg.K << " " << __FILE__ << ":" << __LINE__
1006 << ", in function: " << __func__ << std::endl;
1007
1008#endif // DEBUG_LOG
1009 return false;
1010 }
1011 }
1012 else
1013 {
1014 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1015 auto K_t = karg.KBatch * KReadVec;
1016 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1017 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1018 {
1019 return false;
1020 }
1021 }
1022
1024 {
1025 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1026 {
1027#if DEBUG_LOG
1028 std::cout << "Arg K (" << karg.K
1029 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1030 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1031 << __LINE__ << ", in function: " << __func__ << std::endl;
1032
1033#endif // DEBUG_LOG
1034 return false;
1035 }
1036 }
1037 else
1038 {
1039 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1040 {
1041#if DEBUG_LOG
1042 std::cout << "Arg M (" << karg.M
1043 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1044 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1045 << __LINE__ << ", in function: " << __func__ << std::endl;
1046
1047#endif // DEBUG_LOG
1048 return false;
1049 }
1050 }
1051
1053 {
1054 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1055 {
1056#if DEBUG_LOG
1057 std::cout << "Arg N (" << karg.N
1058 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1059 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1060 << __LINE__ << ", in function: " << __func__ << std::endl;
1061
1062#endif // DEBUG_LOG
1063 return false;
1064 }
1065 }
1066 else
1067 {
1068 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1069 {
1070#if DEBUG_LOG
1071 std::cout << "Arg K (" << karg.K
1072 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1073 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1074 << __LINE__ << ", in function: " << __func__ << std::endl;
1075
1076#endif // DEBUG_LOG
1077 return false;
1078 }
1079 }
1080
1082 {
1084 {
1085#if DEBUG_LOG
1086 std::cout << "Arg N (" << karg.N
1087 << ") value is not a multiple of "
1088 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1089 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1090 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1091
1092#endif // DEBUG_LOG
1093 return false;
1094 }
1095 }
1096 else
1097 {
1099 {
1100#if DEBUG_LOG
1101 std::cout << "Arg M (" << karg.M
1102 << ") value is not a multiple of "
1103 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1104 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1105 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1106
1107#endif // DEBUG_LOG
1108 return false;
1109 }
1110 }
1111
1112 // check gridwise gemm pipeline
1113#if 0
1114 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1115
1116 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1117 {
1118 return false;
1119 }
1120#endif
1121 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1122 return true;
1123 }
1124
1125 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1126 {
1127 const index_t num_loop = K / KPerBlock;
1128
1129 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1130 }
1131
1132 __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1133 {
1134 const index_t num_loop = K / KPerBlock;
1135
1136 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1137 }
1138
1139 template <typename CGridDesc>
1141 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1142 {
1143 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1144 c_grid_desc_m_n,
1149
1150 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1151 }
1152
1153 // return block_id to C matrix tile idx (m0, n0) mapping
1154 // if arch = gfx942
1155 // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1156 // NPerBlock>;
1157
1158 template <bool HasMainKBlockLoop,
1159 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1160 TailNumber TailNum = TailNumber::Odd>
1161 __device__ static void Run(const index_t* p_sorted_token_ids,
1162 const index_t* p_sorted_expert_ids,
1163 const index_t* p_max_token_id,
1164 const ADataType* p_a_grid,
1165 const BDataType* p_b_grid,
1166 DsGridPointer& p_ds_grid,
1167 CDataType* p_c_grid,
1168 void* p_shared,
1169 const Problem& problem,
1170 AElementwiseOperation a_element_op,
1171 BElementwiseOperation b_element_op,
1172 CElementwiseOperation c_element_op)
1173 {
1174 ignore = b_element_op;
1175 index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1176 index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1177 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1178 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1179 problem.MPadded,
1180 problem.K,
1181 problem.KPadded,
1182 problem.StrideA,
1183 problem.AK0);
1184 const auto b_grid_desc_bpreshuffled =
1185 MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1186 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1187 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1188 problem.MPadded,
1189 problem.N,
1190 problem.NPadded,
1191 problem.StrideC);
1192 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1194 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1195 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1196 // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1197 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1198 if(expert_block_id * MPerBlock >= max_token_id)
1199 return;
1200 const index_t expert_id =
1201 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1202 const auto block_mn = [&]() -> std::pair<int, int> {
1203 if constexpr(NSwizzle)
1204 {
1205 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1206 const index_t prefix_block = ecnt_prefix * problem.NBlock;
1207 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1208 const index_t expert_swizzle =
1209 ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1210 const index_t bid_new = blockIdx.x - prefix_block;
1211 const index_t nid = __builtin_amdgcn_readfirstlane(
1212 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1213 const index_t mid =
1214 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1215 return {nid, mid};
1216 }
1217 else
1218 {
1219 return {blockIdx.x, blockIdx.y};
1220 }
1221 }();
1222
1223 const index_t block_n_id = block_mn.first;
1224 const index_t block_m_id = block_mn.second;
1225 const index_t token0 =
1226 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1227
1228 // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1229 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1230 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1231 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1232 constexpr auto AKThreads = AK0Threads * AK1Threads;
1233 constexpr auto AMRepeats = MPerBlock / AMThreads;
1234 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1235
1236 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1237 return;
1239 static_for<0, AMRepeats, 1>{}([&](auto m0) {
1240 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1241 index_t token_offset = fused_token & 0xffffff;
1242 if constexpr(!IsInputGemm)
1243 {
1244 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1245 }
1246 gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1247 });
1248 const IndexType expert_stride =
1249 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1250 const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
1251 // N0, K0, Blocksize*KPack
1252 const index_t n_block_data_idx_on_grid =
1253 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1254
1255 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1256 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1257 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1258 p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1259 // A matrix in LDS memory, dst of blockwise copy
1260 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1261
1262 // B matrix in LDS memory, dst of blockwise copy
1263 // dummy
1264 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1265 // A matrix blockwise copy
1266 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1268 AElementwiseOperation,
1272 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1273 ABlockTransferThreadClusterArrangeOrder,
1274 ADataType,
1275 LDSTypeA,
1276 decltype(a_grid_desc_ak0_m_ak1),
1277 decltype(a_block_desc_ak0_m_ak1),
1278 ABlockTransferSrcAccessOrder,
1280 ABlockTransferSrcVectorDim,
1281 2,
1282 ABlockTransferSrcScalarPerVector,
1283 ABlockTransferDstScalarPerVector_AK1,
1284 1,
1285 1,
1286 AThreadTransferSrcResetCoordinateAfterRun,
1287 true,
1288 IndexType,
1289 1,
1290 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1291 make_multi_index(0, 0, 0),
1292 a_element_op,
1293 a_block_desc_ak0_m_ak1,
1294 make_multi_index(0, 0, 0),
1296 gather_offsets);
1297
1298 // Thread-wise copy
1299 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1301 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1302
1303 auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1304 BDataType,
1305 BDataType,
1306 decltype(b_grid_desc_bpreshuffled),
1307 decltype(b_block_desc_bk0_n_bk1),
1310 3,
1311 BBlockTransferSrcScalarPerVector,
1312 BThreadTransferSrcResetCoordinateAfterRun,
1313 true>(b_grid_desc_bpreshuffled,
1314 make_multi_index(n_block_data_idx_on_grid,
1316 0,
1317 KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1318
1319 // LDS allocation for A and B: be careful of alignment
1320 // Cast after lds
1322 static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1323
1324 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1325 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1326
1327 // Blockwise GEMM pipeline
1328 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1329 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1330 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1331 decltype(c_thread_buf) c_thread_buf_up;
1332
1334 float,
1335 c_thread_buf.num_of_v_,
1336 c_thread_buf.s_per_v,
1337 true>
1338 c_thread_buf_fp32;
1339
1340 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1341 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1342 KPerBlock);
1343 if constexpr(IsInputGemm)
1344 {
1345 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
1346 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1347 p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1348 auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
1349 BDataType,
1350 BDataType,
1351 decltype(b_grid_desc_bpreshuffled),
1352 decltype(b_block_desc_bk0_n_bk1),
1355 3,
1356 BBlockTransferSrcScalarPerVector,
1357 BThreadTransferSrcResetCoordinateAfterRun,
1358 true>(b_grid_desc_bpreshuffled,
1359 make_multi_index(n_block_data_idx_on_grid,
1361 0,
1362 KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1363
1364 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1365 a_grid_desc_ak0_m_ak1,
1366 a_block_desc_ak0_m_ak1,
1367 a_blockwise_copy,
1368 a_grid_buf,
1369 a_block_buf,
1370 a_block_slice_copy_step,
1371 b_grid_desc_bpreshuffled,
1372 b_blockwise_copy,
1373 b_blockwise_copy_up,
1374 b_grid_buf,
1375 b_grid_buf_up,
1376 b_block_buf,
1377 b_block_slice_copy_step,
1378 c_thread_buf,
1379 c_thread_buf_up,
1380 num_k_block_main_loop);
1381 }
1382 else
1383 {
1384 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1385 a_grid_desc_ak0_m_ak1,
1386 a_block_desc_ak0_m_ak1,
1387 a_blockwise_copy,
1388 a_grid_buf,
1389 a_block_buf,
1390 a_block_slice_copy_step,
1391 b_grid_desc_bpreshuffled,
1392 b_blockwise_copy,
1393 b_grid_buf,
1394 b_block_buf,
1395 b_block_slice_copy_step,
1396 c_thread_buf,
1397 num_k_block_main_loop);
1398 }
1399
1400 // shuffle C and write out
1401 {
1402 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1403 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1404 "wrong!");
1405
1406 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1407
1408 // TODO: hacky, fix it!
1409 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1410 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1411
1412 // TODO: hacky, fix it!
1413 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1414 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1415 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1416
1417 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1418 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1419 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1420 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1421 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1422 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1423 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1424 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1425
1426 // mul scales
1427 const float* p_sorted_weights_0 = p_ds_grid[I0];
1428 const float* p_scale_b = p_ds_grid[I1];
1429
1430 static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
1431 static_assert(M4 == 4 || M4 == 8);
1432 const index_t m1 = get_warp_local_1d_id() / NWave;
1433 const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
1434
1435 if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
1436 {
1437 if constexpr(PerTokenQuant)
1438 {
1439 constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
1440 p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
1441 get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
1442 }
1443 else
1444 {
1445 p_scale_b += expert_id;
1446 }
1447
1448 vector_type<int32_t, M4> scale_token_ids;
1449 vector_type<float, M4> topk_weights;
1450 static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1451 const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
1452 static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1453 static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1454 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1455 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1456 if constexpr(PerTokenQuant)
1457 {
1458 scale_token_ids =
1460 p_sorted_token_ids + m_pos);
1461 }
1462 if constexpr(MulRoutedWeight)
1463 {
1465 p_ds_grid[I2] + m_pos);
1466 }
1467 static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1468 float scale_a = [&]() {
1469 if constexpr(PerTokenQuant)
1470 {
1471 index_t fused_token =
1472 scale_token_ids.template AsType<index_t>()[m4];
1473 const index_t token_offset = fused_token & 0xffffff;
1474 return token_offset < problem.NumTokens
1475 ? p_sorted_weights_0[IsInputGemm
1476 ? token_offset
1477 : token_offset *
1478 problem.TopK +
1479 (fused_token >>
1480 24)]
1481 : 0.0;
1482 }
1483 else
1484 {
1485 return p_sorted_weights_0[0];
1486 }
1487 }();
1488 constexpr index_t c_offset =
1489 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1490 make_tuple(m0, n0, m2 * M4 + m4));
1491 constexpr auto cidx = Number<c_offset>{};
1492 if constexpr(IsInputGemm) // gu fusion
1493 {
1494 if constexpr(ActivationOperation == Activation::silu_and_mul)
1495 {
1496 const float scale_up =
1497 p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
1498 PerTokenQuant];
1499 float gate = scale_a * scale_b * c_thread_buf[cidx];
1500 float up = scale_a * scale_up * c_thread_buf_up[cidx];
1501 if constexpr(MulRoutedWeight)
1502 {
1503 gate = gate * topk_weights.template AsType<float>()[m4];
1504 up = up * topk_weights.template AsType<float>()[m4];
1505 }
1507 {
1508 gate *= 16;
1509 up *= 16;
1510 }
1512 c_thread_buf_fp32(cidx) = gate * up;
1513 }
1514 else if(ActivationOperation == Activation::gelu_and_mul)
1515 {
1516 const float scale_up =
1517 p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
1518 PerTokenQuant];
1519 float gate = scale_a * scale_b * c_thread_buf[cidx];
1520 float up = scale_a * scale_up * c_thread_buf_up[cidx];
1521 if constexpr(MulRoutedWeight)
1522 {
1523 gate = gate * topk_weights.template AsType<float>()[m4];
1524 up = up * topk_weights.template AsType<float>()[m4];
1525 }
1527 {
1528 gate *= 16;
1529 up *= 16;
1530 }
1532 c_thread_buf_fp32(cidx) = gate * up;
1533 }
1534 }
1535 else
1536 {
1537 c_thread_buf_fp32(cidx) =
1538 scale_a * scale_b * c_thread_buf[cidx];
1539 if constexpr(MulRoutedWeight)
1540 {
1541 c_thread_buf_fp32(cidx) =
1542 c_thread_buf_fp32(cidx) *
1543 topk_weights.template AsType<float>()[m4];
1544 }
1545 }
1546 });
1547 });
1548 });
1549 });
1550 }
1551 else
1552 {
1553 vector_type<float, M4> topk_weights; // for gemm2 only
1554 static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
1555 static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
1556 static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
1557 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
1558 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
1559 if constexpr(MulRoutedWeight)
1560 {
1562 p_ds_grid[I2] + m_pos);
1563 }
1564 static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
1565 constexpr index_t c_offset =
1566 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1567 make_tuple(m0, n0, m2 * M4 + m4));
1568 constexpr auto cidx = Number<c_offset>{};
1569
1570 if constexpr(IsInputGemm) // gu fusion
1571 {
1572 if constexpr(ActivationOperation == Activation::silu_and_mul)
1573 {
1574 float gate = c_thread_buf[cidx];
1575 float up = c_thread_buf_up[cidx];
1576 if constexpr(MulRoutedWeight)
1577 {
1578 gate = gate * topk_weights.template AsType<float>()[m4];
1579 up = up * topk_weights.template AsType<float>()[m4];
1580 }
1582 c_thread_buf_fp32(cidx) = gate * up;
1583 }
1584 else if(ActivationOperation == Activation::gelu_and_mul)
1585 {
1586 float gate = c_thread_buf[cidx];
1587 float up = c_thread_buf_up[cidx];
1588 if constexpr(MulRoutedWeight)
1589 {
1590 gate = gate * topk_weights.template AsType<float>()[m4];
1591 up = up * topk_weights.template AsType<float>()[m4];
1592 }
1594 c_thread_buf_fp32(cidx) = gate * up;
1595 }
1596 }
1597 else
1598 {
1599 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1600 if constexpr(MulRoutedWeight)
1601 {
1602 c_thread_buf_fp32(cidx) =
1603 topk_weights.template AsType<float>()[m4] *
1604 c_thread_buf_fp32[cidx];
1605 }
1606 }
1607 });
1608 });
1609 });
1610 });
1611 }
1612
1613 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1615
1616 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1617 static_cast<CShuffleDataType*>(p_shared),
1618 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1619
1620 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1621 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1622 make_tuple(
1625 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1626 M1, // M1 = MWave
1627 M2, // M2 * M3 * M4 = MPerXdl
1628 M3,
1629 M4)),
1632 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1633 N1, // N1 = NWave
1634 N2))), // N2 = NPerXdl
1636 make_tuple(
1638
1639 // calculate origin of thread output tensor on global memory
1640 // blockwise GEMM c matrix starting index
1641 const auto c_thread_mtx_on_block =
1642 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1643
1644 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1645 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1646
1647 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1649 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1652
1653 const auto m_thread_data_on_block_idx =
1654 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1655 make_multi_index(m_thread_data_on_block));
1656
1657 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1662
1663 const auto n_thread_data_on_block_idx =
1664 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1665 make_multi_index(n_thread_data_on_block));
1666
1667 // shuffle: threadwise copy C from VGPR to LDS
1668 auto c_thread_copy_vgpr_to_lds =
1670 CShuffleDataType,
1671 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1672 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1674 Sequence<CShuffleMXdlPerWavePerShuffle,
1675 CShuffleNXdlPerWavePerShuffle,
1676 I1,
1677 I1,
1678 M2,
1679 I1,
1680 M4,
1681 I1>,
1683 7,
1684 1,
1686 1,
1687 true>{
1688 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1690 0,
1691 m_thread_data_on_block_idx[I1],
1692 n_thread_data_on_block_idx[I1],
1693 m_thread_data_on_block_idx[I2],
1694 m_thread_data_on_block_idx[I3],
1695 m_thread_data_on_block_idx[I4],
1696 n_thread_data_on_block_idx[I2]),
1698
1699 using EDataType = CDataType;
1700
1701 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1702 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1703
1704 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1706 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1707
1708 const auto ds_grid_buf = generate_tuple(
1709 [&](auto i) {
1711 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1712 },
1714
1715 // tuple of reference to C/Ds tensor descriptors
1716 const auto c_ds_desc_refs = concat_tuple_of_reference(
1717 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1718 generate_tie([&](auto i) -> const auto& // return type should be reference
1719 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1721
1722 // tuple of reference to C/Ds tensor descriptors
1723 const auto c_ds_buf_refs = concat_tuple_of_reference(
1724 tie(c_shuffle_block_buf),
1725 generate_tie([&](auto i) -> const auto& // return type should be reference
1726 { return ds_grid_buf[i]; },
1728
1729 // tuple of starting index of C/Ds blockwise copy
1730 const auto idx_c_ds_block_begin =
1733 [&](auto) {
1734 return make_multi_index(block_m_id, 0, block_n_id, 0);
1735 // return make_multi_index(block_work_idx[I0], 0,
1736 // block_work_idx[I1], 0);
1737 },
1739
1740 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1741 c_grid_desc_mblock_mperblock_nblock_nperblock;
1742
1743 using CDEBlockTransferCluster =
1744 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1745 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1746 constexpr index_t scatter_weight_idx = 3; // hack fix felix
1747 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1749 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1751 decltype(c_ds_desc_refs),
1752 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1753 CElementwiseOperation,
1754 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1755 // support arbitray type
1756 Sequence<1,
1757 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1758 1,
1759 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1760 CDEBlockTransferCluster,
1761 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1762 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1763 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1764 3, // index_t SrcVectorDim,
1765 3, // index_t DstVectorDim,
1766 CDEShuffleBlockTransferScalarPerVectors,
1771 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1772 Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1773 IndexType,
1774 1, // ScatterDim
1775 true, // OutputScatter: false, only use scatter weights
1776 scatter_weight_idx // ScatterWeightIdx: ascale
1777 >{c_ds_desc_refs,
1778 idx_c_ds_block_begin,
1779 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1780 make_tuple(make_multi_index(0, 0, block_n_id, 0)),
1781 c_element_op};
1782
1784 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1785 constexpr auto sfc_c_vgpr =
1788 Sequence<CShuffleMXdlPerWavePerShuffle,
1789 CShuffleNXdlPerWavePerShuffle,
1790 1,
1791 1,
1792 M2,
1793 1,
1794 M4,
1795 1>>{};
1796
1797 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1798
1799 // space filling curve for shuffled blockwise C/D/E
1800 constexpr auto sfc_cde_block =
1803 Sequence<1,
1804 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1805 1,
1806 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1807
1808 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1809 constexpr auto EMThreads =
1810 CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
1811 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1812 constexpr auto ENThreads =
1813 CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
1814 static_for<0, num_access, 1>{}([&](auto access_id) {
1815 // make sure it's safe to write to LDS
1817
1818 auto dstidx = sfc_cde_block.GetIndex(access_id);
1819 const index_t c_token_pos =
1820 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
1821 static_for<0, EMRepeats, 1>{}([&](auto m0) {
1822 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1823 IndexType token_offset = fused_token & 0xffffff;
1824 if constexpr(IsInputGemm)
1825 {
1826 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1827 }
1828 scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
1829 });
1830
1832
1833 // each thread write its data from VGPR to LDS
1834 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1835 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1836 c_thread_buf_fp32,
1837 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1838 c_shuffle_block_buf);
1839
1840 // make sure it's safe to read from LDS
1842
1843 // each block copy its data from LDS to global
1844 cde_block_copy_lds_and_global.Run(
1845 c_ds_desc_refs,
1846 c_ds_buf_refs,
1847 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1848 tie(c_grid_buf),
1849 scatter_offsets);
1850
1851 if constexpr(access_id < num_access - 1)
1852 {
1853 constexpr auto cde_lds_and_global_step =
1854 sfc_cde_block.GetForwardStep(access_id);
1855
1856 // move on Ds
1857 static_for<0, NumDTensor, 1>{}([&](auto i) {
1858 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1859 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1860 });
1861
1862 // move on E
1863 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1864 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1865 I0,
1866 cde_lds_and_global_step);
1867 }
1868 });
1869 }
1870 }
1871
1872 template <bool HasMainKBlockLoop,
1873 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1874 TailNumber TailNum = TailNumber::Odd>
1875 __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
1876 const index_t* p_sorted_expert_ids,
1877 const index_t* p_max_token_id,
1878 const ADataType* p_a_grid,
1879 const BDataType* p_b_grid,
1880 DsGridPointer& p_ds_grid,
1881 CDataType* p_c_grid,
1882 void* p_shared,
1883 void* p_shared1,
1884 const Problem& problem,
1885 AElementwiseOperation a_element_op,
1886 BElementwiseOperation b_element_op,
1887 CElementwiseOperation c_element_op)
1888 {
1889 ignore = b_element_op;
1890 index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1891 index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1892 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1893 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1894 problem.MPadded,
1895 problem.K,
1896 problem.KPadded,
1897 problem.StrideA,
1898 problem.AK0);
1899 const auto b_grid_desc_bpreshuffled =
1900 MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1901 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1902 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1903 problem.MPadded,
1904 problem.N,
1905 problem.NPadded,
1906 problem.StrideC);
1907 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1909 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1910 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1911 // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
1912 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1913 if(expert_block_id * MPerBlock >= max_token_id)
1914 return;
1915 const index_t expert_id =
1916 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1917 const auto block_mn = [&]() -> std::pair<int, int> {
1918 if constexpr(NSwizzle)
1919 {
1920 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1921 const index_t prefix_block = ecnt_prefix * problem.NBlock;
1922 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1923 const index_t expert_swizzle =
1924 ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1925 const index_t bid_new = blockIdx.x - prefix_block;
1926 const index_t nid = __builtin_amdgcn_readfirstlane(
1927 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1928 const index_t mid =
1929 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1930 return {nid, mid};
1931 }
1932 else
1933 {
1934 return {blockIdx.x, blockIdx.y};
1935 }
1936 }();
1937
1938 const index_t block_n_id = block_mn.first;
1939 const index_t block_m_id = block_mn.second;
1940 const index_t token0 =
1941 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1942
1943 // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1944 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1945 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1946 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1947 constexpr auto AKThreads = AK0Threads * AK1Threads;
1948 constexpr auto AMRepeats = MPerBlock / AMThreads;
1949 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1950
1951 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1952 return;
1954 static_for<0, AMRepeats, 1>{}([&](auto m0) {
1955 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1956 index_t token_offset = fused_token & 0xffffff;
1957 if constexpr(!IsInputGemm)
1958 {
1959 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1960 }
1961 gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1962 });
1963 const IndexType expert_stride =
1964 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1965 const IndexType expert_offset = expert_id * expert_stride / BPackedSize;
1966 // N0, K0, Blocksize*KPack
1967 const index_t n_block_data_idx_on_grid =
1968 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1969
1970 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1971 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1972 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1973 p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1974
1975 // A matrix in LDS memory, dst of blockwise copy
1976 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1977
1978 // B matrix in LDS memory, dst of blockwise copy
1979 // dummy
1980 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1981 // A matrix blockwise copy
1982 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1984 AElementwiseOperation,
1988 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1989 ABlockTransferThreadClusterArrangeOrder,
1990 ADataType,
1991 LDSTypeA,
1992 decltype(a_grid_desc_ak0_m_ak1),
1993 decltype(a_block_desc_ak0_m_ak1),
1994 ABlockTransferSrcAccessOrder,
1996 ABlockTransferSrcVectorDim,
1997 2,
1998 ABlockTransferSrcScalarPerVector,
1999 ABlockTransferDstScalarPerVector_AK1,
2000 1,
2001 1,
2002 AThreadTransferSrcResetCoordinateAfterRun,
2003 true,
2004 IndexType,
2005 1,
2006 2>(a_grid_desc_ak0_m_ak1,
2007 make_multi_index(0, 0, 0),
2008 a_element_op,
2009 a_block_desc_ak0_m_ak1,
2010 make_multi_index(0, 0, 0),
2012 gather_offsets);
2013
2014 // Thread-wise copy
2015 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2017 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2019 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2020 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2021
2022 auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
2023 BDataType,
2024 BDataType,
2025 decltype(b_grid_desc_bpreshuffled),
2026 decltype(b_block_desc_bk0_n_bk1),
2029 3,
2030 BBlockTransferSrcScalarPerVector,
2031 BThreadTransferSrcResetCoordinateAfterRun,
2032 true>(b_grid_desc_bpreshuffled,
2033 make_multi_index(n_block_data_idx_on_grid,
2035 0,
2036 KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2037
2038 // LDS allocation for A and B: be careful of alignment
2039 // Cast after lds
2040 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2041 static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2042 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2043 static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2044 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2045
2046 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2047 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
2048
2049 // Blockwise GEMM pipeline
2050 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2051 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2052 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2053 decltype(c_thread_buf) c_thread_buf_up;
2054
2056 float,
2057 c_thread_buf.num_of_v_,
2058 c_thread_buf.s_per_v,
2059 true>
2060 c_thread_buf_fp32;
2061
2062 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2063 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2064 KPerBlock);
2065
2066 if constexpr(IsInputGemm)
2067 {
2068 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2069 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2070 p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2071 auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2072 BDataType,
2073 BDataType,
2074 decltype(b_grid_desc_bpreshuffled),
2075 decltype(b_block_desc_bk0_n_bk1),
2078 3,
2079 BBlockTransferSrcScalarPerVector,
2080 BThreadTransferSrcResetCoordinateAfterRun,
2081 true>(b_grid_desc_bpreshuffled,
2082 make_multi_index(n_block_data_idx_on_grid,
2084 0,
2085 KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2086 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2087 a_grid_desc_ak0_m_ak1,
2088 a_block_desc_ak0_m_ak1,
2089 a_blockwise_copy,
2090 a_grid_buf,
2091 a_block_bufs,
2092 a_block_slice_copy_step,
2093 b_grid_desc_bpreshuffled,
2094 b_blockwise_copy,
2095 b_blockwise_copy_up,
2096 b_grid_buf,
2097 b_grid_buf_up,
2098 b_block_bufs,
2099 b_block_slice_copy_step,
2100 c_thread_buf,
2101 c_thread_buf_up,
2102 num_k_block_main_loop);
2103 }
2104 else
2105 {
2106
2107 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2108 a_grid_desc_ak0_m_ak1,
2109 a_block_desc_ak0_m_ak1,
2110 a_blockwise_copy,
2111 a_grid_buf,
2112 a_block_bufs,
2113 a_block_slice_copy_step,
2114 b_grid_desc_bpreshuffled,
2115 b_blockwise_copy,
2116 b_grid_buf,
2117 b_block_bufs,
2118 b_block_slice_copy_step,
2119 c_thread_buf,
2120 num_k_block_main_loop);
2121 }
2122
2123 // shuffle C and write out
2124 {
2125 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2126 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2127 "wrong!");
2128
2129 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2130
2131 // TODO: hacky, fix it!
2132 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2133 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2134
2135 // TODO: hacky, fix it!
2136 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2137 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2138 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2139
2140 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2141 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2142 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2143 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2144 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2145 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2146 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2147 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2148
2149 // mul scales
2150 const float* p_sorted_weights_0 = p_ds_grid[I0];
2151 const float* p_scale_b = p_ds_grid[I1];
2152
2153 static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2154 static_assert(M4 == 4 || M4 == 8);
2155 const index_t m1 = get_warp_local_1d_id() / NWave;
2156 const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
2157
2158 if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
2159 {
2160 if constexpr(PerTokenQuant)
2161 {
2162 constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
2163 p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
2164 get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
2165 }
2166 else
2167 {
2168 p_scale_b += expert_id;
2169 }
2170
2171 vector_type<int32_t, M4> scale_token_ids;
2172 vector_type<float, M4> topk_weights;
2173 static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2174 const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
2175 static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2176 static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2177 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2178 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2179 if constexpr(PerTokenQuant)
2180 {
2181 scale_token_ids =
2183 p_sorted_token_ids + m_pos);
2184 }
2185 if constexpr(MulRoutedWeight)
2186 {
2188 p_ds_grid[I2] + m_pos);
2189 }
2190 static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2191 float scale_a = [&]() {
2192 if constexpr(PerTokenQuant)
2193 {
2194 index_t fused_token =
2195 scale_token_ids.template AsType<index_t>()[m4];
2196 const index_t token_offset = fused_token & 0xffffff;
2197 return token_offset < problem.NumTokens
2198 ? p_sorted_weights_0[IsInputGemm
2199 ? token_offset
2200 : token_offset *
2201 problem.TopK +
2202 (fused_token >>
2203 24)]
2204 : 0.0;
2205 }
2206 else
2207 {
2208 return p_sorted_weights_0[0];
2209 }
2210 }();
2211 constexpr index_t c_offset =
2212 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2213 make_tuple(m0, n0, m2 * M4 + m4));
2214 constexpr auto cidx = Number<c_offset>{};
2215 if constexpr(IsInputGemm) // gu fusion
2216 {
2217 if constexpr(ActivationOperation == Activation::silu_and_mul)
2218 {
2219 const float scale_up =
2220 p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
2221 PerTokenQuant];
2222 float gate = scale_a * scale_b * c_thread_buf[cidx];
2223 float up = scale_a * scale_up * c_thread_buf_up[cidx];
2224 if constexpr(MulRoutedWeight)
2225 {
2226 gate = gate * topk_weights.template AsType<float>()[m4];
2227 up = up * topk_weights.template AsType<float>()[m4];
2228 }
2230 {
2231 gate *= 16;
2232 up *= 16;
2233 }
2235 c_thread_buf_fp32(cidx) = gate * up;
2236 }
2237 else if(ActivationOperation == Activation::gelu_and_mul)
2238 {
2239 const float scale_up =
2240 p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
2241 PerTokenQuant];
2242 float gate = scale_a * scale_b * c_thread_buf[cidx];
2243 float up = scale_a * scale_up * c_thread_buf_up[cidx];
2244 if constexpr(MulRoutedWeight)
2245 {
2246 gate = gate * topk_weights.template AsType<float>()[m4];
2247 up = up * topk_weights.template AsType<float>()[m4];
2248 }
2250 {
2251 gate *= 16;
2252 up *= 16;
2253 }
2255 c_thread_buf_fp32(cidx) = gate * up;
2256 }
2257 }
2258 else
2259 {
2260 c_thread_buf_fp32(cidx) =
2261 scale_a * scale_b * c_thread_buf[cidx];
2262 if constexpr(MulRoutedWeight)
2263 {
2264 c_thread_buf_fp32(cidx) =
2265 c_thread_buf_fp32(cidx) *
2266 topk_weights.template AsType<float>()[m4];
2267 }
2268 }
2269 });
2270 });
2271 });
2272 });
2273 }
2274 else
2275 {
2276 vector_type<float, M4> topk_weights; // for gemm2 only
2277 static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2278 static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2279 static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2280 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2281 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2282 if constexpr(MulRoutedWeight)
2283 {
2285 p_ds_grid[I2] + m_pos);
2286 }
2287 static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2288 constexpr index_t c_offset =
2289 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2290 make_tuple(m0, n0, m2 * M4 + m4));
2291 constexpr auto cidx = Number<c_offset>{};
2292
2293 if constexpr(IsInputGemm) // gu fusion
2294 {
2295 if constexpr(ActivationOperation == Activation::silu_and_mul)
2296 {
2297 float gate = c_thread_buf[cidx];
2298 float up = c_thread_buf_up[cidx];
2299 if constexpr(MulRoutedWeight)
2300 {
2301 gate = gate * topk_weights.template AsType<float>()[m4];
2302 up = up * topk_weights.template AsType<float>()[m4];
2303 }
2305 c_thread_buf_fp32(cidx) = gate * up;
2306 }
2307 else if(ActivationOperation == Activation::gelu_and_mul)
2308 {
2309 float gate = c_thread_buf[cidx];
2310 float up = c_thread_buf_up[cidx];
2311 if constexpr(MulRoutedWeight)
2312 {
2313 gate = gate * topk_weights.template AsType<float>()[m4];
2314 up = up * topk_weights.template AsType<float>()[m4];
2315 }
2317 c_thread_buf_fp32(cidx) = gate * up;
2318 }
2319 }
2320 else
2321 {
2322 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2323 if constexpr(MulRoutedWeight)
2324 {
2325 c_thread_buf_fp32(cidx) =
2326 topk_weights.template AsType<float>()[m4] *
2327 c_thread_buf_fp32[cidx];
2328 }
2329 }
2330 });
2331 });
2332 });
2333 });
2334 }
2335
2336 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2338
2339 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2340 static_cast<CShuffleDataType*>(p_shared),
2341 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2342
2343 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2344 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2345 make_tuple(
2348 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2349 M1, // M1 = MWave
2350 M2, // M2 * M3 * M4 = MPerXdl
2351 M3,
2352 M4)),
2355 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2356 N1, // N1 = NWave
2357 N2))), // N2 = NPerXdl
2359 make_tuple(
2361
2362 // calculate origin of thread output tensor on global memory
2363 // blockwise GEMM c matrix starting index
2364 const auto c_thread_mtx_on_block =
2365 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2366
2367 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2368 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2369
2370 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2372 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2375
2376 const auto m_thread_data_on_block_idx =
2377 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2378 make_multi_index(m_thread_data_on_block));
2379
2380 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2385
2386 const auto n_thread_data_on_block_idx =
2387 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2388 make_multi_index(n_thread_data_on_block));
2389
2390 // shuffle: threadwise copy C from VGPR to LDS
2391 auto c_thread_copy_vgpr_to_lds =
2393 CShuffleDataType,
2394 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2395 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2397 Sequence<CShuffleMXdlPerWavePerShuffle,
2398 CShuffleNXdlPerWavePerShuffle,
2399 I1,
2400 I1,
2401 M2,
2402 I1,
2403 M4,
2404 I1>,
2406 7,
2407 1,
2409 1,
2410 true>{
2411 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2413 0,
2414 m_thread_data_on_block_idx[I1],
2415 n_thread_data_on_block_idx[I1],
2416 m_thread_data_on_block_idx[I2],
2417 m_thread_data_on_block_idx[I3],
2418 m_thread_data_on_block_idx[I4],
2419 n_thread_data_on_block_idx[I2]),
2421
2422 using EDataType = CDataType;
2423
2424 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2425 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2426
2427 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2429 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2430
2431 const auto ds_grid_buf = generate_tuple(
2432 [&](auto i) {
2434 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2435 },
2437
2438 // tuple of reference to C/Ds tensor descriptors
2439 const auto c_ds_desc_refs = concat_tuple_of_reference(
2440 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2441 generate_tie([&](auto i) -> const auto& // return type should be reference
2442 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2444
2445 // tuple of reference to C/Ds tensor descriptors
2446 const auto c_ds_buf_refs = concat_tuple_of_reference(
2447 tie(c_shuffle_block_buf),
2448 generate_tie([&](auto i) -> const auto& // return type should be reference
2449 { return ds_grid_buf[i]; },
2451
2452 // tuple of starting index of C/Ds blockwise copy
2453 const auto idx_c_ds_block_begin =
2456 [&](auto) {
2457 return make_multi_index(block_m_id, 0, block_n_id, 0);
2458 // return make_multi_index(block_work_idx[I0], 0,
2459 // block_work_idx[I1], 0);
2460 },
2462
2463 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2464 c_grid_desc_mblock_mperblock_nblock_nperblock;
2465
2466 using CDEBlockTransferCluster =
2467 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2468 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2469 constexpr index_t scatter_weight_idx = 3; // hack fix felix
2470 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2472 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2474 decltype(c_ds_desc_refs),
2475 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2476 CElementwiseOperation,
2477 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2478 // support arbitray type
2479 Sequence<1,
2480 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2481 1,
2482 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2483 CDEBlockTransferCluster,
2484 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2485 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2486 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2487 3, // index_t SrcVectorDim,
2488 3, // index_t DstVectorDim,
2489 CDEShuffleBlockTransferScalarPerVectors,
2494 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2495 Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2496 IndexType,
2497 1, // ScatterDim
2498 true, // OutputScatter: false, only use scatter weights
2499 scatter_weight_idx // ScatterWeightIdx: ascale
2500 >{c_ds_desc_refs,
2501 idx_c_ds_block_begin,
2502 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2503 make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2504 c_element_op};
2505
2507 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2508 constexpr auto sfc_c_vgpr =
2511 Sequence<CShuffleMXdlPerWavePerShuffle,
2512 CShuffleNXdlPerWavePerShuffle,
2513 1,
2514 1,
2515 M2,
2516 1,
2517 M4,
2518 1>>{};
2519
2520 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2521
2522 // space filling curve for shuffled blockwise C/D/E
2523 constexpr auto sfc_cde_block =
2526 Sequence<1,
2527 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2528 1,
2529 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2530
2531 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2532 constexpr auto EMThreads =
2533 CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2534 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2535 constexpr auto ENThreads =
2536 CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2537 static_for<0, num_access, 1>{}([&](auto access_id) {
2538 // make sure it's safe to write to LDS
2540
2541 auto dstidx = sfc_cde_block.GetIndex(access_id);
2542 const index_t c_token_pos =
2543 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2544 static_for<0, EMRepeats, 1>{}([&](auto m0) {
2545 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2546 IndexType token_offset = fused_token & 0xffffff;
2547 if constexpr(IsInputGemm)
2548 {
2549 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2550 }
2551 scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2552 });
2553
2555
2556 // each thread write its data from VGPR to LDS
2557 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2558 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2559 c_thread_buf_fp32,
2560 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2561 c_shuffle_block_buf);
2562
2563 // make sure it's safe to read from LDS
2565
2566 // each block copy its data from LDS to global
2567 cde_block_copy_lds_and_global.Run(
2568 c_ds_desc_refs,
2569 c_ds_buf_refs,
2570 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2571 tie(c_grid_buf),
2572 scatter_offsets);
2573
2574 if constexpr(access_id < num_access - 1)
2575 {
2576 constexpr auto cde_lds_and_global_step =
2577 sfc_cde_block.GetForwardStep(access_id);
2578
2579 // move on Ds
2580 static_for<0, NumDTensor, 1>{}([&](auto i) {
2581 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2582 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2583 });
2584
2585 // move on E
2586 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2587 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2588 I0,
2589 cde_lds_and_global_step);
2590 }
2591 });
2592 }
2593 }
2594};
2595
2596} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__device__ index_t get_warp_local_1d_id()
Definition get_id.hpp:45
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_gemm.hpp:46
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ PY c_style_pointer_cast(PX p_x)
Definition c_style_pointer_cast.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
Activation
Definition gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition gridwise_moe_gemm.hpp:32
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr auto BlockGemmBPreshufflePipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp:41
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_gemm.hpp:84
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition gridwise_moe_gemm.hpp:659
const BDataType * p_b_grid
Definition gridwise_moe_gemm.hpp:715
const index_t * p_sorted_token_ids
Definition gridwise_moe_gemm.hpp:711
const index_t * p_sorted_expert_ids
Definition gridwise_moe_gemm.hpp:712
const AElementwiseOperation a_element_op
Definition gridwise_moe_gemm.hpp:719
const ADataType * p_a_grid
Definition gridwise_moe_gemm.hpp:714
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition gridwise_moe_gemm.hpp:660
const index_t * p_max_token_id
Definition gridwise_moe_gemm.hpp:713
const BElementwiseOperation b_element_op
Definition gridwise_moe_gemm.hpp:720
CDataType * p_c_grid
Definition gridwise_moe_gemm.hpp:717
DsGridPointer p_ds_grid
Definition gridwise_moe_gemm.hpp:716
const CElementwiseOperation c_element_op
Definition gridwise_moe_gemm.hpp:721
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_moe_gemm.hpp:644
index_t NumTokens
Definition gridwise_moe_gemm.hpp:637
index_t MBlock
Definition gridwise_moe_gemm.hpp:653
index_t TopK
Definition gridwise_moe_gemm.hpp:638
index_t K
Definition gridwise_moe_gemm.hpp:641
__host__ void Print() const
Definition gridwise_moe_gemm.hpp:626
index_t NPadded
Definition gridwise_moe_gemm.hpp:648
index_t BK0
Definition gridwise_moe_gemm.hpp:652
index_t KRead
Definition gridwise_moe_gemm.hpp:649
index_t MPadded
Definition gridwise_moe_gemm.hpp:647
index_t AK0
Definition gridwise_moe_gemm.hpp:651
index_t StrideA
Definition gridwise_moe_gemm.hpp:642
index_t StrideC
Definition gridwise_moe_gemm.hpp:645
index_t M
Definition gridwise_moe_gemm.hpp:639
index_t KBatch
Definition gridwise_moe_gemm.hpp:646
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition gridwise_moe_gemm.hpp:595
index_t KPadded
Definition gridwise_moe_gemm.hpp:650
index_t StrideB
Definition gridwise_moe_gemm.hpp:643
index_t N
Definition gridwise_moe_gemm.hpp:640
index_t NBlock
Definition gridwise_moe_gemm.hpp:654
index_t a_k_split_offset
Definition gridwise_moe_gemm.hpp:757
index_t b_k_split_offset
Definition gridwise_moe_gemm.hpp:758
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_moe_gemm.hpp:726
Definition gridwise_moe_gemm.hpp:171
remove_cvref_t< decltype(BlockGemmBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition gridwise_moe_gemm.hpp:901
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_moe_gemm.hpp:1875
static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_moe_gemm.hpp:1140
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_moe_gemm.hpp:1161
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
static constexpr index_t GetK1PerXdlops()
Definition xdlops_gemm.hpp:1810
static constexpr auto selected_mfma
Definition xdlops_gemm.hpp:1757
static constexpr index_t GetKPerXdlops()
Definition xdlops_gemm.hpp:1804
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition static_buffer.hpp:75
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1041
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1087
Definition dtype_vector.hpp:10