device_gemm_xdl_waveletmodel_cshuffle.hpp Source File

device_gemm_xdl_waveletmodel_cshuffle.hpp Source File#

Composable Kernel: device_gemm_xdl_waveletmodel_cshuffle.hpp Source File
device_gemm_xdl_waveletmodel_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21
22template <typename GridwiseGemm,
23 typename ABDataType,
24 typename EDataType,
25 typename AElementwiseOperation,
26 typename BElementwiseOperation,
27 typename EElementwiseOperation,
28 typename AGridDesc_AK0_M_AK1,
29 typename BGridDesc_BK0_N_BK1,
30 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
31 typename Block2ETileMap,
32 bool HasMainKBlockLoop>
33__global__ void
34#if CK_USE_LAUNCH_BOUNDS
36#endif
37 kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType* __restrict__ p_a_grid,
38 const ABDataType* __restrict__ p_b_grid,
39 EDataType* __restrict__ p_e_grid,
40 const AElementwiseOperation a_element_op,
41 const BElementwiseOperation b_element_op,
42 const EElementwiseOperation e_element_op,
43 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
44 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
45 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
46 e_grid_desc_mblock_mperblock_nblock_nperblock,
47 const Block2ETileMap block_2_etile_map)
48{
49#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
50 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
51 {
52 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
53
54 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
55 p_b_grid,
56 p_e_grid,
57 p_shared,
58 a_element_op,
59 b_element_op,
60 e_element_op,
61 a_grid_desc_ak0_m_ak1,
62 b_grid_desc_bk0_n_bk1,
63 e_grid_desc_mblock_mperblock_nblock_nperblock,
64 block_2_etile_map);
65 }
66#else
67 ignore = p_a_grid;
68 ignore = p_b_grid;
69 ignore = p_e_grid;
70 ignore = a_element_op;
71 ignore = b_element_op;
72 ignore = e_element_op;
73 ignore = a_grid_desc_ak0_m_ak1;
74 ignore = b_grid_desc_bk0_n_bk1;
75 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
76 ignore = block_2_etile_map;
77#endif
78}
79
80} // namespace ck
81
82namespace ck {
83namespace tensor_operation {
84namespace device {
85
86template <typename ALayout,
87 typename BLayout,
88 typename ELayout,
89 typename ADataType,
90 typename BDataType,
91 typename GemmAcEDataType,
92 typename CShuffleDataType,
93 typename EDataType,
94 typename AElementwiseOperation,
95 typename BElementwiseOperation,
96 typename CDEElementwiseOperation,
97 GemmSpecialization GemmSpec,
98 index_t NumGemmKPrefetchStage,
99 index_t TileLoadThreadGroupSize,
100 index_t TileMathThreadGroupSize,
101 index_t MPerBlock,
102 index_t NPerBlock,
103 index_t KPerBlock,
104 index_t AK1,
105 index_t BK1,
106 index_t MPerXDL,
107 index_t NPerXDL,
108 index_t MXdlPerWave,
109 index_t NXdlPerWave,
110 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
111 typename ABlockTransferThreadClusterArrangeOrder,
112 typename ABlockTransferSrcAccessOrder,
113 index_t ABlockTransferSrcVectorDim,
114 index_t ABlockTransferSrcScalarPerVector,
115 index_t ABlockTransferDstScalarPerVector_AK1,
116 bool ABlockLdsExtraM,
117 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
118 typename BBlockTransferThreadClusterArrangeOrder,
119 typename BBlockTransferSrcAccessOrder,
120 index_t BBlockTransferSrcVectorDim,
121 index_t BBlockTransferSrcScalarPerVector,
122 index_t BBlockTransferDstScalarPerVector_BK1,
123 bool BBlockLdsExtraN,
124 index_t CShuffleMXdlPerWavePerShuffle,
125 index_t CShuffleNXdlPerWavePerShuffle,
126 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
127 index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
129 BLayout,
130 ELayout,
131 ADataType,
132 BDataType,
133 EDataType,
134 AElementwiseOperation,
135 BElementwiseOperation,
136 CDEElementwiseOperation>
137{
138 static constexpr auto BlockSize = math::max(TileLoadThreadGroupSize, TileMathThreadGroupSize);
140 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
141 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
142
144
145 static constexpr auto I0 = Number<0>{};
146 static constexpr auto I1 = Number<1>{};
147 static constexpr auto I2 = Number<2>{};
148
149 static constexpr auto matrix_padder =
150 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
151
152 static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
153 {
154 const auto a_grid_desc_mraw_kraw = [&]() {
156 {
157 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
158 make_tuple(StrideA, I1));
159 }
161 {
162 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
163 make_tuple(I1, StrideA));
164 }
165 }();
166
167 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
168 }
169
170 static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
171 {
172 const auto b_grid_desc_nraw_kraw = [&]() {
174 {
175 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
176 make_tuple(I1, StrideB));
177 }
179 {
180 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
181 make_tuple(StrideB, I1));
182 }
183 }();
184
185 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
186 }
187
188 template <typename ELay>
189 static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
190 {
191 const auto e_grid_desc_mraw_nraw = [&]() {
193 {
194 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
195 make_tuple(StrideE, I1));
196 }
198 {
199 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
200 make_tuple(I1, StrideE));
201 }
202 }();
203
204 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
205 }
206
207 using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
208 using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
210
211 // GridwiseGemm
212 template <index_t NXdlPerWave_>
214 ADataType, // TODO: distinguish A/B datatype
215 GemmAcEDataType,
216 CShuffleDataType,
217 EDataType,
218 AElementwiseOperation,
219 BElementwiseOperation,
220 CDEElementwiseOperation,
225 NumGemmKPrefetchStage,
226 TileLoadThreadGroupSize,
227 TileMathThreadGroupSize,
228 MPerBlock,
229 NPerBlock,
230 KPerBlock,
231 AK1,
232 BK1,
233 MPerXDL,
234 NPerXDL,
235 MXdlPerWave,
236 NXdlPerWave_,
237 ABlockTransferThreadClusterLengths_AK0_M_AK1,
238 ABlockTransferThreadClusterArrangeOrder,
239 ABlockTransferSrcAccessOrder,
240 ABlockTransferSrcVectorDim,
241 ABlockTransferSrcScalarPerVector,
242 ABlockTransferDstScalarPerVector_AK1,
243 false,
244 ABlockLdsExtraM,
245 BBlockTransferThreadClusterLengths_BK0_N_BK1,
246 BBlockTransferThreadClusterArrangeOrder,
247 BBlockTransferSrcAccessOrder,
248 BBlockTransferSrcVectorDim,
249 BBlockTransferSrcScalarPerVector,
250 BBlockTransferDstScalarPerVector_BK1,
251 false,
252 BBlockLdsExtraN,
253 CShuffleMXdlPerWavePerShuffle,
254 CShuffleNXdlPerWavePerShuffle,
255 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
256 CShuffleBlockTransferScalarPerVector_NPerBlock>;
259
262 AGridDesc_M_K{}))>;
265 BGridDesc_N_K{}))>;
266
268
269 // Argument
270 struct Argument : public BaseArgument
271 {
272 Argument(const ADataType* p_a_grid,
273 const BDataType* p_b_grid,
274 EDataType* p_e_grid,
275 index_t MRaw,
276 index_t NRaw,
277 index_t KRaw,
278 index_t StrideA,
279 index_t StrideB,
280 index_t StrideE,
281 AElementwiseOperation a_element_op,
282 BElementwiseOperation b_element_op,
283 CDEElementwiseOperation cde_element_op)
284 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
285 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
286 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
289 e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
291 GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
293 GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
294 block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
295 a_element_op_{a_element_op},
296 b_element_op_{b_element_op},
297 cde_element_op_{cde_element_op}
298 {
299 }
300
301 void Print() const
302 {
303 std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
304 std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
305 std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
306 }
307
308 // private:
309 // pointers
310 const ADataType* p_a_grid_;
311 const BDataType* p_b_grid_;
312 EDataType* p_e_grid_;
313
314 // tensor descriptors for problem definiton
318
319 // tensor descriptors for block/thread-wise copy
322
323 // block-to-e-tile map
325
326 // element-wise op
327 AElementwiseOperation a_element_op_;
328 BElementwiseOperation b_element_op_;
329 CDEElementwiseOperation cde_element_op_;
330 };
331
332 // Invoker
333 struct Invoker : public BaseInvoker
334 {
336
337 template <typename GridwiseGemm>
338 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
339 {
340#if 0
341 {
342 std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
343 << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
344 << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
345 << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
346
347 std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
348 << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
349 << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
350 << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
351
352 std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
353 << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
354 }
355#endif
356
357 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
361 {
362 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
363 }
364 auto e_grid_desc_mblock_mperblock_nblock_nperblock =
365 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
366 arg.e_grid_desc_m_n_);
367
368 const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_);
369 const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
370
371 auto launch_kernel = [&](auto has_main_k_block_loop) {
372 constexpr bool has_main_loop = has_main_k_block_loop.value;
373
374 const auto kernel = kernel_gemm_xdl_waveletmodel_cshuffle<
375 GridwiseGemm,
376 ADataType, // TODO: distiguish A/B datatype
377 EDataType,
378 AElementwiseOperation,
379 BElementwiseOperation,
380 CDEElementwiseOperation,
383 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
384 typename GridwiseGemm::DefaultBlock2ETileMap,
385 has_main_loop>;
386
388 stream_config,
389 kernel,
390 dim3(grid_size),
391 dim3(TileLoadThreadGroupSize + TileMathThreadGroupSize),
392 0,
393 arg.p_a_grid_,
394 arg.p_b_grid_,
395 arg.p_e_grid_,
396 arg.a_element_op_,
397 arg.b_element_op_,
398 arg.cde_element_op_,
401 e_grid_desc_mblock_mperblock_nblock_nperblock,
403 };
404
405 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
406 {
407 return launch_kernel(integral_constant<bool, true>{});
408 }
409 else
410 {
411 return launch_kernel(integral_constant<bool, false>{});
412 }
413 }
414
416
417 // polymorphic
418 float Run(const BaseArgument* p_arg,
419 const StreamConfig& stream_config = StreamConfig{}) override
420 {
421 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
422 }
423 };
424
425 static bool IsSupportedArgument(const Argument& arg)
426 {
428 {
429 return false;
430 }
431 if(get_warp_size() == 64)
432 {
433 if constexpr(NXdlPerWave64 > 0)
434 {
439 }
440 }
441 else
442 {
443 if constexpr(NXdlPerWave32 > 0)
444 {
449 }
450 }
451 return false;
452 }
453
454 // polymorphic
455 bool IsSupportedArgument(const BaseArgument* p_arg) override
456 {
457 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
458 }
459
460 static auto MakeArgument(const ADataType* p_a,
461 const BDataType* p_b,
462 EDataType* p_e,
463 index_t MRaw,
464 index_t NRaw,
465 index_t KRaw,
466 index_t StrideA,
467 index_t StrideB,
468 index_t StrideE,
469 AElementwiseOperation a_element_op,
470 BElementwiseOperation b_element_op,
471 CDEElementwiseOperation cde_element_op)
472 {
473 return Argument{p_a,
474 p_b,
475 p_e,
476 MRaw,
477 NRaw,
478 KRaw,
479 StrideA,
480 StrideB,
481 StrideE,
482 a_element_op,
483 b_element_op,
484 cde_element_op};
485 }
486
487 static auto MakeInvoker() { return Invoker{}; }
488
489 // polymorphic
490 std::unique_ptr<BaseArgument>
491 MakeArgumentPointer(const void* p_a,
492 const void* p_b,
493 void* p_e,
494 index_t MRaw,
495 index_t NRaw,
496 index_t KRaw,
497 index_t StrideA,
498 index_t StrideB,
499 index_t StrideE,
500 AElementwiseOperation a_element_op,
501 BElementwiseOperation b_element_op,
502 CDEElementwiseOperation cde_element_op) override
503 {
504 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
505 static_cast<const BDataType*>(p_b),
506 static_cast<EDataType*>(p_e),
507 MRaw,
508 NRaw,
509 KRaw,
510 StrideA,
511 StrideB,
512 StrideE,
513 a_element_op,
514 b_element_op,
515 cde_element_op);
516 }
517
518 // polymorphic
519 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
520 {
521 return std::make_unique<Invoker>(Invoker{});
522 }
523
524 // polymorphic
525 std::string GetTypeString() const override
526 {
527 auto str = std::stringstream();
528
529 // clang-format off
530 str << "DeviceGemm_Xdl_WaveletModel_CShuffle"
531 << "<"
532 << TileLoadThreadGroupSize << ", "
533 << TileMathThreadGroupSize << ", "
534 << MPerBlock << ", "
535 << NPerBlock << ", "
536 << KPerBlock << ", "
537 << AK1 << ", "
538 << BK1
539 << ">";
540 // clang-format on
541
542 return str.str();
543 }
544};
545
546} // namespace device
547} // namespace tensor_operation
548} // namespace ck
#define CK_WAVELET_MIN_BLOCK_PER_CU
Definition ck.hpp:35
#define CK_WAVELET_MAX_THREAD_PER_BLOCK
Definition ck.hpp:34
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
__global__ void kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const EElementwiseOperation e_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:37
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:64
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition device_base.hpp:197
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:271
AElementwiseOperation a_element_op_
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:327
BGridDesc_N_K b_grid_desc_n_k_
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:316
const BDataType * p_b_grid_
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:311
Block2ETileMap block_2_etile_map_
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:324
CDEElementwiseOperation cde_element_op_
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:329
const ADataType * p_a_grid_
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:310
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:321
BElementwiseOperation b_element_op_
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:328
EDataType * p_e_grid_
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:312
void Print() const
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:301
AGridDesc_M_K a_grid_desc_m_k_
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:315
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, EDataType *p_e_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:272
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:320
EGridDesc_M_N e_grid_desc_m_n_
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:317
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:334
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:418
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:338
DeviceOp::Argument Argument
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:335
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:137
static constexpr auto matrix_padder
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:149
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:263
static constexpr auto I1
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:146
static auto MakeInvoker()
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:487
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:170
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:260
GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle< ADataType, GemmAcEDataType, CShuffleDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, NumGemmKPrefetchStage, TileLoadThreadGroupSize, TileMathThreadGroupSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock > GridwiseGemmBase
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:213
static constexpr auto I2
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:147
static constexpr auto I0
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:145
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:491
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:141
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:140
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:209
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:455
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:152
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:519
typename GridwiseGemm64::DefaultBlock2ETileMap Block2ETileMap
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:267
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:207
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:425
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:189
static constexpr auto BlockSize
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:138
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:258
DeviceGemm_Xdl_WaveletModel_CShuffle DeviceOp
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:143
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, EDataType *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:460
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:208
std::string GetTypeString() const override
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:525
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_waveletmodel_cshuffle.hpp:257
Definition device_gemm.hpp:22
Definition matrix_padder.hpp:180