device_gemm_xdl_cshuffle_v3.hpp Source File

device_gemm_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: device_gemm_xdl_cshuffle_v3.hpp Source File
device_gemm_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
123template <typename ALayout,
124 typename BLayout,
125 typename CLayout,
126 typename ADataType,
127 typename BDataType,
128 typename CDataType,
129 typename GemmAccDataType,
130 typename CShuffleDataType,
131 typename AElementwiseOperation,
132 typename BElementwiseOperation,
133 typename CElementwiseOperation,
134 GemmSpecialization GemmSpec,
135 index_t BlockSize,
136 index_t MPerBlock,
137 index_t NPerBlock,
138 index_t KPerBlock,
139 index_t AK1,
140 index_t BK1,
141 index_t MPerXDL,
142 index_t NPerXDL,
143 index_t MXdlPerWave,
144 index_t NXdlPerWave,
145 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
146 typename ABlockTransferThreadClusterArrangeOrder,
147 typename ABlockTransferSrcAccessOrder,
148 index_t ABlockTransferSrcVectorDim,
149 index_t ABlockTransferSrcScalarPerVector,
150 index_t ABlockTransferDstScalarPerVector_AK1,
151 bool ABlockLdsExtraM,
152 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
153 typename BBlockTransferThreadClusterArrangeOrder,
154 typename BBlockTransferSrcAccessOrder,
155 index_t BBlockTransferSrcVectorDim,
156 index_t BBlockTransferSrcScalarPerVector,
157 index_t BBlockTransferDstScalarPerVector_BK1,
158 bool BBlockLdsExtraN,
159 index_t CShuffleMXdlPerWavePerShuffle,
160 index_t CShuffleNXdlPerWavePerShuffle,
161 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
162 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
165 typename ComputeTypeA = CDataType,
166 typename ComputeTypeB = ComputeTypeA,
167 bool PermuteA = false,
168 bool PermuteB = false>
170 BLayout,
171 CLayout,
172 ADataType,
173 BDataType,
174 CDataType,
175 AElementwiseOperation,
176 BElementwiseOperation,
177 CElementwiseOperation>
178{
179 // GridwiseGemm
181 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
182 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
183
184 template <index_t NXdlPerWave_>
186 ALayout,
187 BLayout,
188 CLayout,
189 ADataType,
190 BDataType,
191 GemmAccDataType,
192 CShuffleDataType,
193 CDataType,
194 AElementwiseOperation,
195 BElementwiseOperation,
196 CElementwiseOperation,
197 GemmSpec,
198 BlockSize,
199 MPerBlock,
200 NPerBlock,
201 KPerBlock,
202 AK1,
203 BK1,
204 MPerXDL,
205 NPerXDL,
206 MXdlPerWave,
207 NXdlPerWave_,
208 ABlockTransferThreadClusterLengths_AK0_M_AK1,
209 ABlockTransferThreadClusterArrangeOrder,
210 ABlockTransferSrcAccessOrder,
211 ABlockTransferSrcVectorDim,
212 ABlockTransferSrcScalarPerVector,
213 ABlockTransferDstScalarPerVector_AK1,
214 false,
215 ABlockLdsExtraM,
216 BBlockTransferThreadClusterLengths_BK0_N_BK1,
217 BBlockTransferThreadClusterArrangeOrder,
218 BBlockTransferSrcAccessOrder,
219 BBlockTransferSrcVectorDim,
220 BBlockTransferSrcScalarPerVector,
221 BBlockTransferDstScalarPerVector_BK1,
222 false,
223 BBlockLdsExtraN,
224 CShuffleMXdlPerWavePerShuffle,
225 CShuffleNXdlPerWavePerShuffle,
226 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
227 CShuffleBlockTransferScalarPerVector_NPerBlock,
228 BlkGemmPipeSched,
229 BlkGemmPipelineVer,
230 ComputeTypeA,
231 ComputeTypeB,
232 PermuteA,
233 PermuteB>;
236
237 using Argument = typename GridwiseGemm64::Argument;
238
239 static constexpr index_t APackedSize = []() {
241 return 2;
242 else
243 return 1;
244 }();
245
246 static constexpr index_t BPackedSize = []() {
248 return 2;
249 else
250 return 1;
251 }();
252
262 struct Invoker : public BaseInvoker
263 {
269 template <typename GridwiseGemm>
270 float RunImp(const typename GridwiseGemm::Argument& arg,
271 const StreamConfig& stream_config = StreamConfig{})
272 {
273 if(stream_config.log_level_ > 0)
274 {
275 arg.Print();
276 GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
277 }
278
279 if(!GridwiseGemm::CheckValidity(arg))
280 {
281 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
282 }
283
284 index_t gdx, gdy, gdz;
285 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
286
287 float ave_time = 0;
288
289 index_t k_grain = arg.KBatch * KPerBlock;
290 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
291
292 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
293
294 const auto Run = [&](const auto& kernel) {
295 if(stream_config.flush_cache)
296 {
297 auto arg_ = arg;
298
299 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
300 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
301 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
302 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
303
304 auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
305 sizeof(ADataType) / APackedSize;
306 auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
307 sizeof(BDataType) / BPackedSize;
308
310 arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
311 rotating_mem.Print();
312
313 auto run_flush_cache = [&]() {
314 // flush icache
316 // rotating mem
317 rotating_mem.Next();
318 // clear c mem
319 if(arg_.KBatch > 1)
320 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
321 0,
322 arg_.M * arg_.N * sizeof(CDataType),
323 stream_config.stream_id_));
324 };
325
327 stream_config,
328 run_flush_cache,
329 kernel,
330 dim3(gdx, gdy, gdz),
331 dim3(BlockSize),
332 0,
333 arg_);
334 }
335 else
336 {
337 if(arg.KBatch > 1)
338 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
339 0,
340 arg.M * arg.N * sizeof(CDataType),
341 stream_config.stream_id_));
342
343 ave_time = launch_and_time_kernel(
344 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
345 }
346 };
347
348 constexpr index_t minimum_occupancy = []() {
349 if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
350 {
351 return 2;
352 }
353 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
354 {
355 return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
356 }
357 else
358 {
359 return 1;
360 }
361 }();
362
363 if(has_main_k_block_loop)
364 {
365 // Tail number always full
366 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
367 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
368 {
369 if(arg.KBatch > 1)
370 {
371 const auto kernel =
372 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
373 true,
375 minimum_occupancy>;
376 Run(kernel);
377 }
378 else
379 {
380 const auto kernel =
381 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
382 true,
384 minimum_occupancy>;
385 Run(kernel);
386 }
387 }
388 // Tail number could be One to Seven
389 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
390 {
391 if(arg.KBatch > 1)
392 {
393 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
394 {
395 const auto kernel =
396 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
397 true,
399 minimum_occupancy,
401 Run(kernel);
402 }
403 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
405 {
406 const auto kernel =
407 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
408 true,
410 minimum_occupancy,
412 Run(kernel);
413 }
414
415 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
416 {
417 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
418 {
419 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
420 GridwiseGemm,
421 true,
423 minimum_occupancy,
425 Run(kernel);
426 }
427 }
428
429 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
430 {
431 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
433 {
434 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
435 GridwiseGemm,
436 true,
438 minimum_occupancy,
440 Run(kernel);
441 }
442 }
443
444 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
445 {
446 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
448 {
449 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
450 GridwiseGemm,
451 true,
453 minimum_occupancy,
455 Run(kernel);
456 }
457 }
458
459 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
460 {
461 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
463 {
464 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
465 GridwiseGemm,
466 true,
468 minimum_occupancy,
470 Run(kernel);
471 }
472 }
473
474 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
475 {
476 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
477 {
478 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
479 GridwiseGemm,
480 true,
482 minimum_occupancy,
484 Run(kernel);
485 }
486 }
487
488 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
489 {
490 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
492 {
493 const auto kernel = kernel_gemm_xdl_cshuffle_v3<
494 GridwiseGemm,
495 true,
497 minimum_occupancy,
499 Run(kernel);
500 }
501 }
502 }
503 else
504 {
505 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
506 {
507 const auto kernel =
508 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
509 true,
511 minimum_occupancy,
513 Run(kernel);
514 }
515 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
517 {
518 const auto kernel =
519 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
520 true,
522 minimum_occupancy,
524 Run(kernel);
525 }
526
527 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
528 {
529 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
530 {
531 const auto kernel =
532 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
533 true,
535 minimum_occupancy,
537 Run(kernel);
538 }
539 }
540
541 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
542 {
543 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
545 {
546 const auto kernel =
547 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
548 true,
550 minimum_occupancy,
552 Run(kernel);
553 }
554 }
555
556 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
557 {
558 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
560 {
561 const auto kernel =
562 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
563 true,
565 minimum_occupancy,
567 Run(kernel);
568 }
569 }
570
571 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
572 {
573 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
575 {
576 const auto kernel =
577 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
578 true,
580 minimum_occupancy,
582 Run(kernel);
583 }
584 }
585
586 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
587 {
588 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
589 {
590 const auto kernel =
591 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
592 true,
594 minimum_occupancy,
596 Run(kernel);
597 }
598 }
599
600 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
601 {
602 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
604 {
605 const auto kernel =
606 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
607 true,
609 minimum_occupancy,
611 Run(kernel);
612 }
613 }
614 }
615 }
616 // Tail number could be Odd or Even
617 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
618 {
619 if(arg.KBatch > 1)
620 {
621 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
622 {
623 const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
624 GridwiseGemm,
625 true,
627 minimum_occupancy,
629 Run(kernel);
630 }
631 else
632 {
633 const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
634 GridwiseGemm,
635 true,
637 minimum_occupancy,
639 Run(kernel);
640 }
641 }
642 else
643 {
644 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
645 {
646 const auto kernel =
648 true,
650 minimum_occupancy,
652 Run(kernel);
653 }
654 else
655 {
656 const auto kernel =
658 true,
660 minimum_occupancy,
662 Run(kernel);
663 }
664 }
665 }
666 else
667 {
668 if(arg.KBatch > 1)
669 {
670 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
671 {
672 const auto kernel =
673 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
674 true,
676 minimum_occupancy,
678 Run(kernel);
679 }
680 else
681 {
682 const auto kernel =
683 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
684 true,
686 minimum_occupancy,
688 Run(kernel);
689 }
690 }
691 else
692 {
693 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
694 {
695 const auto kernel =
696 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
697 true,
699 minimum_occupancy,
701 Run(kernel);
702 }
703 else
704 {
705 const auto kernel =
706 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
707 true,
709 minimum_occupancy,
711 Run(kernel);
712 }
713 }
714 }
715 }
716 else
717 {
718 // Tail number always 1
719 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
720 {
721 if(arg.KBatch > 1)
722 {
723 const auto kernel =
724 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
725 false,
727 minimum_occupancy>;
728 Run(kernel);
729 }
730 else
731 {
732 const auto kernel =
733 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
734 false,
736 minimum_occupancy>;
737 Run(kernel);
738 }
739 }
740 }
741
742 return ave_time;
743 }
744
746 // polymorphic
747 float Run(const BaseArgument* p_arg,
748 const StreamConfig& stream_config = StreamConfig{}) override
749 {
750 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
751 }
752 };
753
754 static constexpr bool IsValidCompilationParameter()
755 {
756 // TODO: properly implement this check
757 return true;
758 }
759
760 static bool IsSupportedArgument(const Argument& arg)
761 {
763 {
764 return false;
765 }
766 if(arg.KBatch > 1)
767 {
769 {
770 return false;
771 }
772
773 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t>)
774 {
775 return false;
776 }
777
778 if(sizeof(CDataType) == 1)
779 {
780 return false;
781 }
782 }
783
785 {
786 if constexpr(std::is_same_v<ADataType, ck::f8_t> ||
787 std::is_same_v<ADataType, ck::bf8_t>)
788 {
789 return false;
790 }
791 }
792
793 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
794 GemmSpec == GemmSpecialization::NKPadding ||
795 GemmSpec == GemmSpecialization::MNKPadding ||
796 GemmSpec == GemmSpecialization::KPadding))
797 {
798 return false;
799 }
800
801 if(get_warp_size() == 64)
802 {
803 if constexpr(NXdlPerWave64 > 0)
804 {
806 }
807 }
808 else
809 {
810 if constexpr(NXdlPerWave32 > 0)
811 {
813 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
814 }
815 }
816 return false;
817 }
818
819 // polymorphic
820 bool IsSupportedArgument(const BaseArgument* p_arg) override
821 {
822 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
823 }
824
825 index_t GetKPerBlock() override { return KPerBlock; }
826
827 bool GetPermuteA() override { return PermuteA; }
828 bool GetPermuteB() override { return PermuteB; }
829
830 static auto MakeArgument(const ADataType* p_a,
831 const BDataType* p_b,
832 CDataType* p_c,
833 index_t M,
834 index_t N,
835 index_t K,
836 index_t StrideA,
837 index_t StrideB,
838 index_t StrideC,
839 index_t KBatch,
840 AElementwiseOperation,
841 BElementwiseOperation,
842 CElementwiseOperation)
843 {
844 return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
845 }
846
847 static auto MakeInvoker() { return Invoker{}; }
848
849 // polymorphic
850 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
851 const void* p_b,
852 void* p_c,
853 index_t M,
854 index_t N,
855 index_t K,
856 index_t StrideA,
857 index_t StrideB,
858 index_t StrideC,
859 index_t KBatch,
860 AElementwiseOperation,
861 BElementwiseOperation,
862 CElementwiseOperation) override
863 {
864 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
865 static_cast<const BDataType*>(p_b),
866 static_cast<CDataType*>(p_c),
867 M,
868 N,
869 K,
870 StrideA,
871 StrideB,
872 StrideC,
873 KBatch);
874 }
875
876 // polymorphic
877 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
878 {
879 return std::make_unique<Invoker>(Invoker{});
880 }
881
882 // polymorphic
883 std::string GetTypeString() const override
884 {
885 auto str = std::stringstream();
886
887 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
890
891 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
897
898 index_t PrefetchStages = 0;
899 index_t AMmaKStride = 0;
900 if(get_warp_size() == 64)
901 {
902 if constexpr(NXdlPerWave64 > 0)
903 {
904 PrefetchStages = GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
905 AMmaKStride = GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride;
906 }
907 }
908 else
909 {
910 if constexpr(NXdlPerWave32 > 0)
911 {
912 PrefetchStages = GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages;
913 AMmaKStride = GridwiseGemm32::BlockwiseGemmPipe::AMmaKStride;
914 }
915 }
916
917 // clang-format off
918 str << "DeviceGemmXdlUniversal"
919 << "<"
920 << getGemmSpecializationString(GemmSpec) << ", "
921 << std::string(ALayout::name)[0]
922 << std::string(BLayout::name)[0]
923 << std::string(CLayout::name)[0]
924 << ">"
925 << " BlkSize: "
926 << BlockSize << ", "
927 << "BlkTile: "
928 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
929 << "WaveTile: "
930 << MPerXDL<<"x"<<NPerXDL << ", "
931 << "WaveMap: "
932 << MXdlPerWave<<"x" << NXdlPerWave<<", "
933 << "VmemReadVec: "
934 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
935 << "BlkGemmPipelineScheduler: "
936 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
937 << "BlkGemmPipelineVersion: "
938 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
939 << "BlkGemmPipelinePrefetchStages: "
940 << PrefetchStages << ", "
941 << "Kpack: "
942 << AMmaKStride;
943 // clang-format on
944
945 return str.str();
946 }
948};
949
950} // namespace device
951} // namespace tensor_operation
952} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define REGISTER_EXTRA_PRINTING_METHODS
Definition device_base.hpp:47
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
Definition data_type.hpp:187
Definition device_base.hpp:197
Helper structure responsible for kernel invocation.
Definition device_gemm_xdl_cshuffle_v3.hpp:263
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
This function issues GPU kernel execution.
Definition device_gemm_xdl_cshuffle_v3.hpp:270
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle_v3.hpp:747
"Universal" GEMM operation with SplitK support.
Definition device_gemm_xdl_cshuffle_v3.hpp:178
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition device_gemm_xdl_cshuffle_v3.hpp:850
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_cshuffle_v3.hpp:235
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle_v3.hpp:820
bool GetPermuteA() override
Definition device_gemm_xdl_cshuffle_v3.hpp:827
index_t GetKPerBlock() override
Definition device_gemm_xdl_cshuffle_v3.hpp:825
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle_v3.hpp:877
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle_v3.hpp:760
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_cshuffle_v3.hpp:754
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_cshuffle_v3.hpp:234
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle_v3.hpp:883
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle_v3.hpp:181
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemmBase
Definition device_gemm_xdl_cshuffle_v3.hpp:185
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition device_gemm_xdl_cshuffle_v3.hpp:830
bool GetPermuteB() override
Definition device_gemm_xdl_cshuffle_v3.hpp:828
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle_v3.hpp:847
static constexpr index_t BPackedSize
Definition device_gemm_xdl_cshuffle_v3.hpp:246
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_cshuffle_v3.hpp:237
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle_v3.hpp:182
static constexpr index_t APackedSize
Definition device_gemm_xdl_cshuffle_v3.hpp:239
Definition device_gemm_v2.hpp:22
Definition flush_cache.hpp:299