23template <
typename GridwiseGemmWelford,
26 typename EMeanVarDataType,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename CDEElementwiseOperation,
30 typename AGridDesc_AK0_M_AK1,
31 typename BGridDesc_BK0_N_BK1,
32 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
33 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
34 typename MeanVarGridDescriptor_MBlock_MPerBlock_NBlock,
35 typename CountGridDescriptor_MBlock_MPerBlock_NBlock,
36 typename Block2ETileMap,
37 bool HasMainKBlockLoop>
39#if CK_USE_LAUNCH_BOUNDS
43 const ABDataType* __restrict__ p_a_grid,
44 const ABDataType* __restrict__ p_b_grid,
46 EMeanVarDataType* __restrict__ p_e_grid,
47 EMeanVarDataType* __restrict__ p_welford_mean_grid,
48 EMeanVarDataType* __restrict__ p_welford_var_grid,
49 int32_t* __restrict__ p_welford_count_grid,
50 const AElementwiseOperation a_element_op,
51 const BElementwiseOperation b_element_op,
52 const CDEElementwiseOperation cde_element_op,
53 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
54 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
55 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
56 ds_grid_desc_mblock_mperblock_nblock_nperblock,
57 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
58 e_grid_desc_mblock_mperblock_nblock_nperblock,
59 const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
60 mean_var_grid_desc_mblock_mperblock_nblock,
61 const CountGridDescriptor_MBlock_MPerBlock_NBlock count_grid_desc_mblock_mperblock_nblock,
62 const Block2ETileMap block_2_etile_map,
65#if defined(__gfx9__) || defined(__gfx12__)
66 if constexpr(GridwiseGemmWelford::template IsValidCompilationParameter<>())
68 __shared__
char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
70 GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
82 a_grid_desc_ak0_m_ak1,
83 b_grid_desc_bk0_n_bk1,
84 ds_grid_desc_mblock_mperblock_nblock_nperblock,
85 e_grid_desc_mblock_mperblock_nblock_nperblock,
86 mean_var_grid_desc_mblock_mperblock_nblock,
87 count_grid_desc_mblock_mperblock_nblock,
96 ignore = p_welford_mean_grid;
97 ignore = p_welford_var_grid;
98 ignore = p_welford_count_grid;
102 ignore = a_grid_desc_ak0_m_ak1;
103 ignore = b_grid_desc_bk0_n_bk1;
104 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
105 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
106 ignore = mean_var_grid_desc_mblock_mperblock_nblock;
107 ignore = count_grid_desc_mblock_mperblock_nblock;
108 ignore = block_2_etile_map;
113template <
typename GridwiseWelfordLayernorm,
114 typename EMeanVarDataType,
116 typename GammaDataType,
117 typename BetaDataType,
118 typename ComputeDataType,
119 typename EHGridDesc_M_N,
120 typename LayernormMeanVarGridDesc_M_NBlock,
121 typename LayernormCountGridDesc_M_NBlock,
122 typename GammaBetaGridDesc_N,
123 typename HElementwiseOperation>
125#if CK_USE_LAUNCH_BOUNDS
129 const EMeanVarDataType* __restrict__ p_e_grid,
130 const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
131 const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
132 const int32_t* __restrict__ p_in_welford_count_grid,
133 const GammaDataType* __restrict__ p_gamma_grid,
134 const BetaDataType* __restrict__ p_beta_grid,
135 HDataType* __restrict__ p_h_grid,
136 const EHGridDesc_M_N e_grid_desc_m_n,
137 const EHGridDesc_M_N h_grid_desc_m_n,
138 const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock,
139 const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock,
140 const GammaBetaGridDesc_N gamma_grid_desc_n,
141 const GammaBetaGridDesc_N beta_grid_desc_n,
142 index_t numMeanVarCountBlockTileIteration_N,
144 ComputeDataType epsilon,
145 HElementwiseOperation h_element_op)
147 GridwiseWelfordLayernorm::Run(p_e_grid,
148 p_in_welford_mean_grid,
149 p_in_welford_var_grid,
150 p_in_welford_count_grid,
156 mean_var_grid_desc_m_nblock,
157 count_grid_desc_m_nblock,
160 numMeanVarCountBlockTileIteration_N,
184template <
typename ALayout,
190 typename AccDataType,
191 typename CShuffleDataType,
193 typename EMeanVarDataType,
194 typename GammaDataType,
195 typename BetaDataType,
197 typename AElementwiseOperation,
198 typename BElementwiseOperation,
199 typename CDEElementwiseOperation,
200 typename HElementwiseOperation,
213 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
214 typename ABlockTransferThreadClusterArrangeOrder,
215 typename ABlockTransferSrcAccessOrder,
216 index_t ABlockTransferSrcVectorDim,
217 index_t ABlockTransferSrcScalarPerVector,
218 index_t ABlockTransferDstScalarPerVector_AK1,
219 bool ABlockLdsExtraM,
220 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
221 typename BBlockTransferThreadClusterArrangeOrder,
222 typename BBlockTransferSrcAccessOrder,
223 index_t BBlockTransferSrcVectorDim,
224 index_t BBlockTransferSrcScalarPerVector,
225 index_t BBlockTransferDstScalarPerVector_BK1,
226 bool BBlockLdsExtraN,
227 index_t CShuffleMXdlPerWavePerShuffle,
228 index_t CShuffleNXdlPerWavePerShuffle,
229 typename PostShuffleThreadClusterSize_M_N,
230 index_t PostShuffleScalarPerVector,
231 typename LayernormThreadClusterSize_M_N,
232 index_t LayernormThreadSliceSize_M,
246 AElementwiseOperation,
247 BElementwiseOperation,
248 CDEElementwiseOperation,
249 HElementwiseOperation>
271 Sequence<LayernormThreadClusterSize_M_N::At(0) * LayernormThreadSliceSize_M,
283 const auto a_grid_desc_mraw_kraw = [&]() {
296 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
301 const auto b_grid_desc_nraw_kraw = [&]() {
314 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
317 template <
typename DoPads, index_t MPerTile, index_t NPerTile>
321 const auto grid_desc_m_n =
327 const std::array<index_t, NumDTensor>& NRaws,
328 const std::array<index_t, NumDTensor>& DsStride)
336 MakeEHGridDescriptor_M_N<Sequence<true, true>, MPerBlock, NPerBlock>(
337 MRaws[i], NRaws[i], DsStride[i]);
342 template <
typename DoPads, index_t MPerTile, index_t NPerTile>
345 const auto grid_desc_m_n =
350 template <
typename DoPads, index_t MPerTile, index_t NPerTile>
355 const auto grid_desc_m_n =
360 template <index_t XPerTile>
391 template <index_t NXdlPerWave_>
398 AElementwiseOperation,
399 BElementwiseOperation,
400 CDEElementwiseOperation,
408 NumGemmKPrefetchStage,
419 ABlockTransferThreadClusterLengths_AK0_M_AK1,
420 ABlockTransferThreadClusterArrangeOrder,
421 ABlockTransferSrcAccessOrder,
422 ABlockTransferSrcVectorDim,
423 ABlockTransferSrcScalarPerVector,
424 ABlockTransferDstScalarPerVector_AK1,
427 BBlockTransferThreadClusterLengths_BK0_N_BK1,
428 BBlockTransferThreadClusterArrangeOrder,
429 BBlockTransferSrcAccessOrder,
430 BBlockTransferSrcVectorDim,
431 BBlockTransferSrcScalarPerVector,
432 BBlockTransferDstScalarPerVector_BK1,
435 CShuffleMXdlPerWavePerShuffle,
436 CShuffleNXdlPerWavePerShuffle,
437 PostShuffleThreadClusterSize_M_N,
438 PostShuffleScalarPerVector,
456 HElementwiseOperation,
458 LayernormThreadClusterSize_M_N::At(
I0),
459 LayernormThreadClusterSize_M_N::At(
I1),
460 LayernormThreadSliceSize_M,
471 const void* p_b_grid,
472 std::array<const void*, NumDTensor> p_ds_grid,
473 const void* p_gamma_grid,
474 const void* p_beta_grid,
481 std::array<index_t, NumDTensor> StrideDs,
484 AElementwiseOperation a_element_op,
485 BElementwiseOperation b_element_op,
486 CDEElementwiseOperation cde_element_op,
487 HElementwiseOperation h_element_op)
488 :
p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
489 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
495 p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)},
496 p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)},
497 p_h_grid_{static_cast<HDataType*>(p_h_grid)},
503 MRaw, NRaw, StrideH)},
508 MRaw, NRaw, StrideH)},
521 MRaw, NRaw, StrideH)},
536 epsilon_{static_cast<AccDataType>(epsilon)}
564 p_ds_grid_(i) =
static_cast<const DDataType*
>(p_ds_grid[i]);
569 MRaw, NRaw, StrideDs[i]);
636 [&](
auto i) { std::cout <<
"Ds[M, N]: " <<
ds_grid_desc_m_n_[i] << std::endl; });
700 template <
typename Gr
idwiseGemmWelford>
711 throw std::runtime_error(
"wrong! GridwiseGemmWelford has invalid setting");
715 throw std::runtime_error(
"wrong! WorkSpace pointer has not been set");
724 auto launch_kernel = [&](
auto has_main_k_block_loop) {
725 constexpr bool has_main_loop = has_main_k_block_loop.value;
727 const auto kernel_gemm_welford =
731 typename GridwiseGemmWelford::DsGridPointer,
733 AElementwiseOperation,
734 BElementwiseOperation,
735 CDEElementwiseOperation,
736 typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1,
737 typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1,
738 typename GridwiseGemmWelford::
739 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
740 typename GridwiseGemmWelford::
741 EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
742 typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock,
743 typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock,
744 typename GridwiseGemmWelford::DefaultBlock2ETileMap,
747 const auto kernel_welford_layernorm =
758 HElementwiseOperation>;
789 grid_size = MBlockClusterLength * NBlockClusterLength;
796 kernel_welford_layernorm,
813 numMeanVarCountBlockTileIteration_N,
821 if(GridwiseGemmWelford::CalculateHasMainKBlockLoop(K))
827 return launch_kernel(integral_constant<bool, false>{});
837 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
845 size_t workspace_size = 0;
850 workspace_size += gemm_welford_size *
sizeof(EMeanVarDataType) + 64;
853 workspace_size += gemm_welford_size *
sizeof(EMeanVarDataType) + 64;
859 workspace_size += pArg_->
MRaw_ * pArg_->
NRaw_ *
sizeof(EMeanVarDataType);
861 return (workspace_size);
868 Argument* pArg_ =
dynamic_cast<Argument*
>(pArg);
870 pArg_->p_workspace_ = p_workspace;
872 int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
875 pArg_->p_workspace_mean_ =
static_cast<char*
>(pArg_->p_workspace_);
877 index_t mean_space_sz = gemm_welford_size *
sizeof(EMeanVarDataType);
881 pArg_->p_workspace_var_ =
reinterpret_cast<char*
>(pArg_->p_workspace_mean_) + mean_space_sz;
883 index_t variance_space_sz = gemm_welford_size *
sizeof(EMeanVarDataType);
887 pArg_->p_workspace_count_ =
888 reinterpret_cast<char*
>(pArg_->p_workspace_var_) + variance_space_sz;
894 pArg_->p_workspace_e_grid_ =
895 reinterpret_cast<char*
>(pArg_->p_workspace_count_) + count_space_sz;
897 pArg_->p_workspace_e_grid_ =
static_cast<void*
>(pArg_->p_h_grid_);
914 if(arg.
KRaw_ % ABlockTransferSrcScalarPerVector != 0)
922 if(arg.
MRaw_ % ABlockTransferSrcScalarPerVector != 0)
935 if(arg.
KRaw_ % BBlockTransferSrcScalarPerVector != 0)
943 if(arg.
NRaw_ % BBlockTransferSrcScalarPerVector != 0)
955 bool all_valid =
true;
975 if(arg.
NRaw_ % PostShuffleScalarPerVector != 0 ||
1028 std::array<const void*, NumDTensor> p_ds,
1029 const void* p_gamma,
1037 std::array<index_t, NumDTensor> StrideDs,
1040 AElementwiseOperation a_element_op,
1041 BElementwiseOperation b_element_op,
1042 CDEElementwiseOperation cde_element_op,
1043 HElementwiseOperation h_element_op)
1070 std::array<const void*, NumDTensor> p_ds,
1071 const void* p_gamma,
1079 std::array<index_t, NumDTensor> StrideDs,
1082 AElementwiseOperation a_element_op,
1083 BElementwiseOperation b_element_op,
1084 CDEElementwiseOperation cde_element_op,
1085 HElementwiseOperation h_element_op)
override
1087 return std::make_unique<Argument>(p_a,
1110 return std::make_unique<Invoker>(
Invoker{});
1116 auto str = std::stringstream();
1118 std::map<LoopScheduler, std::string> LoopSchedToString{
1121 std::map<PipelineVersion, std::string> PipelineVersionToString{{
PipelineVersion::v1,
"v1"},
1125 str <<
"DeviceGemmMultipleDLayernorm_Xdl_CShuffle"
1127 << BlockSize <<
", "
1128 << MPerBlock <<
", "
1129 << NPerBlock <<
", "
1130 << KPerBlock <<
", "
1134 << PostShuffleThreadClusterSize_M_N::At(
I0) <<
", "
1135 << PostShuffleThreadClusterSize_M_N::At(
I1) <<
", "
1136 << LayernormThreadClusterSize_M_N::At(
I0) <<
", "
1137 << LayernormThreadClusterSize_M_N::At(
I1) <<
", "
1138 << LayernormThreadSliceSize_M
1140 <<
" LoopScheduler: "
1141 << LoopSchedToString[LoopSched] <<
", "
1142 <<
"PipelineVersion: "
1143 << PipelineVersionToString[PipelineVer];
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition utility/math.hpp:13
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
__global__ void kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EMeanVarDataType *__restrict__ p_e_grid, EMeanVarDataType *__restrict__ p_welford_mean_grid, EMeanVarDataType *__restrict__ p_welford_var_grid, int32_t *__restrict__ p_welford_count_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock mean_var_grid_desc_mblock_mperblock_nblock, const CountGridDescriptor_MBlock_MPerBlock_NBlock count_grid_desc_mblock_mperblock_nblock, const Block2ETileMap block_2_etile_map, index_t NRaw)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:42
__global__ void kernel_welford_layernorm2d_second_half(const EMeanVarDataType *__restrict__ p_e_grid, const EMeanVarDataType *__restrict__ p_in_welford_mean_grid, const EMeanVarDataType *__restrict__ p_in_welford_var_grid, const int32_t *__restrict__ p_in_welford_count_grid, const GammaDataType *__restrict__ p_gamma_grid, const BetaDataType *__restrict__ p_beta_grid, HDataType *__restrict__ p_h_grid, const EHGridDesc_M_N e_grid_desc_m_n, const EHGridDesc_M_N h_grid_desc_m_n, const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock, const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock, const GammaBetaGridDesc_N gamma_grid_desc_n, const GammaBetaGridDesc_N beta_grid_desc_n, index_t numMeanVarCountBlockTileIteration_N, index_t NBlockClusterLength, ComputeDataType epsilon, HElementwiseOperation h_element_op)
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:87
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
signed int int32_t
Definition stdint.h:123
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:84
ck::GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer >::CountGridDescriptor_MBlock_MPerBlock_NBlock remove_cvref_t< decltype(MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(GemmCountGridDesc_M_NBlock{}))> CountGridDescriptor_MBlock_MPerBlock_NBlock
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:357
ck::GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer >::DsGridPointer decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:367
ck::GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer >::DefaultBGridDesc_BK0_N_BK1 remove_cvref_t< decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> DefaultBGridDesc_BK0_N_BK1
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:349
ck::GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer >::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock remove_cvref_t< decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))> DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:360
ck::GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer >::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock remove_cvref_t< decltype(MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(GemmMeanVarGridDesc_M_NBlock{}))> MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:354
ck::GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EHGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:274
ck::GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer >::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDescriptor_M_N &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:212
ck::GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer >::DefaultBlock2ETileMap remove_cvref_t< decltype(MakeDefaultBlock2ETileMap(EHGridDesc_M_N{}))> DefaultBlock2ETileMap
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:364
ck::GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer >::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDescriptor_M_N &ds_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:234
ck::GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer >::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EHGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:351
ck::GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer >::DefaultAGridDesc_AK0_M_AK1 remove_cvref_t< decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> DefaultAGridDesc_AK0_M_AK1
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:347
ck::GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer >::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock __host__ static __device__ constexpr auto MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N &grid_desc_m_n)
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:246
Definition gridwise_welford_second_half_layernorm2d.hpp:42
Definition utility/sequence.hpp:43
__host__ static __device__ constexpr index_t At(index_t I)
Definition utility/sequence.hpp:53
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:469
GridwiseGemmWelford64::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:674
AElementwiseOperation a_element_op_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:684
CDEElementwiseOperation cde_element_op_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:686
index_t gemm_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:692
const ADataType * p_a_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:643
void * p_workspace_var_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:648
EHGridDesc_M_N gemm_e_grid_desc_m_n_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:658
EHGridDesc_M_N h_grid_desc_m_n_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:666
GemmMeanVarGridDesc_M_NBlock gemm_mean_var_grid_desc_m_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:660
GridwiseGemmWelford64::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:669
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, const void *p_gamma_grid, const void *p_beta_grid, void *p_h_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideH, double epsilon, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, HElementwiseOperation h_element_op)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:470
index_t MRaw_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:689
const BetaDataType * p_beta_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:651
void Print() const
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:631
HElementwiseOperation h_element_op_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:687
EHGridDesc_M_N layernorm_e_grid_desc_m_n_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:659
index_t NRaw_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:690
AccDataType epsilon_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:693
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:657
BElementwiseOperation b_element_op_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:685
GridwiseGemmWelford64::CountGridDescriptor_MBlock_MPerBlock_NBlock gemm_count_grid_desc_mblock_mperblock_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:678
const GammaDataType * p_gamma_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:650
index_t KRaw_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:691
LayernormCountGridDesc_M_NBlock layernorm_count_grid_desc_m_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:663
GemmCountGridDesc_M_NBlock gemm_count_grid_desc_m_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:661
BGridDesc_N_K b_grid_desc_n_k_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:656
void * p_workspace_count_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:649
const BDataType * p_b_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:644
GammaBetaGridDesc_N beta_grid_desc_n_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:665
GammaBetaGridDesc_N gamma_grid_desc_n_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:664
GridwiseGemmWelford64::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock gemm_mean_var_grid_desc_mblock_mperblock_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:676
Block2ETileMap block_2_etile_map_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:681
GridwiseGemmWelford64::DsGridPointer p_ds_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:645
void * p_workspace_mean_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:647
GridwiseGemmWelford64::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:670
HDataType * p_h_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:652
AGridDesc_M_K a_grid_desc_m_k_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:655
void * p_workspace_e_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:646
LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:662
GridwiseGemmWelford64::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:672
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:698
GridwiseGemmWelford32 GridwiseGemm32
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:830
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:701
DeviceOp::Argument Argument
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:699
GridwiseGemmWelford64 GridwiseGemm64
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:831
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:834
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:250
GridwiseGemmWelfordBase< NXdlPerWave32 > GridwiseGemmWelford32
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:442
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:261
static auto MakeDescriptor_X(index_t X)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:361
static constexpr index_t LayernormGammaSrcVectorSize
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:266
GridwiseWelfordSecondHalfLayernorm2d< EMeanVarDataType, HDataType, GammaDataType, BetaDataType, AccDataType, EHGridDesc_M_N, LayernormMeanVarGridDesc_M_NBlock, LayernormCountGridDesc_M_NBlock, GammaBetaGridDesc_N, HElementwiseOperation, BlockSize, LayernormThreadClusterSize_M_N::At(I0), LayernormThreadClusterSize_M_N::At(I1), LayernormThreadSliceSize_M, LayernormThreadSliceSize_N, LayernormESrcVectorSize, LayernormHDstVectorSize, LayernormGammaSrcVectorSize, LayernormBetaSrcVectorSize > GridwiseWelfordLayernorm
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:446
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:262
DeviceGemmMultipleDLayernorm_Xdl_CShuffle DeviceOp
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:257
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:1021
std::string GetTypeString() const override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:1114
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:368
typename GridwiseGemmWelford64::DefaultBlock2ETileMap Block2ETileMap
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:444
Sequence< LayernormThreadClusterSize_M_N::At(0) *LayernormThreadSliceSize_M, LayernormThreadClusterSize_M_N::At(1) *LayernormThreadSliceSize_N > LayernormBlockTileSize_M_N
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:270
decltype(MakeMeanVarDescriptor_M_N< Sequence< true, true >, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(1, 1)) LayernormMeanVarGridDesc_M_NBlock
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:378
static constexpr auto I0
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:274
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer > GridwiseGemmWelfordBase
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:392
static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:343
static auto MakeInvoker()
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:1065
decltype(MakeCountDescriptor_M_N< Sequence< true, false >, MPerBlock, NPerBlock >(1, 1)) GemmCountGridDesc_M_NBlock
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:375
static constexpr auto matrix_padder
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:278
static constexpr index_t LayernormESrcVectorSize
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:268
static constexpr auto I1
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:275
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:264
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:369
static constexpr index_t LayernormThreadSliceSize_N
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:269
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:281
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:299
decltype(MakeEHGridDescriptor_M_N< Sequence< true, true >, 1, 1 >(1, 1, 1)) EHGridDesc_M_N
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:389
static constexpr auto I2
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:276
static auto MakeEHGridDescriptor_M_N(index_t M, index_t N, index_t Stride)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:318
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, const void *p_gamma, const void *p_beta, void *p_h, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideH, double epsilon, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, HElementwiseOperation h_element_op)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:1026
decltype(MakeDescriptor_X< LayernormBlockTileSize_M_N::At(1)>(1)) GammaBetaGridDesc_N
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:388
static constexpr index_t LayernormBetaSrcVectorSize
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:267
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:900
HLayout ELayout
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:258
static auto MakeCountDescriptor_M_N(index_t M, index_t N)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:351
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:1108
GridwiseGemmWelfordBase< math::max(NXdlPerWave64, 1)> GridwiseGemmWelford64
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:441
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:864
decltype(MakeCountDescriptor_M_N< Sequence< true, true >, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(1, 1)) LayernormCountGridDesc_M_NBlock
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:383
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:841
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, const void *p_gamma, const void *p_beta, void *p_h, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideH, double epsilon, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, HElementwiseOperation h_element_op) override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:1068
static constexpr index_t LayernormHDstVectorSize
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:265
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:367
decltype(MakeMeanVarDescriptor_M_N< Sequence< true, false >, MPerBlock, NPerBlock >(1, 1)) GemmMeanVarGridDesc_M_NBlock
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:372
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:326
Definition device_gemm_multiple_d_layernorm.hpp:40
Definition matrix_padder.hpp:180