gridwise_gemm_xdl_cshuffle_conv_v3.hpp Source File

gridwise_gemm_xdl_cshuffle_conv_v3.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_conv_v3.hpp Source File
gridwise_gemm_xdl_cshuffle_conv_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
17namespace ck {
18
19template <typename ALayout,
20 typename BLayout,
21 typename CLayout,
22 typename ADataType,
23 typename BDataType,
24 typename AccDataType,
25 typename CShuffleDataType,
26 typename CDataType,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename CElementwiseOperation,
31 index_t BlockSize,
32 index_t MPerBlock,
33 index_t NPerBlock,
34 index_t KPerBlock,
35 index_t AK1Value,
36 index_t BK1Value,
37 index_t MPerXdl,
38 index_t NPerXdl,
39 index_t MXdlPerWave,
40 index_t NXdlPerWave,
41 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
42 typename ABlockTransferThreadClusterArrangeOrder,
43 typename ABlockTransferSrcAccessOrder,
44 index_t ABlockTransferSrcVectorDim,
45 index_t ABlockTransferSrcScalarPerVector,
46 index_t ABlockTransferDstScalarPerVector_AK1,
47 bool AThreadTransferSrcResetCoordinateAfterRun,
48 index_t ABlockLdsExtraMCustom,
49 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
50 typename BBlockTransferThreadClusterArrangeOrder,
51 typename BBlockTransferSrcAccessOrder,
52 index_t BBlockTransferSrcVectorDim,
53 index_t BBlockTransferSrcScalarPerVector,
54 index_t BBlockTransferDstScalarPerVector_BK1,
55 bool BThreadTransferSrcResetCoordinateAfterRun,
56 index_t BBlockLdsExtraNCustom,
57 index_t CShuffleMXdlPerWavePerShuffle,
58 index_t CShuffleNXdlPerWavePerShuffle,
59 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
60 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
63 typename ComputeTypeA = CDataType,
64 typename ComputeTypeB = ComputeTypeA>
66{
67 static constexpr auto I0 = Number<0>{};
68 static constexpr auto I1 = Number<1>{};
69 static constexpr auto I2 = Number<2>{};
70 static constexpr auto I3 = Number<3>{};
71 static constexpr auto I4 = Number<4>{};
72 static constexpr auto I5 = Number<5>{};
73 static constexpr auto I6 = Number<6>{};
74 static constexpr auto I7 = Number<7>{};
75
76 // K1 should be Number<...>
77 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
78 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
79 static constexpr auto AK1Number = Number<AK1Value>{};
80 static constexpr auto BK1Number = Number<BK1Value>{};
81
82 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
83 static constexpr bool is_single_rate_mfma =
85 lcm_AK1_BK1 <= 4) ||
88 lcm_AK1_BK1 < 32))
89 ? true
90 : false;
91 static constexpr auto is_scale_mfma = false;
92 static constexpr index_t KPack =
94 MfmaSelector<ComputeTypeA,
95 MPerXdl,
96 NPerXdl,
97 ComputeTypeA,
99 is_scale_mfma>::selected_mfma.k_per_blk);
100
102
103 __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
104 {
105 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch);
106 }
107
108 __host__ static auto CalculateMPadded(index_t M)
109 {
110 return math::integer_least_multiple(M, MPerBlock);
111 }
112
113 __host__ static auto CalculateNPadded(index_t N)
114 {
115 return math::integer_least_multiple(N, NPerBlock);
116 }
117
118 __host__ static auto CalculateKPadded(index_t K)
119 {
120 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
121 }
122
123 __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
124 {
125 auto K_t = K_Batch * KPerBlock;
126 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
127 }
128
129 __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
130 {
131 auto K_t = K_Batch * KPerBlock;
132 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
133 }
134
135 __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
136 {
137 auto K_t = K_Batch * KPerBlock;
138 return (K + K_t - 1) / K_t * KPerBlock;
139 }
140
141 __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
142 {
143 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
144 auto K_t = K_Batch * KReadVec;
145 return (K + K_t - 1) / K_t * KReadVec;
146 }
147
148 __host__ static auto CalculateMBlock(index_t M)
149 {
150 return math::integer_divide_ceil(M, MPerBlock);
151 }
152
153 __host__ static auto CalculateNBlock(index_t N)
154 {
155 return math::integer_divide_ceil(N, NPerBlock);
156 }
157
158 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
159 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
160 {
161 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
162 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
163
165 TileDesc_K0_MN_K1{},
171 }
172
173 template <typename ABlockDesc_AK0_M_AK1>
174 __host__ __device__ static constexpr auto
175 MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
176 {
177 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
178
179 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
180 }
181
182 template <typename BBlockDesc_BK0_N_BK1>
183 __host__ __device__ static constexpr auto
184 MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
185 {
186 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
187
188 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
189 }
190
191 struct Problem
192 {
193 __host__ Problem(index_t M_,
194 index_t N_,
195 index_t K_,
196 index_t StrideA_,
197 index_t StrideB_,
198 index_t StrideC_,
199 index_t KBatch_)
200 : M{M_},
201 N{N_},
202 K{K_},
203 StrideA{StrideA_},
204 StrideB{StrideB_},
205 StrideC{StrideC_},
206 KBatch{KBatch_},
209 KRead{CalculateKRead(K_, KBatch_)},
210 KPadded{CalculateKPadded(K_, KBatch_)},
211 AK0{CalculateAK0Padded(K_, KBatch_)},
212 BK0{CalculateBK0Padded(K_, KBatch_)},
215 {
216 }
217
218 __host__ void Print() const
219 {
220 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
221 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
222 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
223 << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
224 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
225 << "NBlock: " << NBlock << "}" << std::endl;
226 }
227
243 };
244
245 // Argument
247 {
248 __host__ Argument(const ADataType* p_a_grid_,
249 const BDataType* p_b_grid_,
250 CDataType* p_c_grid_,
251 index_t M_,
252 index_t N_,
253 index_t K_,
254 index_t StrideA_,
255 index_t StrideB_,
256 index_t StrideC_,
257 index_t k_batch_)
258 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
259 p_a_grid{p_a_grid_},
260 p_b_grid{p_b_grid_},
261 p_c_grid{p_c_grid_}
262 {
263 }
264
265 const ADataType* p_a_grid;
266 const BDataType* p_b_grid;
267 CDataType* p_c_grid;
268 };
269
270 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
271 {
272 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
273 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
274 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
275#if defined(__gfx950__)
276 // Force use padded layout on gfx950 to reduce bank conflicts
277 constexpr index_t ABlockLdsExtraM = 1;
278#else
279 constexpr index_t ABlockLdsExtraM = ABlockLdsExtraMCustom;
280#endif
281 // A matrix in LDS memory, dst of blockwise copy
282 if constexpr(ABlockLdsExtraM)
283 {
287 }
288 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
289 // in some cases.
291 {
292 constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
293 ? 1
294 : 32 * 4 / KPerBlock / sizeof(ADataType);
295 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
297 AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
299
300 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
301 a_lds_block_desc,
307
308 constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
309 a_lds_block_desc_permuted,
315
316 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
317 a_lds_block_desc_ak0_mldslayer_m_ak1,
324
325 return a_lds_block_desc_ak0_m_ak1;
326 }
327 else // ColumnMajor A
328 {
329 // kfold and mpair dimension is not always required.
330 // more dimension in merge_transform increase the difficulty of generating immarg offset
331 // for compiler.
332 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
333 constexpr auto M1 = MPerBlock / M0;
334
335 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
336 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
337 constexpr auto KThreadRead = WaveSize / MPerXdl;
338 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
339
340 constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
341 ? 1
342 : 128 / (AK1Number * M0 * sizeof(ADataType));
343 constexpr auto KThreadReadPerm =
344 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
345 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
346 : KThreadRead;
347
348 // 1<=mpair<=n0
349 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
350 ? 1
351 : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
352 ? M0
353 : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
354
355 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
359 Number<kfold * M0 / mpair>{},
361 AK1Number));
362
363 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
364 a_lds_block_desc,
369 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
376
377 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
378 a_lds_block_desc_permuted,
387 Sequence<1>{},
388 Sequence<2>{},
389 Sequence<3>{},
390 Sequence<4>{},
391 Sequence<5>{}),
393 Sequence<2>{},
396 Sequence<6>{},
397 Sequence<7>{}));
398
399 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
400 a_lds_block_desc_unmerged,
403 Number<KThreadWrite / kfold / KThreadReadPerm>{},
411
412 return a_lds_block_desc_ak0_m_ak1;
413 }
414 }
415
416 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
417 {
418 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
419 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
420 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
421#if defined(__gfx950__)
422 // Force use padded layout on gfx950 to reduce bank conflicts
423 constexpr index_t BBlockLdsExtraN = 1;
424#else
425 constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom;
426#endif
427 // B matrix in LDS memory, dst of blockwise copy
428 if constexpr(BBlockLdsExtraN)
429 {
433 }
435 {
436 // NLdsLayer * K0 as logical Bank
437 constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1
438 ? 1
439 : 32 * 4 / KPerBlock / sizeof(BDataType);
440 ;
441 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
443 BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
445
446 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
447 b_lds_block_desc,
453
454 constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
455 b_lds_block_desc_permuted,
461
462 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
463 b_lds_block_desc_bk0_nldslayer_n_bk1,
470
471 return b_lds_block_desc_bk0_n_bk1;
472 }
473 else // RowMajor B
474 {
475 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
476 constexpr auto N1 = NPerBlock / N0;
477
478 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
479 constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
480 constexpr auto KThreadRead = WaveSize / NPerXdl;
481 constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
482
483 constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
484 ? 1
485 : 128 / (BK1Number * N0 * sizeof(BDataType));
486 constexpr auto KThreadReadPerm =
487 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
488 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
489 : KThreadRead;
490
491 // 1<=npair<=n0
492 constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
493 ? 1
494 : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
495 ? N0
496 : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
497
498 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
502 Number<kfold * N0 / npair>{},
504 BK1Number));
505
506 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
507 b_lds_block_desc,
512 make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
519
520 constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
521 b_lds_block_desc_permuted,
530 Sequence<1>{},
531 Sequence<2>{},
532 Sequence<3>{},
533 Sequence<4>{},
534 Sequence<5>{}),
536 Sequence<2>{},
539 Sequence<6>{},
540 Sequence<7>{}));
541
542 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
543 b_lds_block_desc_unmerged,
546 Number<KThreadWrite / kfold / KThreadReadPerm>{},
554
555 return b_lds_block_desc_bk0_n_bk1;
556 }
557 }
558
560 {
561 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
562 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
563
564 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
568 I1,
570
571 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
572 }
573
575
578 BlkGemmPipelineVer,
579 BlkGemmPipeSched,
580 BlockSize,
581 ADataType,
582 BDataType,
583 ComputeTypeA,
584 AccDataType,
591 ABlockTransferSrcScalarPerVector,
592 BBlockTransferSrcScalarPerVector,
593 MPerBlock,
594 NPerBlock,
595 KPerBlock,
596 MPerXdl,
597 NPerXdl,
598 MXdlPerWave,
599 NXdlPerWave,
600 KPack>())>;
601
602 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
603 {
604 // LDS allocation for A and B: be careful of alignment
605 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
606 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
607
608 // lds max alignment
609 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
610
611 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
612 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
613
614 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
615 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
616
617 // LDS allocation for C shuffle in LDS
618 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
620
621 constexpr auto c_block_size =
622 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
623
624 return math::max((a_block_space_size_aligned * sizeof(ADataType) +
625 b_block_space_size_aligned * sizeof(BDataType)),
626 c_block_size * sizeof(CShuffleDataType));
627 }
628
629 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
630 {
631 const index_t num_loop = K / KPerBlock;
632
633 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
634 }
635
636 __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
637 {
638 const index_t num_loop = K / KPerBlock;
639
640 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
641 }
642
643 template <typename CGridDesc>
644 __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
645 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
646 {
647 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
648 c_grid_desc_m_n,
653
654 return c_grid_desc_mblock_mperblock_nblock_nperblock;
655 }
656
657 // return block_id to C matrix tile idx (m0, n0) mapping
658 // if arch = gfx942
660
661 template <typename AGridDesc_AK0_M_K1,
662 typename BGridDesc_BK0_N_K1,
663 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
664 bool HasMainKBlockLoop,
665 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
666 TailNumber TailNum = TailNumber::Odd>
667 __device__ static void Run(const ADataType* p_a_grid,
668 const BDataType* p_b_grid,
669 CDataType* p_c_grid,
670 void* p_shared,
671 const Problem& problem,
672 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
673 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
674 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
675 c_grid_desc_mblock_mperblock_nblock_nperblock,
676 const index_t k_id = 0)
677 {
678 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
679 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
680 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
681 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
683 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
684
685 const AElementwiseOperation a_element_op{};
686 const BElementwiseOperation b_element_op{};
687 const CElementwiseOperation c_element_op{};
688
689 // divide block work by [M, N]
690 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
691
692 const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
693 make_multi_index(static_cast<index_t>(blockIdx.x)));
694
695 if(!block_2_ctile_map.ValidCTileIndex(
696 block_work_idx,
697 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
698 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
699 {
700 return;
701 }
702
703 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
704 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
705
706 // HACK: this force m/n_block_data_idx_on_grid into SGPR
707 const index_t m_block_data_idx_on_grid =
708 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
709
710 const index_t n_block_data_idx_on_grid =
711 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
712
713 // lds max alignment
714 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
715
716 // A matrix in LDS memory, dst of blockwise copy
717 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
718
719 // B matrix in LDS memory, dst of blockwise copy
720 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
721
722 // A matrix blockwise copy
723 auto a_blockwise_copy =
725 AElementwiseOperation,
729 ABlockTransferThreadClusterLengths_AK0_M_AK1,
730 ABlockTransferThreadClusterArrangeOrder,
731 ADataType,
732 ADataType,
733 decltype(a_grid_desc_ak0_m_ak1),
734 decltype(a_block_desc_ak0_m_ak1),
735 ABlockTransferSrcAccessOrder,
737 ABlockTransferSrcVectorDim,
738 2,
739 ABlockTransferSrcScalarPerVector,
740 ABlockTransferDstScalarPerVector_AK1,
741 1,
742 1,
743 AThreadTransferSrcResetCoordinateAfterRun,
744 true,
745 BlockwiseGemmPipe::GlobalBufferNum>(
746 a_grid_desc_ak0_m_ak1,
747 make_multi_index(k_id, m_block_data_idx_on_grid, 0),
748 a_element_op,
749 a_block_desc_ak0_m_ak1,
750 make_multi_index(0, 0, 0),
752
753 // B matrix blockwise copy
754 auto b_blockwise_copy =
756 BElementwiseOperation,
760 BBlockTransferThreadClusterLengths_BK0_N_BK1,
761 BBlockTransferThreadClusterArrangeOrder,
762 BDataType,
763 BDataType,
764 decltype(b_grid_desc_bk0_n_bk1),
765 decltype(b_block_desc_bk0_n_bk1),
766 BBlockTransferSrcAccessOrder,
768 BBlockTransferSrcVectorDim,
769 2,
770 BBlockTransferSrcScalarPerVector,
771 BBlockTransferDstScalarPerVector_BK1,
772 1,
773 1,
774 BThreadTransferSrcResetCoordinateAfterRun,
775 true,
776 BlockwiseGemmPipe::GlobalBufferNum>(
777 b_grid_desc_bk0_n_bk1,
778 make_multi_index(k_id, n_block_data_idx_on_grid, 0),
779 b_element_op,
780 b_block_desc_bk0_n_bk1,
781 make_multi_index(0, 0, 0),
783
784 // LDS allocation for A and B: be careful of alignment
785 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
786 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
787
788 // Cast after lds
790 static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
791
793 static_cast<BDataType*>(p_shared) +
794 a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
795 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
796
797 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
798 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
799
800 // Blockwise GEMM pipeline
801 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
802 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
803 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
804
805 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
806 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
807 (KPerBlock * problem.KBatch));
808
809 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
810 a_block_desc_ak0_m_ak1,
811 a_blockwise_copy,
812 a_grid_buf,
813 a_block_buf,
814 a_block_slice_copy_step,
815 b_grid_desc_bk0_n_bk1,
816 b_block_desc_bk0_n_bk1,
817 b_blockwise_copy,
818 b_grid_buf,
819 b_block_buf,
820 b_block_slice_copy_step,
821 c_thread_buf,
822 num_k_block_main_loop);
823
824 // shuffle C and write out
825 {
826 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
827 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
828 "wrong!");
829
830 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
831 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
832
833 // TODO: hacky, fix it!
834 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
835 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
836
837 // TODO: hacky, fix it!
838 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
839 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
840 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
841
842 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
843 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
844 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
845 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
846 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
847 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
848 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
849 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
850
851 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
853
854 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
855 static_cast<CShuffleDataType*>(p_shared),
856 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
857
858 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
859 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
863 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
864 M1, // M1 = MWave
865 M2, // M2 * M3 * M4 = MPerXdl
866 M3,
867 M4)),
870 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
871 N1, // N1 = NWave
872 N2))), // N2 = NPerXdl
876
877 // calculate origin of thread output tensor on global memory
878 // blockwise GEMM c matrix starting index
879 const auto c_thread_mtx_on_block =
880 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
881
882 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
883 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
884
885 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
887 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
890
891 const auto m_thread_data_on_block_idx =
892 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
893 make_multi_index(m_thread_data_on_block));
894
895 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
900
901 const auto n_thread_data_on_block_idx =
902 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
903 make_multi_index(n_thread_data_on_block));
904
905 // shuffle: threadwise copy C from VGPR to LDS
906 auto c_thread_copy_vgpr_to_lds =
908 CShuffleDataType,
909 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
910 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
912 Sequence<CShuffleMXdlPerWavePerShuffle,
913 CShuffleNXdlPerWavePerShuffle,
914 I1,
915 I1,
916 M2,
917 I1,
918 M4,
919 I1>,
921 7,
922 1,
924 1,
925 true>{
926 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
928 0,
929 m_thread_data_on_block_idx[I1],
930 n_thread_data_on_block_idx[I1],
931 m_thread_data_on_block_idx[I2],
932 m_thread_data_on_block_idx[I3],
933 m_thread_data_on_block_idx[I4],
934 n_thread_data_on_block_idx[I2]),
936
937 // shuffle: blockwise copy C from LDS to global
938 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
939 ThisThreadBlock, // ThreadGroup
940 CElementwiseOperation, // ElementwiseOperation,
941 CGlobalMemoryDataOperation, // DstInMemOp,
942 Sequence<1,
943 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
944 1,
945 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
946 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
947 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
948 CShuffleDataType, // typename SrcData,
949 CDataType, // typename DstData,
950 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
951 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
952 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
953 3, // index_t VectorDim,
954 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
955 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
956 false> // bool ThreadTransferDstResetCoordinateAfterRun>
957 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
958 make_multi_index(0, 0, 0, 0),
959 c_grid_desc_mblock_mperblock_nblock_nperblock,
960 make_multi_index(block_m_id, 0, block_n_id, 0),
961 c_element_op};
962
963 // space filling curve for threadwise C in VGPR
964 constexpr auto sfc_c_vgpr =
967 Sequence<CShuffleMXdlPerWavePerShuffle,
968 CShuffleNXdlPerWavePerShuffle,
969 1,
970 1,
971 M2,
972 1,
973 M4,
974 1>>{};
975
976 // space filling curve for shuffled blockwise C in global mem
977 constexpr auto sfc_c_global =
980 Sequence<1,
981 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
982 1,
983 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
984
985 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
986
987 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
988
989 static_for<0, num_access, 1>{}([&](auto access_id) {
990 // make sure it's safe to write to LDS
992
993 // each thread write its data from VGPR to LDS
994 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
995 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
996 c_thread_buf,
997 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
998 c_shuffle_block_buf);
999
1000 // make sure it's safe to read from LDS
1002
1003 // each block copy its data from LDS to global
1004 c_shuffle_block_copy_lds_to_global.Run(
1005 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1006 c_shuffle_block_buf,
1007 c_grid_desc_mblock_mperblock_nblock_nperblock,
1008 c_grid_buf);
1009
1010 if constexpr(access_id < num_access - 1)
1011 {
1012 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1013
1014 // move on C
1015 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1016 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1017 }
1018 });
1019 }
1020 }
1021
1022 template <typename AGridDesc_AK0_M_K1,
1023 typename BGridDesc_BK0_N_K1,
1024 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1025 bool HasMainKBlockLoop,
1026 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1027 TailNumber TailNum = TailNumber::Odd>
1028 __device__ static void Run_2Lds(const ADataType* p_a_grid,
1029 const BDataType* p_b_grid,
1030 CDataType* p_c_grid,
1031 void* p_shared_0,
1032 void* p_shared_1,
1033 const Problem& problem,
1034 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1035 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1036 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1037 c_grid_desc_mblock_mperblock_nblock_nperblock,
1038 const index_t k_id = 0)
1039 {
1040 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1041 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1042 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1043 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1045 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1046
1047 const AElementwiseOperation a_element_op{};
1048 const BElementwiseOperation b_element_op{};
1049 const CElementwiseOperation c_element_op{};
1050
1051 // divide block work by [M, N]
1052 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1053
1054 const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
1055 make_multi_index(static_cast<index_t>(blockIdx.x)));
1056
1057 if(!block_2_ctile_map.ValidCTileIndex(
1058 block_work_idx,
1059 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1060 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1061 {
1062 return;
1063 }
1064
1065 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1066 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1067
1068 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1069 const index_t m_block_data_idx_on_grid =
1070 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1071
1072 const index_t n_block_data_idx_on_grid =
1073 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1074
1075 // lds max alignment
1076 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1077
1078 // A matrix in LDS memory, dst of blockwise copy
1079 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1080
1081 // B matrix in LDS memory, dst of blockwise copy
1082 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1083
1084 // A matrix blockwise copy
1085 auto a_blockwise_copy =
1087 AElementwiseOperation,
1091 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1092 ABlockTransferThreadClusterArrangeOrder,
1093 ADataType,
1094 ADataType,
1095 decltype(a_grid_desc_ak0_m_ak1),
1096 decltype(a_block_desc_ak0_m_ak1),
1097 ABlockTransferSrcAccessOrder,
1099 ABlockTransferSrcVectorDim,
1100 2,
1101 ABlockTransferSrcScalarPerVector,
1102 ABlockTransferDstScalarPerVector_AK1,
1103 1,
1104 1,
1105 AThreadTransferSrcResetCoordinateAfterRun,
1106 true,
1107 BlockwiseGemmPipe::GlobalBufferNum>(
1108 a_grid_desc_ak0_m_ak1,
1109 make_multi_index(k_id, m_block_data_idx_on_grid, 0),
1110 a_element_op,
1111 a_block_desc_ak0_m_ak1,
1112 make_multi_index(0, 0, 0),
1114
1115 // B matrix blockwise copy
1116 auto b_blockwise_copy =
1118 BElementwiseOperation,
1122 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1123 BBlockTransferThreadClusterArrangeOrder,
1124 BDataType,
1125 BDataType,
1126 decltype(b_grid_desc_bk0_n_bk1),
1127 decltype(b_block_desc_bk0_n_bk1),
1128 BBlockTransferSrcAccessOrder,
1130 BBlockTransferSrcVectorDim,
1131 2,
1132 BBlockTransferSrcScalarPerVector,
1133 BBlockTransferDstScalarPerVector_BK1,
1134 1,
1135 1,
1136 BThreadTransferSrcResetCoordinateAfterRun,
1137 true,
1138 BlockwiseGemmPipe::GlobalBufferNum>(
1139 b_grid_desc_bk0_n_bk1,
1140 make_multi_index(k_id, n_block_data_idx_on_grid, 0),
1141 b_element_op,
1142 b_block_desc_bk0_n_bk1,
1143 make_multi_index(0, 0, 0),
1145
1146 // LDS allocation for A and B: be careful of alignment
1147 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1148 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1149
1150 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1151 static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1152
1153 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1154 static_cast<BDataType*>(p_shared_0) +
1155 a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
1156 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1157
1158 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1159 static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1160
1161 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1162 static_cast<BDataType*>(p_shared_1) +
1163 a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
1164 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1165
1166 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
1167 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1168
1169 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1170 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1171
1172 // Blockwise GEMM pipeline
1173 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1174 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1175 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1176
1177 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1178 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1179 (KPerBlock * problem.KBatch));
1180
1181 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1182 a_block_desc_ak0_m_ak1,
1183 a_blockwise_copy,
1184 a_grid_buf,
1185 a_block_bufs,
1186 a_block_slice_copy_step,
1187 b_grid_desc_bk0_n_bk1,
1188 b_block_desc_bk0_n_bk1,
1189 b_blockwise_copy,
1190 b_grid_buf,
1191 b_block_bufs,
1192 b_block_slice_copy_step,
1193 c_thread_buf,
1194 num_k_block_main_loop);
1195
1196 // shuffle C and write out
1197 {
1198 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1199 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1200 "wrong!");
1201
1202 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1203 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1204
1205 // TODO: hacky, fix it!
1206 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1207 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1208
1209 // TODO: hacky, fix it!
1210 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1211 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1212 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1213
1214 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1215 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1216 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1217 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1218 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1219 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1220 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1221 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1222
1223 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1225
1226 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1227 static_cast<CShuffleDataType*>(p_shared_0),
1228 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1229
1230 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1231 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1232 make_tuple(
1235 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1236 M1, // M1 = MWave
1237 M2, // M2 * M3 * M4 = MPerXdl
1238 M3,
1239 M4)),
1242 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1243 N1, // N1 = NWave
1244 N2))), // N2 = NPerXdl
1246 make_tuple(
1248
1249 // calculate origin of thread output tensor on global memory
1250 // blockwise GEMM c matrix starting index
1251 const auto c_thread_mtx_on_block =
1252 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1253
1254 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1255 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1256
1257 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1259 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1262
1263 const auto m_thread_data_on_block_idx =
1264 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1265 make_multi_index(m_thread_data_on_block));
1266
1267 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1272
1273 const auto n_thread_data_on_block_idx =
1274 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1275 make_multi_index(n_thread_data_on_block));
1276
1277 // shuffle: threadwise copy C from VGPR to LDS
1278 auto c_thread_copy_vgpr_to_lds =
1280 CShuffleDataType,
1281 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1282 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1284 Sequence<CShuffleMXdlPerWavePerShuffle,
1285 CShuffleNXdlPerWavePerShuffle,
1286 I1,
1287 I1,
1288 M2,
1289 I1,
1290 M4,
1291 I1>,
1293 7,
1294 1,
1296 1,
1297 true>{
1298 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1300 0,
1301 m_thread_data_on_block_idx[I1],
1302 n_thread_data_on_block_idx[I1],
1303 m_thread_data_on_block_idx[I2],
1304 m_thread_data_on_block_idx[I3],
1305 m_thread_data_on_block_idx[I4],
1306 n_thread_data_on_block_idx[I2]),
1308
1309 // shuffle: blockwise copy C from LDS to global
1310 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1311 ThisThreadBlock, // ThreadGroup
1312 CElementwiseOperation, // ElementwiseOperation,
1313 CGlobalMemoryDataOperation, // DstInMemOp,
1314 Sequence<1,
1315 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1316 1,
1317 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1318 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1319 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1320 CShuffleDataType, // typename SrcData,
1321 CDataType, // typename DstData,
1322 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1323 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1324 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1325 3, // index_t VectorDim,
1326 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1327 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1328 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1329 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1330 make_multi_index(0, 0, 0, 0),
1331 c_grid_desc_mblock_mperblock_nblock_nperblock,
1332 make_multi_index(block_m_id, 0, block_n_id, 0),
1333 c_element_op};
1334
1335 // space filling curve for threadwise C in VGPR
1336 constexpr auto sfc_c_vgpr =
1339 Sequence<CShuffleMXdlPerWavePerShuffle,
1340 CShuffleNXdlPerWavePerShuffle,
1341 1,
1342 1,
1343 M2,
1344 1,
1345 M4,
1346 1>>{};
1347
1348 // space filling curve for shuffled blockwise C in global mem
1349 constexpr auto sfc_c_global =
1352 Sequence<1,
1353 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1354 1,
1355 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1356
1357 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1358
1359 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1360
1361 static_for<0, num_access, 1>{}([&](auto access_id) {
1362 // make sure it's safe to write to LDS
1364
1365 // each thread write its data from VGPR to LDS
1366 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1367 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1368 c_thread_buf,
1369 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1370 c_shuffle_block_buf);
1371
1372 // make sure it's safe to read from LDS
1374
1375 // each block copy its data from LDS to global
1376 c_shuffle_block_copy_lds_to_global.Run(
1377 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1378 c_shuffle_block_buf,
1379 c_grid_desc_mblock_mperblock_nblock_nperblock,
1380 c_grid_buf);
1381
1382 if constexpr(access_id < num_access - 1)
1383 {
1384 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1385
1386 // move on C
1387 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1388 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1389 }
1390 });
1391 }
1392 }
1393};
1394
1395} // namespace ck
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_)
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:248
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:267
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:266
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:265
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:218
index_t K
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:230
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:236
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:231
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:235
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:240
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:232
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:241
index_t N
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:229
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:234
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:242
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:238
index_t M
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:228
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:239
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:193
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:237
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:233
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:66
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, AccDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const index_t k_id=0)
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:667
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:576
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, AccDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const index_t k_id=0)
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:1028
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340