device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp Source File

device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp Source File#

Composable Kernel: device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp Source File
device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <functional>
7#include <iostream>
8#include <iterator>
9#include <numeric>
10#include <sstream>
11
26
27namespace ck {
28namespace tensor_operation {
29namespace device {
30
31namespace {
32
33template <index_t NumDTensor, index_t NumRTensor>
34struct ComputePtrOffsetOfStridedBatch
35{
36 ComputePtrOffsetOfStridedBatch() = default;
37
38 ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
39 index_t BatchStrideB,
40 Array<ck::index_t, NumDTensor> BatchStrideDs,
41 index_t BatchStrideE,
42 Array<ck::index_t, NumRTensor> BatchStrideRs)
43 : BatchStrideA_(BatchStrideA),
44 BatchStrideB_(BatchStrideB),
45 BatchStrideDs_(BatchStrideDs),
46 BatchStrideE_(BatchStrideE),
47 BatchStrideRs_(BatchStrideRs)
48 {
49 }
50
51 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
52 {
53 return g_idx * static_cast<long_index_t>(BatchStrideA_);
54 }
55
56 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
57 {
58 return g_idx * static_cast<long_index_t>(BatchStrideB_);
59 }
60
61 __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
62 {
63 Array<long_index_t, NumDTensor> ds_offset;
64 static_for<0, NumDTensor, 1>{}(
65 [&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
66 return ds_offset;
67 }
68
69 __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
70 {
71 return g_idx * static_cast<long_index_t>(BatchStrideE_);
72 }
73
74 __host__ __device__ constexpr auto GetRsPtrOffset(index_t g_idx) const
75 {
76 Array<long_index_t, NumRTensor> rs_offset;
77 static_for<0, NumRTensor, 1>{}(
78 [&](auto i) { rs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideRs_[i]); });
79 return rs_offset;
80 }
81
82 index_t BatchStrideA_;
83 index_t BatchStrideB_;
84 Array<ck::index_t, NumDTensor> BatchStrideDs_;
85 index_t BatchStrideE_;
86 Array<ck::index_t, NumRTensor> BatchStrideRs_;
87};
88
89/*
90 * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
91 *
92 * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
93 * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
94 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
95 * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
96 * limitations.
97 *
98 * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
99 * returns the 2D index of the tile that it computes. \see
100 * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
101 *
102 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
103 * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
104 * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
105 * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
106 * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
107 * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
108 *
109 * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
110 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
111 * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
112 *
113 */
114template <typename GridwiseGemm,
115 typename ABDataType,
116 typename DsPointer,
117 typename EDataType,
118 typename RsPointer,
119 typename AElementwiseOperation,
120 typename BElementwiseOperation,
121 typename CDEElementwiseOperation,
122 typename QsElementwiseOperation,
123 typename RsElementwiseOperation,
124 typename AGridDesc_AK0_M_AK1,
125 typename BGridDesc_BK0_N_BK1,
126 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
127 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
128 typename RsGridDescriptor_MBlock_MPerBlock,
129 typename Block2ETileMap,
130 typename ComputePtrOffsetOfBatch,
131 bool HasMainKBlockLoop>
132__global__ void
133#if CK_USE_LAUNCH_BOUNDS
135#endif
136 kernel_batch_gemm_multiple_d_xdl_cshuffle(
137 const ABDataType* __restrict__ p_a_grid,
138 const ABDataType* __restrict__ p_b_grid,
139 DsPointer p_ds_grid,
140 EDataType* __restrict__ p_e_grid,
141 RsPointer p_rs_grid,
142 const AElementwiseOperation a_element_op,
143 const BElementwiseOperation b_element_op,
144 const CDEElementwiseOperation cde_element_op,
145 const QsElementwiseOperation qs_element_op,
146 const RsElementwiseOperation rs_element_op,
147 const index_t batch_count,
148 const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
149 const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
150 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
151 ds_grid_desc_mblock_mperblock_nblock_nperblock,
152 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
153 e_grid_desc_mblock_mperblock_nblock_nperblock_,
154 const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock,
155 const Block2ETileMap block_2_ctile_map,
156 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
157{
158#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
159 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
160 {
161 const index_t num_blocks_per_batch =
162 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
163 const index_t g_idx =
164 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
165
166 const long_index_t a_batch_offset = amd_wave_read_first_lane(
167 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
168 const long_index_t b_batch_offset = amd_wave_read_first_lane(
169 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
170 const long_index_t e_batch_offset = amd_wave_read_first_lane(
171 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
172
173 const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
174 const auto rs_batch_offset = compute_ptr_offset_of_batch.GetRsPtrOffset(g_idx);
175
176 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
177
178 DsPointer p_ds_grid_grp;
179
180 static constexpr index_t NumDTensor =
181 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
182
183 static_for<0, NumDTensor, 1>{}(
184 [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
185
186 RsPointer p_rs_grid_grp;
187
188 static constexpr index_t NumRTensor = RsGridDescriptor_MBlock_MPerBlock::Size();
189
190 static_for<0, NumRTensor, 1>{}(
191 [&](auto i) { p_rs_grid_grp(i) = p_rs_grid[i] + rs_batch_offset[i]; });
192
193 GridwiseGemm::template Run<HasMainKBlockLoop>(
194 p_a_grid + a_batch_offset,
195 p_b_grid + b_batch_offset,
196 p_ds_grid_grp,
197 p_e_grid + e_batch_offset,
198 p_rs_grid_grp,
199 p_shared,
200 a_element_op,
201 b_element_op,
202 cde_element_op,
203 qs_element_op,
204 rs_element_op,
205 a_grid_desc_k0_m_k1,
206 b_grid_desc_k0_n_k1,
207 ds_grid_desc_mblock_mperblock_nblock_nperblock,
208 e_grid_desc_mblock_mperblock_nblock_nperblock_,
209 rs_grid_desc_mblock_mperblock,
210 block_2_ctile_map);
211 }
212#else
213 ignore = p_a_grid;
214 ignore = p_b_grid;
215 ignore = p_ds_grid;
216 ignore = p_e_grid;
217 ignore = p_rs_grid;
218 ignore = batch_count;
219 ignore = a_grid_desc_k0_m_k1;
220 ignore = b_grid_desc_k0_n_k1;
221 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
222 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
223 ignore = rs_grid_desc_mblock_mperblock;
224 ignore = a_element_op;
225 ignore = b_element_op;
226 ignore = cde_element_op;
227 ignore = qs_element_op;
228 ignore = rs_element_op;
229 ignore = compute_ptr_offset_of_batch;
230 ignore = block_2_ctile_map;
231#endif
232}
233
234} // namespace
235
236template <index_t NDimSpatial,
237 typename ALayout,
238 typename BLayout,
239 typename DELayout,
240 typename RLayout,
241 typename ADataType,
242 typename BDataType,
243 typename AccDataType,
244 typename CShuffleDataType,
245 typename DsDataType,
246 typename EDataType,
247 typename ReduceAccDataType,
248 typename RsDataType,
249 typename AElementwiseOperation,
250 typename BElementwiseOperation,
251 typename CDEElementwiseOperation,
252 typename QsElementwiseOperation,
253 typename RsElementwiseOperation,
254 typename ThreadReduceOperations,
255 typename RsGlobalMemoryDataOperation,
256 ConvolutionForwardSpecialization ConvForwardSpecialization,
257 GemmSpecialization GemmSpec,
258 index_t NumGemmKPrefetchStage,
259 index_t BlockSize,
260 index_t MPerBlock,
261 index_t NPerBlock,
262 index_t KPerBlock,
263 index_t AK1,
264 index_t BK1,
265 index_t MPerXDL,
266 index_t NPerXDL,
267 index_t MXdlPerWave,
268 index_t NXdlPerWave,
269 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
270 typename ABlockTransferThreadClusterArrangeOrder,
271 typename ABlockTransferSrcAccessOrder,
272 index_t ABlockTransferSrcVectorDim,
273 index_t ABlockTransferSrcScalarPerVector,
274 index_t ABlockTransferDstScalarPerVector_AK1,
275 index_t ABlockLdsExtraM,
276 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
277 typename BBlockTransferThreadClusterArrangeOrder,
278 typename BBlockTransferSrcAccessOrder,
279 index_t BBlockTransferSrcVectorDim,
280 index_t BBlockTransferSrcScalarPerVector,
281 index_t BBlockTransferDstScalarPerVector_BK1,
282 index_t BBlockLdsExtraN,
283 index_t CShuffleMXdlPerWavePerShuffle,
284 index_t CShuffleNXdlPerWavePerShuffle,
285 typename CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
286 index_t CDEBlockTransferScalarPerVector_NPerBlock,
287 index_t RThreadTransferDstScalarPerVector_MPerBlock,
290 : public DeviceGroupedConvFwdMultipleDMultipleR<NDimSpatial,
291 ALayout,
292 BLayout,
293 DELayout,
294 RLayout,
295 ADataType,
296 BDataType,
297 DsDataType,
298 EDataType,
299 RsDataType,
300 AElementwiseOperation,
301 BElementwiseOperation,
302 CDEElementwiseOperation,
303 RsElementwiseOperation,
304 QsElementwiseOperation>
305{
308 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
309 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
310
311 static constexpr index_t NumDTensor = DsDataType::Size();
312 static constexpr index_t NumRTensor = RsDataType::Size();
313
314 static constexpr auto I0 = Number<0>{};
315 static constexpr auto I1 = Number<1>{};
316 static constexpr auto I2 = Number<2>{};
317 static constexpr auto I3 = Number<3>{};
318
320
321 static constexpr auto matrix_padder =
322 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
323
324 template <typename ALay>
325 static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
326 {
327 const auto in_gemmmraw_gemmkraw_desc =
328 conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
329
330 const auto in_gemmm_gemmk_desc =
331 matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
332
333 return in_gemmm_gemmk_desc;
334 }
335
336 template <typename BLay>
337 static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
338 {
339 const auto wei_gemmnraw_gemmkraw_desc =
340 conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
341
342 const auto wei_gemmn_gemmk_desc =
343 matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
344
345 return wei_gemmn_gemmk_desc;
346 }
347
348 template <typename ELay>
349 static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
350 {
351 const auto out_gemmmraw_gemmnraw_desc =
352 conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
353
354 const auto out_gemmm_gemmn_desc =
355 matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
356
357 return out_gemmm_gemmn_desc;
358 }
359
360 template <typename Descriptor>
361 static auto GetPaddedRGridDescriptor(Descriptor descriptor, index_t MRaw)
362 {
363 const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
364 const auto MPad = M - MRaw;
365
366 if constexpr(GemmSpec == GemmSpecialization::MPadding ||
367 GemmSpec == GemmSpecialization::MNPadding ||
368 GemmSpec == GemmSpecialization::MKPadding ||
370 {
371 // pad M
372 return transform_tensor_descriptor(descriptor,
376 }
377 else
378 {
379 // not pad M
380 return descriptor;
381 }
382 }
383
384 template <typename RLay,
385 typename std::enable_if<is_same_v<RLay, tensor_layout::convolution::GNW> ||
388 bool>::type = false>
389 static auto
390 MakeRGridDescriptor_M(const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
391 const std::array<index_t, NDimSpatial + 2>& /* r_g_n_wos_strides */)
392 {
393 const index_t N = r_g_n_wos_lengths[1];
394
395 const index_t NHoWo =
397 r_g_n_wos_lengths.begin() + 2, NDimSpatial, 1, std::multiplies<>());
398
399 const auto r_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(NHoWo));
400
401 return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
402 }
403
404 template <typename RLay,
405 typename std::enable_if<is_same_v<RLay, tensor_layout::convolution::G_NW> ||
411 bool>::type = false>
412 static auto MakeRGridDescriptor_M(const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
413 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides)
414 {
415 const index_t N = r_g_n_wos_lengths[1];
416
417 const index_t WoStride = r_g_n_wos_strides[NDimSpatial + 2];
418
419 const index_t NHoWo =
421 r_g_n_wos_lengths.begin() + 2, NDimSpatial, 1, std::multiplies<>());
422
423 const auto r_grid_desc_mraw =
425
426 return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
427 }
428
437
438 // GridwiseGemm
439 template <index_t NXdlPerWave_>
441 ADataType, // TODO: distinguish A/B datatype
442 AccDataType,
443 CShuffleDataType,
444 DsDataType,
445 EDataType,
446 ReduceAccDataType,
447 RsDataType,
448 AElementwiseOperation,
449 BElementwiseOperation,
450 CDEElementwiseOperation,
451 QsElementwiseOperation,
452 RsElementwiseOperation,
453 ThreadReduceOperations,
455 RsGlobalMemoryDataOperation,
460 NumGemmKPrefetchStage,
461 BlockSize,
462 MPerBlock,
463 NPerBlock,
464 KPerBlock,
465 AK1,
466 BK1,
467 MPerXDL,
468 NPerXDL,
469 MXdlPerWave,
470 NXdlPerWave_,
471 ABlockTransferThreadClusterLengths_AK0_M_AK1,
472 ABlockTransferThreadClusterArrangeOrder,
473 ABlockTransferSrcAccessOrder,
474 ABlockTransferSrcVectorDim,
475 ABlockTransferSrcScalarPerVector,
476 ABlockTransferDstScalarPerVector_AK1,
477 false,
478 ABlockLdsExtraM,
479 BBlockTransferThreadClusterLengths_BK0_N_BK1,
480 BBlockTransferThreadClusterArrangeOrder,
481 BBlockTransferSrcAccessOrder,
482 BBlockTransferSrcVectorDim,
483 BBlockTransferSrcScalarPerVector,
484 BBlockTransferDstScalarPerVector_BK1,
485 false,
486 BBlockLdsExtraN,
487 CShuffleMXdlPerWavePerShuffle,
488 CShuffleNXdlPerWavePerShuffle,
489 CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
490 CDEBlockTransferScalarPerVector_NPerBlock,
491 RThreadTransferDstScalarPerVector_MPerBlock,
492 LoopSched>;
495
498 AGridDesc_M_K{}))>;
501 BGridDesc_N_K{}))>;
502
504
505 // Argument
506 struct Argument : public BaseArgument
507 {
508 Argument(const void* p_a,
509 const void* p_b,
510 const std::array<const void*, NumDTensor>& p_ds,
511 void* p_e,
512 std::array<void*, NumRTensor> p_rs,
513 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
514 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
515 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
516 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
517 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
518 ds_g_n_k_wos_lengths,
519 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
520 ds_g_n_k_wos_strides,
521 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
522 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
523 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
524 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
525 const std::array<index_t, NDimSpatial>& conv_filter_strides,
526 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
527 const std::array<index_t, NDimSpatial>& input_left_pads,
528 const std::array<index_t, NDimSpatial>& input_right_pads,
529 const AElementwiseOperation& a_element_op,
530 const BElementwiseOperation& b_element_op,
531 const CDEElementwiseOperation& cde_element_op,
532 const QsElementwiseOperation& qs_element_op,
533 const RsElementwiseOperation& rs_element_op)
534 : p_a_grid_{static_cast<const ADataType*>(p_a)},
535 p_b_grid_{static_cast<const BDataType*>(p_b)},
536 p_ds_grid_{},
537 p_e_grid_{static_cast<EDataType*>(p_e)},
538 p_rs_grid_{}, // FIXME
539 conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
540 a_g_n_c_wis_strides,
541 b_g_k_c_xs_lengths,
542 b_g_k_c_xs_strides,
543 e_g_n_k_wos_lengths,
544 e_g_n_k_wos_strides,
545 conv_filter_strides,
546 conv_filter_dilations,
547 input_left_pads,
548 input_right_pads},
557 DeviceOp::MakeRGridDescriptor_M<RLayout>(r_g_n_wos_lengths, r_g_n_wos_strides)},
559 GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
561 GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
562 block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
564 a_element_op_{a_element_op},
565 b_element_op_{b_element_op},
566 cde_element_op_{cde_element_op},
567 qs_element_op_{qs_element_op},
568 rs_element_op_{rs_element_op},
569 a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
570 a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
571 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
572 b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
573 ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
574 ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
575 e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
576 e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
577 conv_filter_strides_{conv_filter_strides},
578 conv_filter_dilations_{conv_filter_dilations},
579 input_left_pads_{input_left_pads},
580 input_right_pads_{input_right_pads}
581 {
582 // A/B/E Batch Stride
583 compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
584 compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
585 compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
586
587 // populate pointer, batch stride, desc for Ds
588 static_for<0, NumDTensor, 1>{}([&](auto i) {
589 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
590
591 // D pointer
592 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
593
594 // D batch stride
595 compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
596
597 ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
598 a_g_n_c_wis_strides,
599 b_g_k_c_xs_lengths,
600 b_g_k_c_xs_strides,
601 ds_g_n_k_wos_lengths[i],
602 ds_g_n_k_wos_strides[i],
603 conv_filter_strides,
604 conv_filter_dilations,
605 input_left_pads,
606 input_right_pads};
607
608 // D desc
610 DeviceOp::MakeEGridDescriptor_M_N<DELayout>(conv_to_gemm_transformer_d);
611 });
612
613 // populate pointer for Rs
614 static_for<0, NumRTensor, 1>{}([&](auto i) {
615 using RDataType = remove_cvref_t<tuple_element_t<i.value, RsDataType>>;
616
617 // R pointer
618 p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
619 compute_ptr_offset_of_batch_.BatchStrideRs_(i) = r_g_n_wos_strides[0];
620 });
621 }
622
623 void Print() const
624 {
625 std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
626 std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
628 [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
629 std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
630 }
631
632 // private:
633 // pointers
634 const ADataType* p_a_grid_;
635 const BDataType* p_b_grid_;
637 EDataType* p_e_grid_;
639
641
642 // tensor descriptors for problem definiton
648
649 // tensor descriptors for block/thread-wise copy
652
653 // block-to-e-tile map
655
656 ComputePtrOffsetOfStridedBatch<NumDTensor, NumRTensor> compute_ptr_offset_of_batch_;
657
658 // element-wise op
659 AElementwiseOperation a_element_op_;
660 BElementwiseOperation b_element_op_;
661 CDEElementwiseOperation cde_element_op_;
662 QsElementwiseOperation qs_element_op_;
663 RsElementwiseOperation rs_element_op_;
664
665 // for checking IsSupportedArgument()
666 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
667 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
668 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
669 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
670 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
671 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
672 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
673 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
674 std::array<index_t, NDimSpatial> conv_filter_strides_;
675 std::array<index_t, NDimSpatial> conv_filter_dilations_;
676 std::array<index_t, NDimSpatial> input_left_pads_;
677 std::array<index_t, NDimSpatial> input_right_pads_;
678 };
679
680 // Invoker
681 struct Invoker : public BaseInvoker
682 {
684
685 template <typename GridwiseGemm>
686 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
687 {
688 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
691 arg.r_grid_desc_m_,
693 {
694 throw std::runtime_error(
695 "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting");
696 }
697
699 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
701 ds_grid_desc_mblock_mperblock_nblock_nperblock = {};
702
705 rs_grid_desc_mblock_mperblock = {};
706
707 auto e_grid_desc_mblock_mperblock_nblock_nperblock =
709 arg.e_grid_desc_m_n_);
710
711 // populate pointer, batch stride, desc for Ds
712 static_for<0, NumDTensor, 1>{}([&](auto i) {
713 ds_grid_desc_mblock_mperblock_nblock_nperblock(i) =
715 arg.ds_grid_desc_m_n_(i));
716 });
717
718 // populate pointer for Rs
719 static_for<0, NumRTensor, 1>{}([&](auto i) {
720 rs_grid_desc_mblock_mperblock(i) =
722 });
723
724 const index_t grid_size =
725 arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) *
726 arg.a_g_n_c_wis_lengths_[0]; // Group count
727
728 const auto K =
729 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
730
731 auto launch_kernel = [&](auto has_main_k_block_loop) {
732 constexpr bool has_main_loop = has_main_k_block_loop.value;
733
734 const auto kernel = kernel_batch_gemm_multiple_d_xdl_cshuffle<
735 GridwiseGemm,
736 ADataType, // TODO: distiguish A/B datatype
737 typename GridwiseGemm::DsGridPointer,
738 EDataType,
739 typename GridwiseGemm::RsGridPointer,
740 AElementwiseOperation,
741 BElementwiseOperation,
742 CDEElementwiseOperation,
743 QsElementwiseOperation,
744 RsElementwiseOperation,
748 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
749 NumDTensor>,
750 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
752 typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock,
753 NumRTensor>,
755 ComputePtrOffsetOfStridedBatch<NumDTensor, NumRTensor>,
756 has_main_loop>;
757
758 return launch_and_time_kernel(stream_config,
759 kernel,
760 dim3(grid_size),
761 dim3(BlockSize),
762 0,
763 arg.p_a_grid_,
764 arg.p_b_grid_,
765 arg.p_ds_grid_,
766 arg.p_e_grid_,
767 arg.p_rs_grid_,
768 arg.a_element_op_,
769 arg.b_element_op_,
770 arg.cde_element_op_,
771 arg.qs_element_op_,
772 arg.rs_element_op_,
773 arg.a_g_n_c_wis_lengths_[0], // Group count
776 ds_grid_desc_mblock_mperblock_nblock_nperblock,
777 e_grid_desc_mblock_mperblock_nblock_nperblock,
778 rs_grid_desc_mblock_mperblock,
781 };
782
783 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
784 {
785 return launch_kernel(integral_constant<bool, true>{});
786 }
787 else
788 {
789 return launch_kernel(integral_constant<bool, false>{});
790 }
791 }
792
794
795 float Run(const BaseArgument* p_arg,
796 const StreamConfig& stream_config = StreamConfig{}) override
797 {
798 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
799 }
800 };
801
802 static bool IsSupportedArgument(const Argument& arg)
803 {
804 namespace ctc = tensor_layout::convolution;
806 {
807 return false;
808 }
809 // check device
810 if(get_device_name() == "gfx908")
811 {
814 {
815 return false;
816 }
817 }
819 {
822 {
823 return false;
824 }
825 }
827 {
829 {
830 return false;
831 }
832 }
833 else
834 {
835 // return false;
836 }
837
838 // check ConvolutionForwardSpecialization
839 if constexpr(ConvForwardSpecialization ==
841 {
842 // check if it's 1x1, stride=1 conv
843 for(index_t i = 0; i < NDimSpatial; ++i)
844 {
845 const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
846 const index_t ConvStride = arg.conv_filter_strides_[i];
847 const index_t LeftPad = arg.input_left_pads_[i];
848 const index_t RightPad = arg.input_right_pads_[i];
849
850 if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
851 {
852 return false;
853 }
854 }
855 }
856 else if constexpr(ConvForwardSpecialization ==
858 {
859 // check if it's 1x1 conv
860 for(index_t i = 0; i < NDimSpatial; ++i)
861 {
862 const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
863 const index_t LeftPad = arg.input_left_pads_[i];
864 const index_t RightPad = arg.input_right_pads_[i];
865
866 if(!(X == 1 && LeftPad == 0 && RightPad == 0))
867 {
868 return false;
869 }
870 }
871 }
872
873 // check vector access of A
874 // FIXME: layout
880 {
881 const index_t C = arg.a_g_n_c_wis_lengths_[2];
882
883 if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
884 {
885 return false;
886 }
887 }
888 else
889 {
890 return false;
891 }
892
893 // check vector access of B
894 // FIXME: layout
900
901 {
902 const index_t C = arg.b_g_k_c_xs_lengths_[2];
903
904 if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
905 {
906 return false;
907 }
908 }
909 else
910 {
911 return false;
912 }
913
914 // check vector access of Ds
915 bool valid = true;
916
917 static_for<0, NumDTensor, 1>{}([&](auto i) {
918 // FIXME: layout
924 {
925 const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
926
927 if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
928 {
929 valid = false;
930 }
931 }
932 else
933 {
934 valid = false;
935 }
936 });
937
938 if(!valid)
939 {
940 return false;
941 }
942
943 // check vector access of E
949 {
950 const index_t K = arg.e_g_n_k_wos_lengths_[2];
951
952 if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
953 {
954 return false;
955 }
956 }
957 else
958 {
959 return false;
960 }
961
962 // check vector access of R
968 {
969 return false;
970 }
971
972 // check Gridwise GEMM
973 if(get_warp_size() == 64)
974 {
975 if constexpr(NXdlPerWave64 > 0)
976 {
980 arg.r_grid_desc_m_,
982 }
983 }
984 else
985 {
986 if constexpr(NXdlPerWave32 > 0)
987 {
991 arg.r_grid_desc_m_,
993 }
994 }
995 return false;
996 }
997
998 bool IsSupportedArgument(const BaseArgument* p_arg) override
999 {
1000 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1001 }
1002
1003 static auto MakeArgument(
1004 const void* p_a,
1005 const void* p_b,
1006 const std::array<const void*, NumDTensor>& p_ds,
1007 void* p_e,
1008 std::array<void*, NumRTensor> p_rs,
1009 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1010 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1011 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1012 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1013 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
1014 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
1015 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1016 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1017 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
1018 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
1019 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1020 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1021 const std::array<index_t, NDimSpatial>& input_left_pads,
1022 const std::array<index_t, NDimSpatial>& input_right_pads,
1023 const AElementwiseOperation& a_element_op,
1024 const BElementwiseOperation& b_element_op,
1025 const CDEElementwiseOperation& cde_element_op,
1026 const QsElementwiseOperation& qs_element_op,
1027 const RsElementwiseOperation& rs_element_op)
1028 {
1029 return Argument{p_a,
1030 p_b,
1031 p_ds,
1032 p_e,
1033 p_rs,
1034 a_g_n_c_wis_lengths,
1035 a_g_n_c_wis_strides,
1036 b_g_k_c_xs_lengths,
1037 b_g_k_c_xs_strides,
1038 ds_g_n_k_wos_lengths,
1039 ds_g_n_k_wos_strides,
1040 e_g_n_k_wos_lengths,
1041 e_g_n_k_wos_strides,
1042 r_g_n_wos_lengths,
1043 r_g_n_wos_strides,
1044 conv_filter_strides,
1045 conv_filter_dilations,
1046 input_left_pads,
1047 input_right_pads,
1048 a_element_op,
1049 b_element_op,
1050 cde_element_op,
1051 qs_element_op,
1052 rs_element_op};
1053 }
1054
1055 static auto MakeInvoker() { return Invoker{}; }
1056
1057 std::unique_ptr<BaseArgument> MakeArgumentPointer(
1058 const void* p_a,
1059 const void* p_b,
1060 const std::array<const void*, NumDTensor>& p_ds,
1061 void* p_e,
1062 std::array<void*, NumRTensor> p_rs,
1063 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1064 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1065 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1066 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1067 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
1068 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
1069 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1070 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1071 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
1072 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
1073 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1074 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1075 const std::array<index_t, NDimSpatial>& input_left_pads,
1076 const std::array<index_t, NDimSpatial>& input_right_pads,
1077 const AElementwiseOperation& a_element_op,
1078 const BElementwiseOperation& b_element_op,
1079 const CDEElementwiseOperation& cde_element_op,
1080 const QsElementwiseOperation& qs_element_op,
1081 const RsElementwiseOperation& rs_element_op) override
1082 {
1083 return std::make_unique<Argument>(p_a,
1084 p_b,
1085 p_ds,
1086 p_e,
1087 p_rs,
1088 a_g_n_c_wis_lengths,
1089 a_g_n_c_wis_strides,
1090 b_g_k_c_xs_lengths,
1091 b_g_k_c_xs_strides,
1092 ds_g_n_k_wos_lengths,
1093 ds_g_n_k_wos_strides,
1094 e_g_n_k_wos_lengths,
1095 e_g_n_k_wos_strides,
1096 r_g_n_wos_lengths,
1097 r_g_n_wos_strides,
1098 conv_filter_strides,
1099 conv_filter_dilations,
1100 input_left_pads,
1101 input_right_pads,
1102 a_element_op,
1103 b_element_op,
1104 cde_element_op,
1105 qs_element_op,
1106 rs_element_op);
1107 }
1108
1109 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1110 {
1111 return std::make_unique<Invoker>(Invoker{});
1112 }
1113
1114 std::string GetTypeString() const override
1115 {
1116 auto str = std::stringstream();
1117
1118 // clang-format off
1119 str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
1120 << "<"
1121 << BlockSize << ", "
1122 << MPerBlock << ", "
1123 << NPerBlock << ", "
1124 << KPerBlock << ", "
1125 << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
1126 << MPerXDL << ", "
1127 << NPerXDL << ", "
1128 << MXdlPerWave << ", "
1129 << NXdlPerWave << ", "
1130 << ABlockTransferSrcScalarPerVector << ", "
1131 << BBlockTransferSrcScalarPerVector << ", "
1132 << CShuffleMXdlPerWavePerShuffle << ", "
1133 << CShuffleNXdlPerWavePerShuffle
1134 << ">";
1135 // clang-format on
1136
1137 return str.str();
1138 }
1139};
1140
1141} // namespace device
1142} // namespace tensor_operation
1143} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
bool is_lds_direct_load_supported()
Definition host_utility/device_prop.hpp:101
__device__ index_t get_grid_size()
Definition get_id.hpp:49
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:74
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const RGridDesc_M &r_grid_desc_m, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:208
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition utility/sequence.hpp:43
Definition functional2.hpp:33
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:25
Definition device_base.hpp:197
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:507
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:651
void Print() const
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:623
const ADataType * p_a_grid_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:634
Argument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, std::array< void *, NumRTensor > p_rs, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_lengths, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const QsElementwiseOperation &qs_element_op, const RsElementwiseOperation &rs_element_op)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:508
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:673
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:669
std::array< index_t, NDimSpatial > conv_filter_dilations_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:675
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:667
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:676
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:636
BGridDesc_N_K b_grid_desc_n_k_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:644
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:677
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:668
RGridDesc_M r_grid_desc_m_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:647
ComputePtrOffsetOfStridedBatch< NumDTensor, NumRTensor > compute_ptr_offset_of_batch_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:656
QsElementwiseOperation qs_element_op_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:662
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:670
EGridDesc_M_N ds_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:645
EGridDesc_M_N e_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:646
EDataType * p_e_grid_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:637
CDEElementwiseOperation cde_element_op_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:661
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:672
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:640
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:674
GridwiseGemm64::RsGridPointer p_rs_grid_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:638
BElementwiseOperation b_element_op_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:660
Block2ETileMap block_2_etile_map_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:654
AElementwiseOperation a_element_op_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:659
const BDataType * p_b_grid_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:635
AGridDesc_M_K a_grid_desc_m_k_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:643
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:666
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:650
RsElementwiseOperation rs_element_op_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:663
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:671
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:682
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:686
DeviceOp::Argument Argument
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:683
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:795
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:305
DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle DeviceOp
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:306
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:496
static auto MakeRGridDescriptor_M(const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_lengths, const std::array< index_t, NDimSpatial+2 > &)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:390
GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation, ThreadReduceOperations, InMemoryDataOperationEnum::Set, RsGlobalMemoryDataOperation, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, RGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched > GridwiseGemmBase
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:440
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:998
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:325
static constexpr auto I1
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:315
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization > ConvToGemmFwdTransformer
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:319
static auto MakeRGridDescriptor_M(const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_lengths, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_strides)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:412
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:493
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:499
static constexpr auto I3
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:317
remove_cvref_t< decltype(MakeEGridDescriptor_M_N< DELayout >(dummy_conv_to_gemm_transformer))> EGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:434
static constexpr index_t NumDTensor
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:311
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, std::array< void *, NumRTensor > p_rs, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_lengths, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const QsElementwiseOperation &qs_element_op, const RsElementwiseOperation &rs_element_op) override
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:1057
static constexpr index_t NumRTensor
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:312
static constexpr auto NXdlPerWave32
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:309
static constexpr ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:429
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:802
static constexpr auto I2
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:316
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:337
remove_cvref_t< decltype(MakeBGridDescriptor_N_K< BLayout >(dummy_conv_to_gemm_transformer))> BGridDesc_N_K
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:432
static constexpr auto matrix_padder
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:321
remove_cvref_t< decltype(MakeAGridDescriptor_M_K< ALayout >(dummy_conv_to_gemm_transformer))> AGridDesc_M_K
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:430
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:308
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:494
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:349
static auto MakeInvoker()
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:1055
remove_cvref_t< decltype(MakeRGridDescriptor_M< RLayout >({}, {}))> RGridDesc_M
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:436
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:1109
static constexpr auto I0
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:314
typename GridwiseGemm64::DefaultBlock2ETileMap Block2ETileMap
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:503
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, std::array< void *, NumRTensor > p_rs, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_lengths, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const QsElementwiseOperation &qs_element_op, const RsElementwiseOperation &rs_element_op)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:1003
static auto GetPaddedRGridDescriptor(Descriptor descriptor, index_t MRaw)
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:361
std::string GetTypeString() const override
Definition device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp:1114
Definition device_grouped_conv_fwd_multiple_d_multiple_r.hpp:42
Definition matrix_padder.hpp:180