gridwise_gemm_wmma_cshuffle_v3_common.hpp Source File

gridwise_gemm_wmma_cshuffle_v3_common.hpp Source File#

Composable Kernel: gridwise_gemm_wmma_cshuffle_v3_common.hpp Source File
gridwise_gemm_wmma_cshuffle_v3_common.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
7#include <iostream>
8#include <ostream>
9#endif
10
11#include "ck/utility/env.hpp"
28
29namespace ck {
30
31template <typename GridwiseGemm,
32 bool HasMainKBlockLoop,
33 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
34 index_t MinimumOccupancy = 1,
36__global__ void
37#if CK_USE_LAUNCH_BOUNDS
38__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
39#endif
40 kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
41{
42#if(defined(__gfx11__) || defined(__gfx12__))
43#if defined(__gfx11__)
44 // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
45 using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
46 if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
47 (std::is_same_v<e_data_type, ck::half_t> ||
48 std::is_same_v<e_data_type, ck::bhalf_t>)))
49 {
50#endif
51 constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
52 typename GridwiseGemm::EpilogueCShuffle>();
53 __shared__ char p_shared[LDS_size];
54
55 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
56
57 auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
58
59 GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
60 p_shared, splitk_batch_offset, karg, epilogue_args);
61
62#if defined(__gfx11__)
63 }
64#endif
65#else
66 ignore = karg;
67#endif
68}
69
70template <typename ALayout,
71 typename BLayout,
72 typename DsLayout,
73 typename ELayout,
74 typename AsDataType,
75 typename BsDataType,
76 typename AccDataType,
77 typename CShuffleDataType,
78 typename DsDataType,
79 typename EDataType,
80 typename AElementwiseOperation,
81 typename BElementwiseOperation,
82 typename CDEElementwiseOperation,
84 index_t BlockSize,
85 index_t MPerBlock,
86 index_t NPerBlock,
87 index_t KPerBlock,
88 index_t AK1Value,
89 index_t BK1Value,
90 index_t MPerWmma,
91 index_t NPerWmma,
92 index_t MRepeat,
93 index_t NRepeat,
94 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
95 typename ABlockTransferThreadClusterArrangeOrder,
96 typename ABlockTransferSrcAccessOrder,
97 index_t ABlockTransferSrcVectorDim,
98 index_t ABlockTransferSrcScalarPerVector,
99 index_t ABlockTransferDstScalarPerVector_AK1,
100 bool AThreadTransferSrcResetCoordinateAfterRun,
101 index_t ABlockLdsExtraM,
102 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
103 typename BBlockTransferThreadClusterArrangeOrder,
104 typename BBlockTransferSrcAccessOrder,
105 index_t BBlockTransferSrcVectorDim,
106 index_t BBlockTransferSrcScalarPerVector,
107 index_t BBlockTransferDstScalarPerVector_BK1,
108 bool BThreadTransferSrcResetCoordinateAfterRun,
109 index_t BBlockLdsExtraN,
110 index_t CShuffleMRepeatPerShuffle,
111 index_t CShuffleNRepeatPerShuffle,
112 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
113 typename CDEShuffleBlockTransferScalarPerVectors,
114 BlockGemmPipelineScheduler BlkGemmPipeSched,
115 BlockGemmPipelineVersion BlkGemmPipelineVer,
116 typename ComputeTypeA,
117 typename ComputeTypeB,
118 bool PermuteA,
119 bool PermuteB,
120 bool ForceThreadTileTransfer = false> // only needed for convolution (limitation)
122{
123
124 static constexpr auto I0 = Number<0>{};
125 static constexpr auto I1 = Number<1>{};
126 static constexpr auto I2 = Number<2>{};
127 static constexpr auto I3 = Number<3>{};
128 static constexpr auto I4 = Number<4>{};
129 static constexpr auto I5 = Number<5>{};
130 static constexpr auto I6 = Number<6>{};
131 static constexpr auto I7 = Number<7>{};
132
133 static constexpr index_t NumATensor = AsDataType::Size();
134 static constexpr index_t NumBTensor = BsDataType::Size();
135
136 using LDSTypeA =
137 typename std::conditional<(NumATensor > 1),
138 ComputeTypeA,
140 using LDSTypeB =
141 typename std::conditional<(NumBTensor > 1),
142 ComputeTypeB,
144
146 CDEShuffleBlockTransferScalarPerVectors{}[I0];
147
148 // K1 should be Number<...>
149 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
150 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
151 static constexpr auto AK1Number = Number<AK1Value>{};
152 static constexpr auto BK1Number = Number<BK1Value>{};
153
154 static constexpr index_t KPack = math::max(
157 .k_per_wmma);
158
160
161 static constexpr index_t APackedSize = []() {
163 return 2;
164 else
165 return 1;
166 }();
167
168 static constexpr index_t BPackedSize = []() {
170 return 2;
171 else
172 return 1;
173 }();
174
175 // Limitations of the current implementation:
176 // - no multiAB
177 // - GemmSpecialization Default
178 // - pipeline v1 because v3 is buggy (fixed in batched gemm gemm implementation)
179 // AK1Value == 8 is not really a limitation but a requirement for the method so
180 // it will stay
181#ifdef __gfx12__
182 static constexpr bool IsAWaveTransferApplicable =
183 !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
185 BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8;
186
187 static constexpr bool IsBWaveTransferApplicable =
188 !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
190 BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
191#else
192 static constexpr bool IsAWaveTransferApplicable = false;
193 static constexpr bool IsBWaveTransferApplicable = false;
194#endif
195
196 static constexpr index_t WaveSize =
198 .wave_size;
199 static constexpr bool UseBlockPaddingA =
200 ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4;
201 using ATransfer = typename std::conditional<
203 ABTransferWaveTiles<ALayout,
205 LDSTypeA,
206 BlockSize,
207 MPerBlock,
208 KPerBlock,
209 MPerWmma,
210 KPack,
211 AK1Value,
212 WaveSize>,
213 ABTransferThreadTiles<ALayout,
215 LDSTypeA,
216 BlockSize,
217 MPerBlock,
218 KPerBlock,
219 MPerWmma,
220 AK1Value,
222 PermuteA,
223 ABlockTransferThreadClusterLengths_AK0_M_AK1,
224 ABlockTransferThreadClusterArrangeOrder,
225 ABlockTransferSrcAccessOrder,
226 ABlockTransferSrcVectorDim,
227 ABlockTransferSrcScalarPerVector,
228 ABlockTransferDstScalarPerVector_AK1,
229 AThreadTransferSrcResetCoordinateAfterRun>>::type;
230
231 static constexpr bool UseBlockPaddingB =
232 BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4;
233
234 using BTransfer = typename std::conditional<
236 ABTransferWaveTiles<BLayout,
238 LDSTypeB,
239 BlockSize,
240 NPerBlock,
241 KPerBlock,
242 NPerWmma,
243 KPack,
244 BK1Value,
245 WaveSize>,
246 ABTransferThreadTiles<BLayout,
248 LDSTypeB,
249 BlockSize,
250 NPerBlock,
251 KPerBlock,
252 NPerWmma,
253 BK1Value,
255 PermuteB,
256 BBlockTransferThreadClusterLengths_BK0_N_BK1,
257 BBlockTransferThreadClusterArrangeOrder,
258 BBlockTransferSrcAccessOrder,
259 BBlockTransferSrcVectorDim,
260 BBlockTransferSrcScalarPerVector,
261 BBlockTransferDstScalarPerVector_BK1,
262 BThreadTransferSrcResetCoordinateAfterRun>>::type;
263
266 "pk_i4_t does not support padding");
267
268 static_assert(!PermuteA, "PermuteA is not supported");
269
270 // return block_id to C matrix tile idx (m0, n0) mapping
272
273 __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
274 {
275 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
276 }
277
278 __host__ static auto CalculateMPadded(index_t M)
279 {
280 return math::integer_least_multiple(M, MPerBlock);
281 }
282
283 __host__ static auto CalculateNPadded(index_t N)
284 {
285 return math::integer_least_multiple(N, NPerBlock);
286 }
287
288 __host__ static auto CalculateKPadded(index_t K)
289 {
290 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
291 }
292
293 __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
294 {
295 auto K_t = K_Batch * KPerBlock;
296 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
297 }
298
299 __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
300 {
301 auto K_t = K_Batch * KPerBlock;
302 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
303 }
304
305 __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
306 {
307 auto K_t = K_Batch * KPerBlock;
308 return (K + K_t - 1) / K_t * KPerBlock;
309 }
310
311 __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
312 {
313 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
314 auto K_t = K_Batch * KReadVec;
315 return (K + K_t - 1) / K_t * KReadVec;
316 }
317
318 __host__ static auto CalculateMBlock(index_t M)
319 {
320 return math::integer_divide_ceil(M, MPerBlock);
321 }
322
323 __host__ static auto CalculateNBlock(index_t N)
324 {
325 return math::integer_divide_ceil(N, NPerBlock);
326 }
327
328 static constexpr auto MakeAsGridPointer()
329 {
330 return generate_tuple(
331 [&](auto i) {
332 using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
333
334 return static_cast<const ADataType_*>(nullptr);
335 },
337 }
338
339 static constexpr auto MakeBsGridPointer()
340 {
341 return generate_tuple(
342 [&](auto i) {
343 using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
344
345 return static_cast<const BDataType_*>(nullptr);
346 },
348 }
349
350 using AsGridPointer = decltype(MakeAsGridPointer());
351 using BsGridPointer = decltype(MakeBsGridPointer());
352
353 __host__ __device__ static auto MakeAGridDescriptor_M_K(index_t M, index_t K, index_t StrideA)
354 {
356 {
357 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
358 }
360 {
361 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
362 }
363 }
364
365 __host__ __device__ static auto MakeBGridDescriptor_N_K(index_t N, index_t K, index_t StrideB)
366 {
368 {
369 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
370 }
372 {
373 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
374 }
375 }
376
377 __host__ __device__ static auto
379 const index_t MPad,
380 const index_t K,
381 const index_t KPad,
382 const std::array<index_t, NumATensor>& StrideAs,
383 const index_t AK0)
384 {
385 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
386 constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding ||
387 GemmSpec == GemmSpecialization::MNKPadding ||
388 GemmSpec == GemmSpecialization::MPadding ||
389 GemmSpec == GemmSpecialization::MNPadding;
390 constexpr bool padK = GemmSpec == GemmSpecialization::MKPadding ||
391 GemmSpec == GemmSpecialization::MNKPadding ||
392 GemmSpec == GemmSpecialization::KPadding ||
393 GemmSpec == GemmSpecialization::NKPadding;
394 return generate_tuple(
395 [&](auto i) {
396 const auto base_desc = MakeAGridDescriptor_M_K(M, K, StrideAs[i]);
397
398 return ATransfer::template MakeGridDescriptor<padM, padK>(
399 base_desc, M, MPad, K, KPad, StrideAs[i], AK0);
400 },
402 }
403
404 __host__ __device__ static auto
406 const index_t KPad,
407 const index_t N,
408 const index_t NPad,
409 const std::array<index_t, NumBTensor>& StrideBs,
410 const index_t BK0)
411 {
412 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
413 constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding ||
414 GemmSpec == GemmSpecialization::MNKPadding ||
415 GemmSpec == GemmSpecialization::NPadding ||
416 GemmSpec == GemmSpecialization::MNPadding;
417 constexpr bool padK = GemmSpec == GemmSpecialization::NKPadding ||
418 GemmSpec == GemmSpecialization::MNKPadding ||
419 GemmSpec == GemmSpecialization::KPadding ||
420 GemmSpec == GemmSpecialization::MKPadding;
421 return generate_tuple(
422 [&](auto i) {
423 const auto base_desc = MakeBGridDescriptor_N_K(N, K, StrideBs[i]);
424 return BTransfer::template MakeGridDescriptor<padN, padK>(
425 base_desc, N, NPad, K, KPad, StrideBs[i], BK0);
426 },
428 }
429
430 __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor()
431 {
432 constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
433
434 return ATransfer::template MakeWmmaTileDescriptor<MRepeat, MWaves>();
435 }
436
437 __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor()
438 {
439 constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
440
441 return BTransfer::template MakeWmmaTileDescriptor<NRepeat, NWaves>();
442 }
443
444 template <typename DELayout>
445 __host__ __device__ static auto
447 {
448 const auto c_grid_desc_mraw_nraw = [&]() {
450 {
451 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideDE, I1));
452 }
454 {
455 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideDE));
456 }
457 }();
458
459 // pad M and N
460 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
462 make_right_pad_transform(N, NPad - N)),
465 // TODO: Investigate why this path is not used in the original
466 // gridwise_gemm_xdl_cshuffle_v3.hpp
467#if 0
468 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
469
470 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
471 GemmSpec == GemmSpecialization::MNKPadding)
472 {
473 // pad M and N
474 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
476 make_right_pad_transform(N, NPad - N)),
479 }
480 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
481 GemmSpec == GemmSpecialization::MKPadding)
482 {
483 // pad M, but not N
485 c_grid_desc_mraw_nraw,
489 }
490 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
491 GemmSpec == GemmSpecialization::NKPadding)
492 {
493 // pad N, but not M
495 c_grid_desc_mraw_nraw,
499 }
500 else
501 {
502 // not pad M or N
503 return c_grid_desc_mraw_nraw;
504 }
505#endif
506 }
507
508 static constexpr index_t NumDTensor = DsDataType::Size();
509
510 static constexpr auto MakeDsGridPointer()
511 {
512 return generate_tuple(
513 [&](auto i) {
514 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
515
516 return static_cast<const DDataType*>(nullptr);
517 },
519 }
520
521 using DsGridPointer = decltype(MakeDsGridPointer());
522
523 __host__ __device__ static auto MakeDsGridDescriptor_M_N(
524 index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
525 {
526 return generate_tuple(
527 [&](auto i) {
528 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
529 return MakeDEGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
530 },
532 }
533
534 template <typename DsGridDesc>
536 const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
537 {
538 return generate_tuple(
539 [&](auto i) {
541 ds_grid_desc_m_n[i], MBlock, NBlock);
542 },
544 }
545
547 remove_cvref_t<decltype(BlockGemmPipeline_Selector<BlkGemmPipelineVer,
548 BlkGemmPipeSched,
549 BlockSize,
550 LDSTypeA,
551 LDSTypeB,
552 ComputeTypeA,
553 ComputeTypeB,
554 AccDataType,
555 decltype(MakeAWmmaTileDescriptor()),
556 decltype(MakeBWmmaTileDescriptor()),
557 ABlockTransferSrcScalarPerVector,
558 BBlockTransferSrcScalarPerVector,
559 MPerBlock,
560 NPerBlock,
561 KPerBlock,
562 MPerWmma,
563 NPerWmma,
564 MRepeat,
565 NRepeat,
566 KPack>())>;
567
568 // Used to create obj in global function and pass it to Run method
570 EpilogueCShuffle<DsDataType,
571 EDataType,
572 AccDataType,
573 CShuffleDataType,
574 MPerBlock,
575 NPerBlock,
576 MPerWmma,
577 NPerWmma,
578 MRepeat,
579 NRepeat,
580 CShuffleMRepeatPerShuffle,
581 CShuffleNRepeatPerShuffle,
582 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
583 CDEShuffleBlockTransferScalarPerVectors,
584 CDEElementwiseOperation,
587
589 DsDataType,
590 EDataType,
591 AccDataType,
592 CShuffleDataType,
593 MPerBlock,
594 NPerBlock,
595 MPerWmma,
596 NPerWmma,
597 MRepeat,
598 NRepeat,
599 CShuffleMRepeatPerShuffle,
600 CShuffleNRepeatPerShuffle,
601 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
602 CDEShuffleBlockTransferScalarPerVectors,
603 CDEElementwiseOperation,
606 BlockSize>;
607
608 template <typename DEGridDesc>
610 const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
611 {
612 const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
613 de_grid_desc_m_n,
618
619 return de_grid_desc_mblock_mperblock_nblock_nperblock;
620 }
621
622 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
623 template <typename Argument>
624 __host__ static constexpr bool CheckValidity(const Argument& karg)
625 {
626 static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
627 (NPerBlock % (NPerWmma * NRepeat)) == 0,
628 "Invalid tuning param!");
629
635 {
636 if(!(karg.M % MPerBlock == 0))
637 {
638 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
639 {
640 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
641 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
642 << std::endl;
643 }
644 return false;
645 }
646 }
647
653 {
654 if(!(karg.N % NPerBlock == 0))
655 {
656 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
657 {
658 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
659 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
660 << std::endl;
661 }
662 return false;
663 }
664 }
665
670 {
671
672 auto K_t = karg.KBatch * KPerBlock;
673 if(!(karg.K % K_t == 0))
674 {
675 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
676 {
677 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
678 << karg.K << " " << __FILE__ << ":" << __LINE__
679 << ", in function: " << __func__ << std::endl;
680 }
681 return false;
682 }
683 }
684 else
685 {
686 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
687 auto K_t = karg.KBatch * KReadVec;
688 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
689 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
690 {
691 return false;
692 }
693 }
694
696 {
697 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
698 {
699 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
700 {
701 std::cout << "Arg K (" << karg.K
702 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
703 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
704 << __LINE__ << ", in function: " << __func__ << std::endl;
705 }
706 return false;
707 }
708 }
709 else
710 {
711 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
712 {
713 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
714 {
715 std::cout << "Arg M (" << karg.M
716 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
717 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
718 << __LINE__ << ", in function: " << __func__ << std::endl;
719 }
720 return false;
721 }
722 }
723
725 {
726 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
727 {
728 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
729 {
730 std::cout << "Arg N (" << karg.N
731 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
732 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
733 << __LINE__ << ", in function: " << __func__ << std::endl;
734 }
735 return false;
736 }
737 }
738 else
739 {
740 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
741 {
742 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
743 {
744 std::cout << "Arg K (" << karg.K
745 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
746 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
747 << __LINE__ << ", in function: " << __func__ << std::endl;
748 }
749 return false;
750 }
751 }
752
754 {
756 {
757 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
758 {
759 std::cout << "Arg N (" << karg.N
760 << ") value is not a multiple of "
761 "EShuffleBlockTransferScalarPerVector ("
762 << EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
763 << __LINE__ << ", in function: " << __func__ << std::endl;
764 }
765 return false;
766 }
767 }
768 else
769 {
771 {
772 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
773 {
774 std::cout << "Arg M (" << karg.M
775 << ") value is not a multiple of "
776 "EShuffleBlockTransferScalarPerVector ("
777 << EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
778 << __LINE__ << ", in function: " << __func__ << std::endl;
779 }
780 return false;
781 }
782 }
783
788 {
789 if(karg.IsAtomicAdd() && karg.KBatch > 1)
790 {
791 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
792 {
793 std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported for this "
794 << "destination type (EDataType) " << __FILE__ << ":" << __LINE__
795 << ", in function: " << __func__ << std::endl;
796 }
797 return false;
798 }
799 }
800
801 // check gridwise gemm pipeline
802 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
803
804 if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
805 {
806 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
807 {
808 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
809 {
810 std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop
811 << ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages
812 << ") for pipeline version != v1." << __FILE__ << ":" << __LINE__
813 << ", in function: " << __func__ << std::endl;
814 }
815 return false;
816 }
817 }
818
820 {
821 if(karg.KBatch > 1)
822 {
823 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
824 {
825 std::cout << "int8_t does not support KBatch > 1. KBatch: " << karg.KBatch
826 << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
827 << std::endl;
828 }
829 return false;
830 }
831 }
832
833 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
834 return true;
835 }
836
837 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
838 {
839 const index_t num_loop = K / KPerBlock;
840
841 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
842 }
843
844 __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
845 {
846 const index_t num_loop = K / KPerBlock;
847
848 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
849 }
850
851 template <typename EpilogueType>
852 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
853 {
854 // LDS allocation for A and B: be careful of alignment
855 constexpr auto a_block_desc_ak0_m_ak1 = ATransfer::GetBlockDescriptor();
856 constexpr auto b_block_desc_bk0_n_bk1 = BTransfer::GetBlockDescriptor();
857
858 // lds max alignment
859 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
860
861 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
862 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
863
864 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
865 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
866
867 // LDS allocation for C shuffle in LDS
868 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
869 EpilogueType::
870 GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
871
872 constexpr auto c_block_size =
873 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
874 .GetElementSpaceSize();
875
876 return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
877 b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize),
878 c_block_size * sizeof(CShuffleDataType));
879 }
880
881 template <index_t numElements, typename Type>
882 __device__ __forceinline__ static auto get_first_element_workaround(Type& array)
883 {
884 if constexpr(numElements > 1)
885 {
886 return array;
887 }
888 else
889 {
890 return array[I0];
891 }
892 }
893
894 template <typename AGridDesc_AK0_M_K1,
895 typename BGridDesc_BK0_N_K1,
896 typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
897 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
898 typename BScaleStruct,
899 typename EpilogueArgument,
900 bool HasMainKBlockLoop,
901 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
902 TailNumber TailNum = TailNumber::Odd>
903 __device__ static void Run(AsGridPointer p_as_grid,
904 BsGridPointer p_bs_grid,
905 DsGridPointer p_ds_grid,
906 EDataType* p_e_grid,
907 void* p_shared,
908 const AGridDesc_AK0_M_K1& as_grid_desc_ak0_m_ak1,
909 const BGridDesc_BK0_N_K1& bs_grid_desc_bk0_n_bk1,
910 const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
911 ds_grid_desc_mblock_mperblock_nblock_nperblock,
912 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
913 e_grid_desc_mblock_mperblock_nblock_nperblock,
914 AElementwiseOperation a_element_op,
915 BElementwiseOperation b_element_op,
916 CDEElementwiseOperation cde_element_op,
917 const index_t& block_m_id,
918 const index_t& block_n_id,
919 const index_t& num_k_block_per_scale,
920 BScaleStruct& b_scale_struct,
921 EpilogueArgument& epilogue_args)
922 {
923 const auto as_grid_buf = generate_tuple(
924 [&](auto i) {
926 p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize());
927 },
929
930 const auto bs_grid_buf = generate_tuple(
931 [&](auto i) {
933 p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize());
934 },
936
937 // lds max alignment
938 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
939
940 // A matrix in LDS memory, dst of blockwise copy
941 constexpr auto a_block_desc_ak0_m_ak1 = ATransfer::GetBlockDescriptor();
942
943 // B matrix in LDS memory, dst of blockwise copy
944 constexpr auto b_block_desc_bk0_n_bk1 = BTransfer::GetBlockDescriptor();
945
946 // A matrix blockwise copy
947 auto a_blockwise_copy =
948 ATransfer::template GetBlockTransfer<AGridDesc_AK0_M_K1,
949 decltype(a_block_desc_ak0_m_ak1),
950 AsDataType,
951 AElementwiseOperation,
952 BlockwiseGemmPipe::GlobalBufferNum>(
953 as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id);
954
955 // B matrix blockwise copy
956 auto b_blockwise_copy =
957 BTransfer::template GetBlockTransfer<BGridDesc_BK0_N_K1,
958 decltype(b_block_desc_bk0_n_bk1),
959 BsDataType,
960 BElementwiseOperation,
961 BlockwiseGemmPipe::GlobalBufferNum>(
962 bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id);
963
964 // LDS allocation for A and B: be careful of alignment
965 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
966 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
967
968 // Cast after lds
970 static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
971
973 reinterpret_cast<LDSTypeB*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
974 sizeof(LDSTypeA) /
976 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
977
978 constexpr auto a_block_slice_copy_step = ATransfer::GetBlockStep();
979 constexpr auto b_block_slice_copy_step = BTransfer::GetBlockStep();
980
981 // Blockwise GEMM pipeline
982 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
983 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
984 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
985
986 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
987 ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / KPerBlock);
988
989 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
990 get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),
991 a_block_desc_ak0_m_ak1,
992 a_blockwise_copy,
994 a_block_buf,
995 a_block_slice_copy_step,
996 get_first_element_workaround<NumBTensor>(bs_grid_desc_bk0_n_bk1),
997 b_block_desc_bk0_n_bk1,
998 b_blockwise_copy,
1000 b_block_buf,
1001 b_block_slice_copy_step,
1002 c_thread_buf,
1003 b_scale_struct,
1004 num_k_block_main_loop,
1005 num_k_block_per_scale);
1006
1007 // shuffle C and write out
1008 epilogue_args.template Run<EGlobalMemoryDataOperation>(
1009 c_thread_buf,
1010 p_ds_grid,
1011 p_e_grid,
1012 p_shared,
1013 ds_grid_desc_mblock_mperblock_nblock_nperblock,
1014 e_grid_desc_mblock_mperblock_nblock_nperblock,
1015 cde_element_op,
1016 block_m_id,
1017 block_n_id);
1018 }
1019};
1020
1021} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ Default
Definition gemm_specialization.hpp:13
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
_Float16 half_t
Definition data_type.hpp:31
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__global__ void kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:40
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Type
Type of JSON value.
Definition rapidjson.h:760
signed int int32_t
Definition stdint.h:123
signed char int8_t
Definition stdint.h:121
Definition gridwise_ab_transfer_thread_tiles.hpp:30
Definition gridwise_ab_transfer_wave_tiles.hpp:23
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:122
static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:288
__host__ static __device__ constexpr auto MakeBWmmaTileDescriptor()
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:437
static constexpr bool IsAWaveTransferApplicable
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:192
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor()), decltype(MakeBWmmaTileDescriptor()), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:546
EpilogueWelfordCShuffle< DsDataType, EDataType, AccDataType, CShuffleDataType, MPerBlock, NPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, CDEElementwiseOperation, ThisThreadBlock, BlockwiseGemmPipe, BlockSize > EpilogueWelfordCShuffle
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:588
static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:323
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:293
static __host__ constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:844
static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:318
static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:624
typename std::conditional< IsBWaveTransferApplicable, ABTransferWaveTiles< BLayout, tensor_layout::gemm::ColumnMajor, LDSTypeB, BlockSize, NPerBlock, KPerBlock, NPerWmma, KPack, BK1Value, WaveSize >, ABTransferThreadTiles< BLayout, tensor_layout::gemm::ColumnMajor, LDSTypeB, BlockSize, NPerBlock, KPerBlock, NPerWmma, BK1Value, UseBlockPaddingB, PermuteB, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BThreadTransferSrcResetCoordinateAfterRun > >::type BTransfer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:234
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:311
__host__ static __device__ auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, const index_t N, const index_t NPad, const std::array< index_t, NumBTensor > &StrideBs, const index_t BK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:405
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:446
__host__ static __device__ auto MakeAsGridDescriptor_AK0_M_AK1(const index_t M, const index_t MPad, const index_t K, const index_t KPad, const std::array< index_t, NumATensor > &StrideAs, const index_t AK0)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:378
EpilogueCShuffle< DsDataType, EDataType, AccDataType, CShuffleDataType, MPerBlock, NPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, CDEElementwiseOperation, ThisThreadBlock, BlockwiseGemmPipe > EpilogueCShuffle
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:569
__host__ static __device__ auto MakeAGridDescriptor_M_K(index_t M, index_t K, index_t StrideA)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:353
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:523
__host__ static __device__ constexpr auto MakeAWmmaTileDescriptor()
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:430
static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:535
static constexpr auto MakeBsGridPointer()
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:339
static constexpr auto MakeAsGridPointer()
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:328
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:273
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:305
static __device__ void Run(AsGridPointer p_as_grid, BsGridPointer p_bs_grid, DsGridPointer p_ds_grid, EDataType *p_e_grid, void *p_shared, const AGridDesc_AK0_M_K1 &as_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &bs_grid_desc_bk0_n_bk1, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, const index_t &block_m_id, const index_t &block_n_id, const index_t &num_k_block_per_scale, BScaleStruct &b_scale_struct, EpilogueArgument &epilogue_args)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:903
static __device__ constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:609
__host__ static __device__ auto MakeBGridDescriptor_N_K(index_t N, index_t K, index_t StrideB)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:365
static constexpr auto MakeDsGridPointer()
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:510
static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:837
static constexpr bool IsBWaveTransferApplicable
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:193
static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:278
static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:852
static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:283
typename std::conditional< IsAWaveTransferApplicable, ABTransferWaveTiles< ALayout, tensor_layout::gemm::RowMajor, LDSTypeA, BlockSize, MPerBlock, KPerBlock, MPerWmma, KPack, AK1Value, WaveSize >, ABTransferThreadTiles< ALayout, tensor_layout::gemm::RowMajor, LDSTypeA, BlockSize, MPerBlock, KPerBlock, MPerWmma, AK1Value, UseBlockPaddingA, PermuteA, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, AThreadTransferSrcResetCoordinateAfterRun > >::type ATransfer
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:201
__device__ static __forceinline__ auto get_first_element_workaround(Type &array)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:882
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:299
Definition utility/sequence.hpp:43
static constexpr auto selected_wmma
Definition wmma_gemm.hpp:636
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition type.hpp:177
Definition data_type.hpp:187
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
#define CK_ENV(name)
Definition utility/env.hpp:129