gridwise_gemm_xdl_cshuffle_v2.hpp Source File

gridwise_gemm_xdl_cshuffle_v2.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_v2.hpp Source File
gridwise_gemm_xdl_cshuffle_v2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
17
18namespace ck {
19
20template <typename GridwiseGemm, bool HasMainKBlockLoop, index_t TailNum = 3>
21__global__ void
22#if CK_USE_LAUNCH_BOUNDS
23__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
24#endif
25 // __attribute__((amdgpu_waves_per_eu(1, 1)))
26 kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
27{
28#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
29 defined(__gfx12__)
30 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
31 {
32 // Pass two lds pointer is the key to tell compiler that ds_read/write
33 // operate on different lds chunk at same time without order dependecy
34 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
35 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
36
37 GridwiseGemm::template Run<HasMainKBlockLoop, TailNum>(
38 karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg);
39 }
40#else
41 ignore = karg;
42#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
43}
44
45template <typename GridwiseGemm,
46 typename FloatA,
47 typename FloatB,
48 typename FloatC,
49 bool HasMainKBlockLoop>
50__global__ void
51#if CK_USE_LAUNCH_BOUNDS
52__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
53#endif
54 kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid,
55 const FloatB* p_b_grid,
56 FloatC* p_c_grid,
57 typename GridwiseGemm::Problem problem)
58{
59#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
60 defined(__gfx12__)
61 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
62 {
63 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
64 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
65
66 GridwiseGemm::template Run<HasMainKBlockLoop>(
67 p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem);
68 }
69#else
70 ignore = p_a_grid;
71 ignore = p_b_grid;
72 ignore = p_c_grid;
73 ignore = problem;
74#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
75}
76
77template <typename ALayout,
78 typename BLayout,
79 typename CLayout,
80 typename FloatA,
81 typename FloatB,
82 typename FloatGemmAcc,
83 typename FloatCShuffle,
84 typename FloatC,
85 typename AElementwiseOperation,
86 typename BElementwiseOperation,
87 typename CElementwiseOperation,
89 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
90 index_t NumGemmKPrefetchStage,
91 index_t BlockSize,
92 index_t MPerBlock,
93 index_t NPerBlock,
94 index_t KPerBlock,
95 index_t AK1Value,
96 index_t BK1Value,
97 index_t MPerXdl,
98 index_t NPerXdl,
99 index_t MXdlPerWave,
100 index_t NXdlPerWave,
101 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
102 typename ABlockTransferThreadClusterArrangeOrder,
103 typename ABlockTransferSrcAccessOrder,
104 index_t ABlockTransferSrcVectorDim,
105 index_t ABlockTransferSrcScalarPerVector,
106 index_t ABlockTransferDstScalarPerVector_AK1,
107 bool AThreadTransferSrcResetCoordinateAfterRun,
108 index_t ABlockLdsExtraM,
109 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
110 typename BBlockTransferThreadClusterArrangeOrder,
111 typename BBlockTransferSrcAccessOrder,
112 index_t BBlockTransferSrcVectorDim,
113 index_t BBlockTransferSrcScalarPerVector,
114 index_t BBlockTransferDstScalarPerVector_BK1,
115 bool BThreadTransferSrcResetCoordinateAfterRun,
116 index_t BBlockLdsExtraN,
117 index_t CShuffleMXdlPerWavePerShuffle,
118 index_t CShuffleNXdlPerWavePerShuffle,
119 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
120 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
121 LoopScheduler LoopSched,
123 typename ComputeTypeA = FloatC,
124 typename ComputeTypeB = ComputeTypeA>
126{
127 static constexpr auto I0 = Number<0>{};
128 static constexpr auto I1 = Number<1>{};
129 static constexpr auto I2 = Number<2>{};
130 static constexpr auto I3 = Number<3>{};
131 static constexpr auto I4 = Number<4>{};
132 static constexpr auto I5 = Number<5>{};
133 static constexpr auto I6 = Number<6>{};
134 static constexpr auto I7 = Number<7>{};
135
136 // K1 should be Number<...>
137 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
138 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
139 static constexpr auto AK1Number = Number<AK1Value>{};
140 static constexpr auto BK1Number = Number<BK1Value>{};
141
143
144 __host__ static auto CalculateGridSize(index_t M, index_t N)
145 {
146 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
147 }
148
149 __host__ static auto CalculateMPadded(index_t M)
150 {
151 return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
152 }
153
154 __host__ static auto CalculateNPadded(index_t N)
155 {
156 return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
157 }
158
159 __host__ static auto CalculateKPadded(index_t K)
160 {
161 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
162 }
163
164 __host__ static auto CalculateAK0(index_t K)
165 {
166 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
167
168 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
169 GemmSpec == GemmSpecialization::MNKPadding ||
170 GemmSpec == GemmSpecialization::KPadding ||
171 GemmSpec == GemmSpecialization::NKPadding)
172 {
173 return CalculateKPadded(K) / AK1Value;
174 }
175 else
176 {
177 return K / AK1Value;
178 }
179 }
180
181 __host__ static auto CalculateBK0(index_t K)
182 {
183 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
184
185 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
186 GemmSpec == GemmSpecialization::MNKPadding ||
187 GemmSpec == GemmSpecialization::KPadding ||
188 GemmSpec == GemmSpecialization::MKPadding)
189 {
190 return CalculateKPadded(K) / BK1Value;
191 }
192 else
193 {
194 return K / BK1Value;
195 }
196 }
197
198 __host__ static auto CalculateMBlock(index_t M)
199 {
200 return math::integer_divide_floor(M, MPerBlock);
201 }
202
203 __host__ static auto CalculateNBlock(index_t N)
204 {
205 return math::integer_divide_floor(N, NPerBlock);
206 }
207
208 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
209 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
210 {
211 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
212 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
213
215 TileDesc_K0_MN_K1{},
221 }
222
223 __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
224 index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
225 {
226 const auto a_grid_desc_mraw_kraw = [&]() {
228 {
229 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
230 }
232 {
233 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
234 }
235 }();
236
237 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
238
239 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
240 GemmSpec == GemmSpecialization::MNKPadding)
241 {
242 // pad both M and K
243 const auto a_grid_desc_m_k =
244 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
246 make_right_pad_transform(K, KPad - K)),
249
250 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
251 a_grid_desc_m_k,
256
257 return a_grid_desc_ak0_m_ak1;
258 }
259 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
260 GemmSpec == GemmSpecialization::MNPadding)
261 {
262 // pad M, but not K
263 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
264 a_grid_desc_mraw_kraw,
266 make_right_pad_transform(M, MPad - M)),
269
270 return a_grid_desc_ak0_m_ak1;
271 }
272 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
273 GemmSpec == GemmSpecialization::NKPadding)
274 {
275 // pad K, but not M
276 const auto a_grid_desc_m_k = transform_tensor_descriptor(
277 a_grid_desc_mraw_kraw,
281
282 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
283 a_grid_desc_m_k,
288
289 return a_grid_desc_ak0_m_ak1;
290 }
291 else
292 {
293 // not pad M or K
294 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
295 a_grid_desc_mraw_kraw,
300
301 return a_grid_desc_ak0_m_ak1;
302 }
303 }
304
305 __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
306 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
307 {
308 const auto b_grid_desc_nraw_kraw = [&]() {
310 {
311 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
312 }
314 {
315 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
316 }
317 }();
318
319 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
320
321 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
322 GemmSpec == GemmSpecialization::MNKPadding)
323 {
324 // pad both N and K
325 const auto b_grid_desc_n_k =
326 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
328 make_right_pad_transform(K, KPad - K)),
331
332 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
333 b_grid_desc_n_k,
338
339 return b_grid_desc_bk0_n_bk1;
340 }
341 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
342 GemmSpec == GemmSpecialization::MNPadding)
343 {
344 // pad N, but not K
345 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
346 b_grid_desc_nraw_kraw,
348 make_right_pad_transform(N, NPad - N)),
351
352 return b_grid_desc_bk0_n_bk1;
353 }
354 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
355 GemmSpec == GemmSpecialization::MKPadding)
356 {
357 // pad K, but not N
358 const auto b_grid_desc_n_k = transform_tensor_descriptor(
359 b_grid_desc_nraw_kraw,
363
364 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
365 b_grid_desc_n_k,
370
371 return b_grid_desc_bk0_n_bk1;
372 }
373 else
374 {
375 // not pad N or K
376 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
377 b_grid_desc_nraw_kraw,
382
383 return b_grid_desc_bk0_n_bk1;
384 }
385 }
386
387 template <typename ABlockDesc_AK0_M_AK1>
388 __host__ __device__ static constexpr auto
389 MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
390 {
391 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
392
393 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
394 }
395
396 template <typename BBlockDesc_BK0_N_BK1>
397 __host__ __device__ static constexpr auto
398 MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
399 {
400 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
401
402 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
403 }
404
405 __host__ __device__ static auto
407 {
408 const auto c_grid_desc_mraw_nraw = [&]() {
410 {
411 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
412 }
414 {
415 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
416 }
417 }();
418
419 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
420
421 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
422 GemmSpec == GemmSpecialization::MNKPadding)
423 {
424 // pad M and N
425 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
427 make_right_pad_transform(N, NPad - N)),
430 }
431 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
432 GemmSpec == GemmSpecialization::MKPadding)
433 {
434 // pad M, but not N
436 c_grid_desc_mraw_nraw,
440 }
441 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
442 GemmSpec == GemmSpecialization::NKPadding)
443 {
444 // pad N, but not M
446 c_grid_desc_mraw_nraw,
450 }
451 else
452 {
453 // not pad M or N
454 return c_grid_desc_mraw_nraw;
455 }
456 }
457
458 struct Problem
459 {
460 __host__ Problem(index_t M_,
461 index_t N_,
462 index_t K_,
463 index_t StrideA_,
464 index_t StrideB_,
465 index_t StrideC_)
466 : M{M_},
467 N{N_},
468 K{K_},
469 StrideA{StrideA_},
470 StrideB{StrideB_},
471 StrideC{StrideC_},
475 AK0{CalculateAK0(K_)},
476 BK0{CalculateBK0(K_)},
479 {
480 }
481
482 __host__ void Print() const
483 {
484 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
485 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
486 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
487 << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
488 << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
489 }
490
504 };
505
506 // Argument
508 {
509 __host__ Argument(const FloatA* p_a_grid_,
510 const FloatB* p_b_grid_,
511 FloatC* p_c_grid_,
512 index_t M_,
513 index_t N_,
514 index_t K_,
515 index_t StrideA_,
516 index_t StrideB_,
517 index_t StrideC_)
518 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
519 p_a_grid{p_a_grid_},
520 p_b_grid{p_b_grid_},
521 p_c_grid{p_c_grid_}
522 {
523 }
524
525 const FloatA* p_a_grid;
526 const FloatB* p_b_grid;
527 FloatC* p_c_grid;
528 };
529
530 // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
533
534 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
535 {
536 // A matrix in LDS memory, dst of blockwise copy
540 }
541
542 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
543 {
544 // B matrix in LDS memory, dst of blockwise copy
548 }
549
551 {
552 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
553 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
554
555 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
559 I1,
561
562 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
563 }
564
565 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
566 {
567 // LDS allocation for A and B: be careful of alignment
568 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
569 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
570
571 // lds max alignment
572 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
573
574 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
575 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
576
577 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
578 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
579
580 // LDS allocation for C shuffle in LDS
581 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
583
584 constexpr auto c_block_size =
585 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
586
587 return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
588 b_block_space_size_aligned * sizeof(ComputeTypeB)),
589 c_block_size * sizeof(FloatCShuffle));
590 }
591
592 template <
593 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
594 __device__ static bool constexpr IsValidCompilationParameter()
595 {
596 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
597 BlockSize,
598 MPerBlock,
599 NPerBlock,
600 MPerXdl,
601 NPerXdl,
602 MXdlPerWave,
603 NXdlPerWave,
604 FloatC,
605 CGlobalMemoryDataOperation>();
606 }
607
608 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
609 __host__ static constexpr bool CheckValidity(const Problem& problem)
610 {
611 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
612 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
613 "Invalid tuning param!");
614
619 {
620 if(!(problem.M % MPerBlock == 0))
621 {
622 return false;
623 }
624 }
625
630 {
631 if(!(problem.N % NPerBlock == 0))
632 {
633 return false;
634 }
635 }
636
641 {
642 if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
643 !(CalculateKPadded(problem.K) % BK1Value == 0))
644 {
645 return false;
646 }
647 }
648 else
649 {
650 if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
651 {
652 return false;
653 }
654 }
655
657 {
658 if(problem.K % ABlockTransferSrcScalarPerVector != 0)
659 {
660 return false;
661 }
662 }
663 else
664 {
665 if(problem.M % ABlockTransferSrcScalarPerVector != 0)
666 {
667 return false;
668 }
669 }
670
672 {
673 if(problem.N % BBlockTransferSrcScalarPerVector != 0)
674 {
675 return false;
676 }
677 }
678 else
679 {
680 if(problem.K % BBlockTransferSrcScalarPerVector != 0)
681 {
682 return false;
683 }
684 }
685
687 {
688 if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
689 {
690 return false;
691 }
692 }
693 else
694 {
695 if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
696 {
697 return false;
698 }
699 }
700
701 // check gridwise gemm pipeline
702 const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
703
704 if(num_k_loop < 4)
705 {
706 return false;
707 }
708
709 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
710 return true;
711 }
712
713 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
714 {
715 const index_t num_loop = K / KPerBlock;
716
717 return num_loop > 3;
718 }
719
720 __host__ static constexpr index_t CalculateKBlockLoopTailNum(index_t K)
721 {
722 const index_t num_loop = K / KPerBlock;
723
724 if(num_loop % 2 == 1)
725 return 3;
726 else
727 return 2;
728 }
729
730 template <typename CGridDesc>
732 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
733 {
734 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
735 c_grid_desc_m_n,
740
741 return c_grid_desc_mblock_mperblock_nblock_nperblock;
742 }
743
744 // return block_id to C matrix tile idx (m0, n0) mapping
745 // if arch = gfx942
747
748 template <bool HasMainKBlockLoop, index_t TailNum = 3>
749 __device__ static void Run(const FloatA* p_a_grid,
750 const FloatB* p_b_grid,
751 FloatC* p_c_grid,
752 void* p_shared_0,
753 void* p_shared_1,
754 const Problem& problem)
755 {
756 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
757 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
758 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
759 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
760 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
761 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
762
763 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
765 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
766
767 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
768 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
769 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
770 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
772 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
773
774 const AElementwiseOperation a_element_op{};
775 const BElementwiseOperation b_element_op{};
776 const CElementwiseOperation c_element_op{};
777
778 // divide block work by [M, N]
779 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
780
781 const auto block_work_idx =
782 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
783
784 if(!block_2_ctile_map.ValidCTileIndex(
785 block_work_idx,
786 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
787 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
788 {
789 return;
790 }
791#if 0
792 if(threadIdx.x == 0){
793 printf("Hardware assigned No. %03d workgroup of logical C tile (%02d, %02d) on %d th XCC Die, %d th SE, %d th CU\n",
795 block_work_idx[I0],
796 block_work_idx[I1],
797 __smid()>>6 & 0xf,
798 __smid()>>4 & 0x3,
799 __smid() & 0xf);
800 }
801#endif
802 // HACK: this force m/n_block_data_idx_on_grid into SGPR
803 const index_t m_block_data_idx_on_grid =
804 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
805
806 const index_t n_block_data_idx_on_grid =
807 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
808
809 // lds max alignment
810 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
811
812 // A matrix in LDS memory, dst of blockwise copy
813 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
814
815 // B matrix in LDS memory, dst of blockwise copy
816 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
817
818 // A matrix blockwise copy
819 auto a_blockwise_copy =
821 AElementwiseOperation,
825 ABlockTransferThreadClusterLengths_AK0_M_AK1,
826 ABlockTransferThreadClusterArrangeOrder,
827 FloatA,
828 ComputeTypeA,
829 decltype(a_grid_desc_ak0_m_ak1),
830 decltype(a_block_desc_ak0_m_ak1),
831 ABlockTransferSrcAccessOrder,
833 ABlockTransferSrcVectorDim,
834 2,
835 ABlockTransferSrcScalarPerVector,
836 ABlockTransferDstScalarPerVector_AK1,
837 1,
838 1,
839 AThreadTransferSrcResetCoordinateAfterRun,
840 true>(
841 a_grid_desc_ak0_m_ak1,
842 make_multi_index(0, m_block_data_idx_on_grid, 0),
843 a_element_op,
844 a_block_desc_ak0_m_ak1,
845 make_multi_index(0, 0, 0),
847
848 // B matrix blockwise copy
849 auto b_blockwise_copy =
851 BElementwiseOperation,
855 BBlockTransferThreadClusterLengths_BK0_N_BK1,
856 BBlockTransferThreadClusterArrangeOrder,
857 FloatB,
858 ComputeTypeB,
859 decltype(b_grid_desc_bk0_n_bk1),
860 decltype(b_block_desc_bk0_n_bk1),
861 BBlockTransferSrcAccessOrder,
863 BBlockTransferSrcVectorDim,
864 2,
865 BBlockTransferSrcScalarPerVector,
866 BBlockTransferDstScalarPerVector_BK1,
867 1,
868 1,
869 BThreadTransferSrcResetCoordinateAfterRun,
870 true>(
871 b_grid_desc_bk0_n_bk1,
872 make_multi_index(0, n_block_data_idx_on_grid, 0),
873 b_element_op,
874 b_block_desc_bk0_n_bk1,
875 make_multi_index(0, 0, 0),
877
878 // GEMM definition
879 // c_mtx += transpose(a_mtx) * b_mtx
880 // a_mtx[K0PerBlock, MPerBlock] is in LDS
881 // b_mtx[K0PerBlock, NPerBlock] is in LDS
882 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
883 // register
884 // sanity check
885 constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
886 constexpr bool is_single_rate_mfma =
888 lcm_AK1_BK1 <= 4) ||
889 (is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
891 lcm_AK1_BK1 < 32))
892 ? true
893 : false;
894 constexpr auto is_scale_mfma = false;
895 constexpr index_t KPack = math::max(lcm_AK1_BK1,
896 MfmaSelector<ComputeTypeA,
897 MPerXdl,
898 NPerXdl,
899 ComputeTypeA,
900 is_single_rate_mfma,
901 is_scale_mfma>::selected_mfma.k_per_blk);
902
903 // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
904 // BlockSize,
905 // ComputeType,
906 // FloatGemmAcc,
907 // decltype(a_block_desc_ak0_m_ak1),
908 // decltype(b_block_desc_bk0_n_bk1),
909 // MPerXdl,
910 // NPerXdl,
911 // MXdlPerWave,
912 // NXdlPerWave,
913 // KPack,
914 // LoopSched>();
915 auto blockwise_gemm_pipeline = BlockwiseGemmXdlops_pipeline_v4<
916 BlockSize,
917 ComputeTypeA,
918 FloatGemmAcc,
919 decltype(a_block_desc_ak0_m_ak1),
920 decltype(b_block_desc_bk0_n_bk1),
921 decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
922 decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
923 MPerBlock,
924 NPerBlock,
925 KPerBlock,
926 MPerXdl,
927 NPerXdl,
928 MXdlPerWave,
929 NXdlPerWave,
930 KPack>{}; // TransposeC
931
932 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
933
934 // LDS allocation for A and B: be careful of alignment
935 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
936 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
937
938 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
939 static_cast<ComputeTypeA*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
940
941 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
942 static_cast<ComputeTypeB*>(p_shared_0) + a_block_space_size_aligned,
943 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
944
945 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
946 static_cast<ComputeTypeA*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
947
948 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
949 static_cast<ComputeTypeB*>(p_shared_1) + a_block_space_size_aligned,
950 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
951
952 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
953 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
954
955 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
956 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
957
958 // gridwise GEMM pipeline
959 static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
960 // const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
961
962 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
963 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
964 KPerBlock);
965
966 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
967 a_block_desc_ak0_m_ak1,
968 a_blockwise_copy,
969 a_grid_buf,
970 a_block_bufs,
971 a_block_slice_copy_step,
972 b_grid_desc_bk0_n_bk1,
973 b_block_desc_bk0_n_bk1,
974 b_blockwise_copy,
975 b_grid_buf,
976 b_block_bufs,
977 b_block_slice_copy_step,
978 c_thread_buf,
979 num_k_block_main_loop);
980
981 // shuffle C and write out
982 {
983 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
984 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
985 "wrong!");
986
987 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
988 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
989
990 // TODO: hacky, fix it!
991 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
992 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
993
994 // TODO: hacky, fix it!
995 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
996 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
997 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
998
999 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1000 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1001 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1002 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1003 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1004 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1005 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1006 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1007
1008 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1010
1011 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1012 static_cast<FloatCShuffle*>(p_shared_0),
1013 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1014
1015 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1016 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1017 make_tuple(
1020 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1021 M1, // M1 = MWave
1022 M2, // M2 * M3 * M4 = MPerXdl
1023 M3,
1024 M4)),
1027 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1028 N1, // N1 = NWave
1029 N2))), // N2 = NPerXdl
1031 make_tuple(
1033
1034 // calculate origin of thread output tensor on global memory
1035 // blockwise GEMM c matrix starting index
1036 const auto c_thread_mtx_on_block =
1037 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1038
1039 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1040 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1041
1042 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1044 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1047
1048 const auto m_thread_data_on_block_idx =
1049 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1050 make_multi_index(m_thread_data_on_block));
1051
1052 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1057
1058 const auto n_thread_data_on_block_idx =
1059 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1060 make_multi_index(n_thread_data_on_block));
1061
1062 // shuffle: threadwise copy C from VGPR to LDS
1063 auto c_thread_copy_vgpr_to_lds =
1065 FloatCShuffle,
1066 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1067 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1069 Sequence<CShuffleMXdlPerWavePerShuffle,
1070 CShuffleNXdlPerWavePerShuffle,
1071 I1,
1072 I1,
1073 M2,
1074 I1,
1075 M4,
1076 I1>,
1078 7,
1079 1,
1081 1,
1082 true>{
1083 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1085 0,
1086 m_thread_data_on_block_idx[I1],
1087 n_thread_data_on_block_idx[I1],
1088 m_thread_data_on_block_idx[I2],
1089 m_thread_data_on_block_idx[I3],
1090 m_thread_data_on_block_idx[I4],
1091 n_thread_data_on_block_idx[I2]),
1093
1094 // shuffle: blockwise copy C from LDS to global
1095 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1096 ThisThreadBlock, // ThreadGroup
1097 CElementwiseOperation, // ElementwiseOperation,
1098 CGlobalMemoryDataOperation, // DstInMemOp,
1099 Sequence<1,
1100 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1101 1,
1102 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1103 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1104 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1105 FloatCShuffle, // typename SrcData,
1106 FloatC, // typename DstData,
1107 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1108 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1109 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1110 3, // index_t VectorDim,
1111 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1112 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1113 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1114 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1115 make_multi_index(0, 0, 0, 0),
1116 c_grid_desc_mblock_mperblock_nblock_nperblock,
1117 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1118 c_element_op};
1119
1120 // space filling curve for threadwise C in VGPR
1121 constexpr auto sfc_c_vgpr =
1124 Sequence<CShuffleMXdlPerWavePerShuffle,
1125 CShuffleNXdlPerWavePerShuffle,
1126 1,
1127 1,
1128 M2,
1129 1,
1130 M4,
1131 1>>{};
1132
1133 // space filling curve for shuffled blockwise C in global mem
1134 constexpr auto sfc_c_global =
1137 Sequence<1,
1138 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1139 1,
1140 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1141
1142 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1143
1144 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1145
1146 static_for<0, num_access, 1>{}([&](auto access_id) {
1147 // make sure it's safe to write to LDS
1149
1150 // each thread write its data from VGPR to LDS
1151 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1152 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1153 c_thread_buf,
1154 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1155 c_shuffle_block_buf);
1156
1157 // make sure it's safe to read from LDS
1159
1160 // each block copy its data from LDS to global
1161 c_shuffle_block_copy_lds_to_global.Run(
1162 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1163 c_shuffle_block_buf,
1164 c_grid_desc_mblock_mperblock_nblock_nperblock,
1165 c_grid_buf);
1166
1167 if constexpr(access_id < num_access - 1)
1168 {
1169 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1170
1171 // move on C
1172 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1173 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1174 }
1175 });
1176 }
1177 }
1178};
1179
1180} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
constexpr bool is_same_v
Definition type.hpp:283
__global__ void kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:26
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition blockwise_gemm_pipeline_xdlops.hpp:103
const FloatB * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:526
__host__ Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:509
FloatC * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:527
const FloatA * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:525
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:495
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:496
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:499
index_t N
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:492
index_t M
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:491
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:502
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:503
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:482
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:501
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:497
index_t K
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:493
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:494
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:500
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:460
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:498
Definition gridwise_gemm_xdl_cshuffle_v2.hpp:126
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340