device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp Source File

device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp Source File#

Composable Kernel: device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp Source File
device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <numeric>
8#include <sstream>
9
11
25
28
29namespace ck {
30namespace tensor_operation {
31namespace device {
32
33template <typename GridwiseGemm,
34 typename FloatA,
35 typename FloatB,
36 typename FloatC,
37 typename AElementwiseOperation,
38 typename BElementwiseOperation,
39 typename CElementwiseOperation,
40 typename AGridDesc_B_K0_M_K1,
41 typename BGridDesc_B_K0_N_K1,
42 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
43 typename Block2CTileMap,
44 typename ComputePtrOffsetOfBatch,
45 bool HasMainKBlockLoop>
46__global__ void
47#if CK_USE_LAUNCH_BOUNDS
49#endif
50 kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid,
51 const FloatB* __restrict__ p_b_grid,
52 FloatC* __restrict__ p_c_grid,
53 const AElementwiseOperation a_element_op,
54 const BElementwiseOperation b_element_op,
55 const CElementwiseOperation c_element_op,
56 const index_t batch_count,
57 const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
58 const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
59 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
60 c_grid_desc_mblock_mperblock_nblock_nperblock,
61 const Block2CTileMap block_2_ctile_map,
62 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
63{
64#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
65 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
66 {
67 const index_t num_blocks_per_batch =
68 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
69 const index_t g_idx =
70 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
71
72 const long_index_t a_batch_offset =
73 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
74 const long_index_t b_batch_offset =
75 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
76 const long_index_t c_batch_offset =
77 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
78
79 __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
80
81 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
82 p_b_grid + b_batch_offset,
83 p_c_grid + c_batch_offset,
84 p_shared,
85 a_b_k0_m_k1_grid_desc,
86 b_b_k0_n_k1_grid_desc,
87 c_grid_desc_mblock_mperblock_nblock_nperblock,
88 a_element_op,
89 b_element_op,
90 c_element_op,
91 block_2_ctile_map);
92 }
93#else
94 ignore = p_a_grid;
95 ignore = p_b_grid;
96 ignore = p_c_grid;
97 ignore = a_b_k0_m_k1_grid_desc;
98 ignore = b_b_k0_n_k1_grid_desc;
99 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
100 ignore = a_element_op;
101 ignore = b_element_op;
102 ignore = c_element_op;
103 ignore = batch_count;
104 ignore = block_2_ctile_map;
105 ignore = compute_ptr_offset_of_batch;
106
107 compute_ptr_offset_of_batch.GetAPtrOffset(0);
108 compute_ptr_offset_of_batch.GetBPtrOffset(0);
109 compute_ptr_offset_of_batch.GetCPtrOffset(0);
110#endif // end of if (defined(__gfx9__))
111}
112
113template <index_t NDimSpatial,
114 typename InLayout,
115 typename WeiLayout,
116 typename OutLayout,
117 typename DsLayout,
118 typename InDataType,
119 typename WeiDataType,
120 typename OutDataType,
121 typename AccDataType,
122 typename DsDataType,
123 typename InElementwiseOperation,
124 typename WeiElementwiseOperation,
125 typename OutElementwiseOperation,
126 ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization,
127 ck::index_t BlockSize,
128 ck::index_t MPerBlock,
129 ck::index_t NPerBlock,
130 ck::index_t K0PerBlock,
131 ck::index_t K1,
132 ck::index_t MPerXDL,
133 ck::index_t NPerXDL,
134 ck::index_t MXdlPerWave,
135 ck::index_t NXdlPerWave,
136 typename ABlockTransferThreadClusterLengths_K0_M_K1,
137 typename ABlockTransferThreadClusterArrangeOrder,
138 typename ABlockTransferSrcAccessOrder,
139 ck::index_t ABlockTransferSrcVectorDim,
140 ck::index_t ABlockTransferSrcScalarPerVector,
141 ck::index_t ABlockTransferDstScalarPerVector_K1,
142 bool ABlockLdsAddExtraM,
143 typename BBlockTransferThreadClusterLengths_K0_N_K1,
144 typename BBlockTransferThreadClusterArrangeOrder,
145 typename BBlockTransferSrcAccessOrder,
146 ck::index_t BBlockTransferSrcVectorDim,
147 ck::index_t BBlockTransferSrcScalarPerVector,
148 ck::index_t BBlockTransferDstScalarPerVector_K1,
149 bool BBlockLdsAddExtraN,
150 index_t CShuffleMXdlPerWavePerShuffle,
151 index_t CShuffleNXdlPerWavePerShuffle,
152 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
153 index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
154 typename ComputeTypeA = InDataType,
155 typename ComputeTypeB = ComputeTypeA>
157 : public DeviceGroupedConvBwdWeightMultipleD<NDimSpatial,
158 InLayout,
159 WeiLayout,
160 OutLayout,
161 DsLayout,
162 InDataType,
163 WeiDataType,
164 OutDataType,
165 DsDataType,
166 InElementwiseOperation,
167 WeiElementwiseOperation,
168 OutElementwiseOperation,
169 ComputeTypeA,
170 ComputeTypeB>
171{
174 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
175 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
176
177 using ADataType = OutDataType;
178 using BDataType = InDataType;
179 using EDataType = WeiDataType;
180
181 static constexpr index_t NumDTensor = DsLayout::Size();
182
183 using AElementwiseOperation = OutElementwiseOperation;
184 using BElementwiseOperation = InElementwiseOperation;
185 using CDEElementwiseOperation = WeiElementwiseOperation;
186
187 // TODO make A/B datatype different
188 using ABDataType = InDataType;
189
190 static constexpr auto I0 = Number<0>{};
191 static constexpr auto I1 = Number<1>{};
192 static constexpr auto I2 = Number<2>{};
193 static constexpr auto I3 = Number<3>{};
194 static constexpr auto I4 = Number<4>{};
195 static constexpr auto I5 = Number<5>{};
196
197 static constexpr auto K1Number = Number<K1>{};
198
199 static constexpr auto conv_to_gemm_transformer =
201 MPerBlock,
202 NPerBlock,
203 K1Number,
204 K0PerBlock,
205 ConvBackwardWeightSpecialization>{};
206
207 static constexpr index_t MaxScalarPerVectorFP32 = 4;
210 ? math::min(CBlockTransferScalarPerVector_NWaveNPerXdl, MaxScalarPerVectorFP32)
211 : CBlockTransferScalarPerVector_NWaveNPerXdl;
212
213 // Bytes per 32 lds bank: 32 * 4 bytes
214 static constexpr auto BankLength = 128;
215 static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
216
217 // M1 & M0
218 static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1;
219 static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock;
220 static constexpr auto ABlockLdsM1Padding = 4;
221
222 // N1 & N0
223 static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1;
224 static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
225 static constexpr auto BBlockLdsN1Padding = 4;
226
227 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
228 static auto GetABCGridDesc()
229 {
230 const ck::index_t dim = 1;
231 const ck::index_t batch = 1;
232 const std::array<ck::index_t, NDimSpatial> lengths{1};
233 const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1};
234 const std::array<ck::index_t, NDimSpatial> params{1};
235 return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
236 dim,
237 dim,
238 dim,
239 lengths,
240 lengths,
241 lengths,
242 strides,
243 strides,
244 strides,
245 params,
246 params,
247 params,
248 params,
249 batch);
250 }
251
252 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
253 static auto GetABCGridDesc()
254 {
255 const ck::index_t dim = 1;
256 const ck::index_t batch = 1;
257 const std::array<ck::index_t, NDimSpatial> lengths{1, 1};
258 const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1};
259 const std::array<ck::index_t, NDimSpatial> params{1, 1};
260 return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
261 dim,
262 dim,
263 dim,
264 lengths,
265 lengths,
266 lengths,
267 strides,
268 strides,
269 strides,
270 params,
271 params,
272 params,
273 params,
274 batch);
275 }
276
277 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
278 static auto GetABCGridDesc()
279 {
280 const ck::index_t dim = 1;
281 const ck::index_t batch = 1;
282 const std::array<ck::index_t, NDimSpatial> lengths{1, 1, 1};
283 const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1, 1};
284 const std::array<ck::index_t, NDimSpatial> params{1, 1, 1};
285 return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(
286 dim,
287 dim,
288 dim,
289 lengths,
290 lengths,
291 lengths,
292 strides,
293 strides,
294 strides,
295 params,
296 params,
297 params,
298 params,
299 batch);
300 }
301
303
307
308 template <index_t NXdlPerWave_>
310 BlockSize,
311 ADataType,
312 BDataType,
313 AccDataType,
314 AccDataType,
322 MPerBlock,
323 NPerBlock,
324 K0PerBlock,
325 MPerXDL,
326 NPerXDL,
327 K1,
328 MXdlPerWave,
329 NXdlPerWave_,
330 ABlockTransferThreadClusterLengths_K0_M_K1,
331 ABlockTransferThreadClusterArrangeOrder,
332 ABlockTransferSrcAccessOrder,
333 ABlockTransferSrcVectorDim,
334 ABlockTransferSrcScalarPerVector,
335 ABlockTransferDstScalarPerVector_K1,
336 false, // AThreadTransferSrcResetCoordinateAfterRun,
337 ABlockLdsAddExtraM,
341 BBlockTransferThreadClusterLengths_K0_N_K1,
342 BBlockTransferThreadClusterArrangeOrder,
343 BBlockTransferSrcAccessOrder,
344 BBlockTransferSrcVectorDim,
345 BBlockTransferSrcScalarPerVector,
346 BBlockTransferDstScalarPerVector_K1,
347 false, // BThreadTransferSrcResetCoordinateAfterRun,
348 BBlockLdsAddExtraN,
352 CShuffleMXdlPerWavePerShuffle,
353 CShuffleNXdlPerWavePerShuffle,
355 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
356 true,
357 true,
358 1,
360 ComputeTypeA,
361 ComputeTypeB>;
364
365 static constexpr auto MakeElementwiseInputSequence()
366 {
368 [&](auto) constexpr { return Number<WorkspaceInOutScalarPerVector>{}; },
370 }
371
372 static constexpr auto GetDsGridPointerTuple()
373 {
374 return generate_tuple(
375 [&](auto i) {
376 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
377 return static_cast<const DDataType*>(nullptr);
378 },
380 }
381
382 template <index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
384 const std::array<std::array<index_t, NDim + 3>, NumDTensor>& ds_g_k_c_xs_lengths,
385 const std::array<std::array<index_t, NDim + 3>, NumDTensor>& ds_g_k_c_xs_strides)
386 {
387 return generate_tuple(
388 [&](auto i) {
389 const index_t K = ds_g_k_c_xs_lengths[i][I1];
390 const index_t C = ds_g_k_c_xs_lengths[i][I2];
391 const index_t X = ds_g_k_c_xs_lengths[i][I3];
392 const index_t CStride = ds_g_k_c_xs_strides[I2];
393 const index_t KStride = ds_g_k_c_xs_strides[I1];
394
395 const auto wei_grid_desc = make_naive_tensor_descriptor(
396 make_tuple(K, X * C), make_tuple(KStride, CStride));
397
398 if constexpr(ConvBackwardWeightSpecialization ==
400 {
401 return wei_grid_desc;
402 }
403 else
404 {
405 const index_t GemmM = K;
406 const index_t GemmN = C * X;
407 const auto PadGemmM =
408 GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
409 const auto PadGemmN =
410 GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
411
413 wei_grid_desc,
414 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
415 make_right_pad_transform(GemmN, PadGemmN)),
418 }
419 },
421 }
422
423 template <index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
425 const std::array<std::array<index_t, NDim + 3>, NumDTensor>& ds_g_k_c_xs_lengths,
426 const std::array<std::array<index_t, NDim + 3>, NumDTensor>& ds_g_k_c_xs_strides)
427 {
428 return generate_tuple(
429 [&](auto i) {
430 const index_t K = ds_g_k_c_xs_lengths[i][I1];
431 const index_t C = ds_g_k_c_xs_lengths[i][I2];
432 const index_t Y = ds_g_k_c_xs_lengths[i][I3];
433 const index_t X = ds_g_k_c_xs_lengths[i][I4];
434
435 const auto wei_grid_desc =
436 conv_to_gemm_transformer.template make_wei_grid_desc<NDim>(
437 K, Y, X, C, ds_g_k_c_xs_strides[i]);
438
439 if constexpr(ConvBackwardWeightSpecialization ==
441 {
442 return wei_grid_desc;
443 }
444 else
445 {
446 const index_t GemmM = K;
447 const index_t GemmN = C * X * Y;
448 const auto PadGemmM =
449 GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
450 const auto PadGemmN =
451 GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
452
454 wei_grid_desc,
455 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
456 make_right_pad_transform(GemmN, PadGemmN)),
459 }
460 },
462 }
463
464 template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
466 const std::array<std::array<index_t, NDim + 3>, NumDTensor>& ds_g_k_c_xs_lengths,
467 const std::array<std::array<index_t, NDim + 3>, NumDTensor>& ds_g_k_c_xs_strides)
468 {
469 return generate_tuple(
470 [&](auto i) {
471 const index_t K = ds_g_k_c_xs_lengths[i][I1];
472 const index_t C = ds_g_k_c_xs_lengths[i][I2];
473 const index_t Z = ds_g_k_c_xs_lengths[i][I3];
474 const index_t Y = ds_g_k_c_xs_lengths[i][I4];
475 const index_t X = ds_g_k_c_xs_lengths[i][I5];
476
477 const auto wei_grid_desc =
478 conv_to_gemm_transformer.template make_wei_grid_desc<NDim>(
479 K, Z, Y, X, C, ds_g_k_c_xs_strides[i]);
480
481 if constexpr(ConvBackwardWeightSpecialization ==
483 {
484 return wei_grid_desc;
485 }
486 else
487 {
488 const index_t GemmM = K;
489 const index_t GemmN = C * X * Y * Z;
490 const auto PadGemmM =
491 GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
492 const auto PadGemmN =
493 GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
494
496 wei_grid_desc,
497 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
498 make_right_pad_transform(GemmN, PadGemmN)),
501 }
502 },
504 }
505
506 template <typename ComputePtrOffsetOfBatch>
507 static void
508 InitElementwiseBatchStrides(const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch_,
509 std::array<index_t, NumDTensor + I1>& input_batch_strides,
510 std::array<index_t, I1>& output_batch_strides)
511 {
512 input_batch_strides[I0] = compute_ptr_offset_of_batch_.BatchStrideC_;
513 output_batch_strides[I0] = compute_ptr_offset_of_batch_.BatchStrideC_;
514
515 // input_batch_strides = {C, Ds...}
516 static_for<0, NumDTensor, 1>{}([&](auto i) {
517 input_batch_strides[i + 1] = compute_ptr_offset_of_batch_.BatchStrideDs_[i];
518 });
519 }
520
527 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
529 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
531
539 BlockSize,
540 MPerBlock,
541 NPerBlock,
542 MPerBlock / ClusterLengthMPerBlock,
543 NPerBlock / ClusterLengthNPerBlock,
547 I1,
548 I1>;
549
550 // Argument
553
556
558 {
559 template <typename GridwiseGemm>
561 {
562 constexpr int dynamic_smem_size = 0;
563 int max_occupancy = 0;
564 hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
565 &max_occupancy,
567 GridwiseGemm,
568 ADataType,
569 BDataType,
570 AccDataType,
571 OutElementwiseOperation,
572 InElementwiseOperation,
578 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
579 true>,
580 BlockSize,
581 dynamic_smem_size));
582 return std::max(1, max_occupancy);
583 }
584
586 {
587 max_occupancy_ = 1;
588 if(get_warp_size() == 64)
589 {
590 if constexpr(NXdlPerWave64 > 0)
591 {
593 }
594 }
595 else
596 {
597 if constexpr(NXdlPerWave32 > 0)
598 {
600 }
601 }
602 }
604 };
605
606 struct Argument : public BaseArgument, public ArgumentSplitK
607 {
609 const InDataType* p_in_grid,
610 WeiDataType* p_wei_grid,
611 const OutDataType* p_out_grid,
612 const std::array<const void*, NumDTensor>& p_ds,
613 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
614 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
615 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
616 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
617 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
618 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
619 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_lengths,
620 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_strides,
621 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
622 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
623 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
624 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
625 const ck::index_t M01,
626 const ck::index_t N01,
627 InElementwiseOperation in_element_op,
628 WeiElementwiseOperation wei_element_op,
629 OutElementwiseOperation out_element_op,
630 ck::index_t split_k)
631 : p_a_grid_{p_out_grid},
632 p_b_grid_{p_in_grid},
633 p_ds_grid_{},
634 p_e_grid_{p_wei_grid},
640 M01_{M01},
641 N01_{N01},
642 a_element_op_{out_element_op},
643 b_element_op_{in_element_op},
644 cde_element_op_{wei_element_op},
645 Conv_G_{b_g_n_c_wis_lengths[0]},
646 Conv_N_{b_g_n_c_wis_lengths[1]},
647 Conv_K_{e_g_k_c_xs_lengths[1]},
648 Conv_C_{b_g_n_c_wis_lengths[2]},
652 conv_filter_strides_{conv_filter_strides},
653 input_left_pads_{input_left_pads},
654 input_right_pads_{input_right_pads}
655 {
656 static ActiveWorkgroupsPerCU active_workgroups_per_cu;
657
660 e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
661 sizeof(AccDataType);
662
663 constexpr index_t spatial_offset = 3;
664 std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
665 end(b_g_n_c_wis_lengths),
667 std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset,
668 end(e_g_k_c_xs_lengths),
670 std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset,
671 end(a_g_n_k_wos_lengths),
673
674#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
675 if(split_k < 0)
676 {
677 ck::index_t gemmM, gemmN;
678 std::tie(gemmM, gemmN, std::ignore) =
679 get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
680
681 const auto grid_size =
684 grid_size);
685 }
686 else
687#endif
688 {
689 k_batch_ = split_k;
690 }
691
692 const auto descs =
694 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
695 Conv_N_,
696 Conv_K_,
697 Conv_C_,
701 b_g_n_c_wis_strides,
702 e_g_k_c_xs_strides,
703 a_g_n_k_wos_strides,
704 conv_filter_strides,
705 conv_filter_dilations,
706 input_left_pads,
707 input_right_pads,
708 k_batch_);
709
710 static_for<0, NumDTensor, 1>{}([&](auto i) {
711 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
712 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
713
714 static_assert(is_same_v<DLayout, WeiLayout>, "Not supported D data layout");
715
716 // D pointer
717 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
718 compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_k_c_xs_strides[i][0];
719 });
720
723 ce_grid_desc_m_n_ = descs[I2];
724
726 MakeDsGridDescriptor_M_N<NDimSpatial>(ds_g_k_c_xs_lengths, ds_g_k_c_xs_strides);
727
731 ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)};
732
733 // A/B/C Batch Stride
734 compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
735 compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0];
736 compute_ptr_offset_of_batch_.BatchStrideC_ =
737 Conv_K_ * Conv_C_ *
738 std::accumulate(begin(filter_spatial_lengths_),
740 index_t{1},
741 std::multiplies<>{});
742 }
743
744 std::size_t GetWorkspaceSizeBytes() const
745 {
746 return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
747 }
748
753
758
761
762 // for computing batch offset
763 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
764
767
768 OutElementwiseOperation a_element_op_;
769 InElementwiseOperation b_element_op_;
770 WeiElementwiseOperation cde_element_op_;
771
772 // for checking IsSupportedArgument()
777 std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
778 std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
779 std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
780 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
781 const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
782 const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
784 };
785
786 // Invoker
787 struct Invoker : public BaseInvoker
788 {
790
791 void ShowInfo(const Argument& arg)
792 {
793 std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
794 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
795 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
796 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
797 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
798
799 std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
800 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
801 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
802 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
803 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
804
805 std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", "
806 << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
807 }
808
809 template <typename GridwiseGemm>
810 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
811 {
812 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
816 {
817 throw std::runtime_error(
818 "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
819 }
820
821 auto c_grid_desc_mblock_mperblock_nblock_nperblock =
822 GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(
824 const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
825 const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
826
827 auto launch_gemm_kernel = [&](auto has_main_k_block_loop) {
828 AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
829 const index_t grid_size =
830 arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_;
831
832 constexpr bool has_main_loop = has_main_k_block_loop.value;
833
834 auto preprocess = [&]() {
835 hip_check_error(hipMemsetAsync(
836 p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
837 };
838
839 const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
840 GridwiseGemm,
841 ADataType,
842 BDataType,
843 AccDataType,
844 OutElementwiseOperation,
845 InElementwiseOperation,
851 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
852 has_main_loop>;
853
855 stream_config,
856 preprocess,
857 kernel,
858 dim3(grid_size),
859 dim3(BlockSize),
860 0,
861 arg.p_a_grid_,
862 arg.p_b_grid_,
863 p_c_grid,
864 arg.a_element_op_,
865 arg.b_element_op_,
867 arg.Conv_G_,
870 c_grid_desc_mblock_mperblock_nblock_nperblock,
873 };
874
875 auto launch_elementwise_kernel = [&]() {
876 const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
877 const index_t grid_size =
878 arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) *
879 arg.Conv_G_;
880
881 std::array<index_t, NumDTensor + I1> input_batch_strides;
882 std::array<index_t, I1> output_batch_strides;
884 arg.compute_ptr_offset_of_batch_, input_batch_strides, output_batch_strides);
885
893 NumDTensor + I1,
894 I1>;
895
897 stream_config,
898 kernel,
899 dim3(grid_size),
900 dim3(BlockSize),
901 0,
904 concat_tuple(make_tuple(p_c_grid), arg.p_ds_grid_),
905 arg.p_e_grid_,
907 arg.cde_element_op_,
908 arg.Conv_G_,
909 input_batch_strides,
910 output_batch_strides);
911 };
912
913 float avg_time = 0;
914 if(has_main_k0_block_loop)
915 {
916 avg_time = launch_gemm_kernel(integral_constant<bool, true>{});
917 }
918 else
919 {
920 avg_time = launch_gemm_kernel(integral_constant<bool, false>{});
921 }
922
923 avg_time += launch_elementwise_kernel();
924 return avg_time;
925 }
926
928
929 float Run(const BaseArgument* p_arg,
930 const StreamConfig& stream_config = StreamConfig{}) override
931 {
932 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
933 }
934 };
935
936 static constexpr bool IsValidCompilationParameter()
937 {
938 // TODO: properly implement this check
939 return true;
940 }
941
942 static bool IsSupportedArgument(const Argument& arg)
943 {
944#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
945 if(arg.k_batch_ < 0)
946 {
947 return false;
948 }
949#endif
951 {
952 return false;
953 }
955 {
956 if(!is_tf32_supported())
957 {
958 return false;
959 }
961 {
962 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
963 {
964 std::cout << "ComputeDataType for A and B should be same while using TF32"
965 << std::endl;
966 }
967 return false;
968 }
969 }
970 if constexpr(NDimSpatial == 1)
971 {
973 {
974 return false;
975 }
976 }
977 else if constexpr(NDimSpatial == 2)
978 {
981 {
982 return false;
983 }
984 }
985 else if constexpr(NDimSpatial == 3)
986 {
989 {
990 return false;
991 }
992 }
993 else
994 {
995 return false;
996 }
997
998 if constexpr(ConvBackwardWeightSpecialization ==
1000 {
1001 // check if it's 1x1, stride=1 pad = 0 conv
1002 for(int i = 0; i < NDimSpatial; i++)
1003 {
1004 if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
1005 arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
1006 {
1007 return false;
1008 }
1009 }
1010 }
1011
1012 // vector load A/B matrix from global memory
1013 if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
1014 arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
1015 arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
1016 {
1017 return false;
1018 }
1019
1020 // vector store C matrix into global memory
1021 if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0 &&
1023 {
1024 return false;
1025 }
1026
1027 // Gridwise GEMM size
1028 if(get_warp_size() == 64)
1029 {
1030 if constexpr(NXdlPerWave64 > 0)
1031 {
1035 arg.block_2_ctile_map_);
1036 }
1037 }
1038 else
1039 {
1040 if constexpr(NXdlPerWave32 > 0)
1041 {
1045 arg.block_2_ctile_map_);
1046 }
1047 }
1048 return false;
1049 }
1050
1051 bool IsSupportedArgument(const BaseArgument* p_arg) override
1052 {
1053 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1054 }
1055
1056 static auto MakeArgument(
1057 const InDataType* p_in_grid,
1058 WeiDataType* p_wei_grid,
1059 const OutDataType* p_out_grid,
1060 const std::array<const void*, NumDTensor>& p_ds,
1061 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
1062 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
1063 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
1064 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
1065 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
1066 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
1067 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_lengths,
1068 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_strides,
1069 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
1070 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
1071 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
1072 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
1073 InElementwiseOperation in_element_op,
1074 WeiElementwiseOperation wei_element_op,
1075 OutElementwiseOperation out_element_op,
1076 const ck::index_t split_k)
1077 {
1078 return Argument{p_in_grid,
1079 p_wei_grid,
1080 p_out_grid,
1081 p_ds,
1082 b_g_n_c_wis_lengths, // input
1083 b_g_n_c_wis_strides,
1084 e_g_k_c_xs_lengths, // weight
1085 e_g_k_c_xs_strides,
1086 a_g_n_k_wos_lengths, // output
1087 a_g_n_k_wos_strides,
1088 ds_g_k_c_xs_lengths,
1089 ds_g_k_c_xs_strides,
1090 conv_filter_strides,
1091 conv_filter_dilations,
1092 input_left_pads,
1093 input_right_pads,
1094 1,
1095 1,
1096 in_element_op,
1097 wei_element_op,
1098 out_element_op,
1099 split_k};
1100 }
1101
1102 static auto MakeInvoker() { return Invoker{}; }
1103
1104 std::unique_ptr<BaseArgument> MakeArgumentPointer(
1105 const void* p_in_grid,
1106 void* p_wei_grid,
1107 const void* p_out_grid,
1108 const std::array<const void*, NumDTensor>& p_ds,
1109 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
1110 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
1111 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
1112 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
1113 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
1114 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
1115 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_lengths,
1116 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_strides,
1117 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
1118 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
1119 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
1120 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
1121 InElementwiseOperation in_element_op,
1122 WeiElementwiseOperation wei_element_op,
1123 OutElementwiseOperation out_element_op,
1124 const ck::index_t split_k) override
1125 {
1126 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
1127 static_cast<WeiDataType*>(p_wei_grid),
1128 static_cast<const OutDataType*>(p_out_grid),
1129 p_ds,
1130 b_g_n_c_wis_lengths, // input
1131 b_g_n_c_wis_strides,
1132 e_g_k_c_xs_lengths, // weight
1133 e_g_k_c_xs_strides,
1134 a_g_n_k_wos_lengths, // output
1135 a_g_n_k_wos_strides,
1136 ds_g_k_c_xs_lengths,
1137 ds_g_k_c_xs_strides,
1138 conv_filter_strides,
1139 conv_filter_dilations,
1140 input_left_pads,
1141 input_right_pads,
1142 1,
1143 1,
1144 in_element_op,
1145 wei_element_op,
1146 out_element_op,
1147 split_k);
1148 }
1149
1150 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1151 {
1152 return std::make_unique<Invoker>(Invoker{});
1153 }
1154
1155 std::string GetTypeString() const override
1156 {
1157 auto str = std::stringstream();
1158
1159 // clang-format off
1160 str << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle"
1161 << "<"
1162 << BlockSize << ", "
1163 << MPerBlock << ", "
1164 << NPerBlock << ", "
1165 << K0PerBlock << ", "
1166 << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
1167 << K1 << ", "
1168 << MXdlPerWave << ", "
1169 << NXdlPerWave << ", "
1170 << ABlockTransferSrcScalarPerVector << ", "
1171 << ABlockTransferDstScalarPerVector_K1 << ", "
1172 << BBlockTransferSrcScalarPerVector << ", "
1173 << BBlockTransferDstScalarPerVector_K1 << ", "
1174 << CShuffleMXdlPerWavePerShuffle << ", "
1175 << CShuffleNXdlPerWavePerShuffle << ", "
1176 << CBlockTransferScalarPerVector_NWaveNPerXdl
1177 << ">";
1178 // clang-format on
1179
1180 return str.str();
1181 }
1182
1183 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
1184 {
1185 auto arg = dynamic_cast<const Argument*>(p_arg);
1186 if(arg)
1187 {
1188 return arg->GetWorkspaceSizeBytes();
1189 }
1190 else
1191 throw std::runtime_error(
1192 "The argument pointer is not an object of "
1193 "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle::Argument structure!");
1194 }
1195
1197 void* p_workspace,
1198 const StreamConfig& = StreamConfig{}) const override
1199 {
1200 auto p_arg_ = dynamic_cast<Argument*>(p_arg);
1201 if(p_arg_)
1202 {
1203 p_arg_->p_workspace_ = p_workspace;
1204 }
1205 else
1206 throw std::runtime_error(
1207 "The argument pointer is not an object of "
1208 "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle::Argument structure!");
1209 }
1210};
1211
1212} // namespace device
1213} // namespace tensor_operation
1214} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
auto get_bwd_weight_gemm_sizes(const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths)
Definition split_k_utils.hpp:55
ConvolutionBackwardWeightSpecialization
Definition convolution_backward_weight_specialization.hpp:13
@ Filter1x1Stride1Pad0
Definition convolution_backward_weight_specialization.hpp:15
constexpr bool is_GNWC_GKXC_GNWK()
Definition device_grouped_conv_utils.hpp:23
__global__ void kernel_batched_gemm_xdlops_bwd_weight(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const index_t batch_count, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:50
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
Definition device_grouped_conv_utils.hpp:88
constexpr bool is_NHWGC_GKYXC_NHWGK()
Definition device_grouped_conv_utils.hpp:40
ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
Definition split_k_utils.hpp:30
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition device_grouped_conv_utils.hpp:80
std::string getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization &s)
Definition convolution_backward_weight_specialization.hpp:21
ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
Definition split_k_utils.hpp:84
constexpr bool is_GNHWC_GKYXC_GNHWK()
Definition device_grouped_conv_utils.hpp:48
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__global__ void kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op, const index_t batch_count, const std::array< index_t, NumInputs > input_batch_strides, const std::array< index_t, NumOutputs > output_batch_strides)
Definition gridwise_elementwise_2d.hpp:221
__host__ __device__ constexpr auto concat_tuple(const Tuple< X... > &tx, const Tuple< Y... > &ty)
Definition tuple_helper.hpp:52
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
bool is_tf32_supported()
Definition host_utility/device_prop.hpp:132
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:261
Definition gridwise_elementwise_2d.hpp:278
Definition gridwise_gemm_xdlops_bwd_weight.hpp:254
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
Definition functional2.hpp:33
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:24
index_t k_batch_
Definition split_k_arg.hpp:12
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:558
int GetMaxOccupancy()
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:560
int max_occupancy_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:603
ActiveWorkgroupsPerCU()
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:585
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:607
std::array< ck::index_t, NDimSpatial > input_spatial_lengths_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:777
InElementwiseOperation b_element_op_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:769
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:755
const index_t Conv_G_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:773
std::size_t GetWorkspaceSizeBytes() const
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:744
CGridDesc_M_N ce_grid_desc_m_n_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:756
long_index_t c_space_size_bytes
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:783
const BDataType * p_b_grid_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:750
const std::array< ck::index_t, NDimSpatial > & input_right_pads_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:782
index_t N01_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:766
ComputePtrOffsetOfStridedBatch< I1, I1, NumDTensor > compute_ptr_offset_of_batch_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:763
Argument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< const void *, NumDTensor > &p_ds, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_k_c_xs_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_k_c_xs_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, const ck::index_t M01, const ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:608
std::array< ck::index_t, NDimSpatial > output_spatial_lengths_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:779
OutElementwiseOperation a_element_op_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:768
const std::array< ck::index_t, NDimSpatial > & input_left_pads_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:781
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:754
DsGridPointerTuple p_ds_grid_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:751
const index_t Conv_K_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:775
const ADataType * p_a_grid_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:749
EDataType * p_e_grid_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:752
index_t M01_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:765
DsGridDesc_M_N ds_grid_descs_tuple_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:757
std::array< ck::index_t, NDimSpatial > filter_spatial_lengths_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:778
const std::array< ck::index_t, NDimSpatial > & conv_filter_strides_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:780
Block2CTileMap block_2_ctile_map_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:759
const index_t Conv_N_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:774
WeiElementwiseOperation cde_element_op_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:770
Block2TileMapElementwise elementwise_block_2_ctile_map_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:760
const index_t Conv_C_
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:776
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:788
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:929
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:810
void ShowInfo(const Argument &arg)
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:791
DeviceOp::Argument Argument
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:789
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:171
InDataType BDataType
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:178
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:174
decltype(concat_tuple(Tuple< const AccDataType * >{}, DsGridPointerTuple{})) CDDataTypes
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:524
static constexpr auto BBlockLdsN1PerBlock
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:223
static constexpr auto conv_to_gemm_transformer
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:199
decltype(MakeDsGridDescriptor_M_N< NDimSpatial >({}, {})) DsGridDesc_M_N
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:521
static constexpr index_t WorkspaceInOutScalarPerVector
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:208
static constexpr auto NXdlPerWave32
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:175
static constexpr auto I5
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:195
static void InitElementwiseBatchStrides(const ComputePtrOffsetOfBatch &compute_ptr_offset_of_batch_, std::array< index_t, NumDTensor+I1 > &input_batch_strides, std::array< index_t, I1 > &output_batch_strides)
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:508
static constexpr index_t ClusterLengthNPerBlock
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:528
static constexpr auto I0
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:190
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:1051
static auto MakeArgument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< const void *, NumDTensor > &p_ds, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_k_c_xs_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_k_c_xs_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:1056
static auto MakeDsGridDescriptor_M_N(const std::array< std::array< index_t, NDim+3 >, NumDTensor > &ds_g_k_c_xs_lengths, const std::array< std::array< index_t, NDim+3 >, NumDTensor > &ds_g_k_c_xs_strides)
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:383
static constexpr auto K1Number
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:197
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, void *p_wei_grid, const void *p_out_grid, const std::array< const void *, NumDTensor > &p_ds, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_k_c_xs_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_k_c_xs_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const ck::index_t split_k) override
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:1104
static constexpr auto ABlockLdsM0PerBlock
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:219
static constexpr auto ElePerBank
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:215
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:304
static auto MakeInvoker()
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:1102
BlockToCTileMap_M00_N0_M01Adapt< MPerBlock, NPerBlock > Block2TileMapElementwise
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:530
decltype(GetDsGridPointerTuple()) DsGridPointerTuple
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:523
std::string GetTypeString() const override
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:1155
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:305
decltype(concat_tuple(Tuple< CGridDesc_M_N >{}, DsGridDesc_M_N{})) CDGridDesc_M_N
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:522
static constexpr auto ABlockLdsM1Padding
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:220
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:936
static constexpr auto ABlockLdsM1PerBlock
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:218
static constexpr auto I2
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:192
InDataType ABDataType
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:188
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, BDataType, AccDataType, AccDataType, InMemoryDataOperationEnum::AtomicAdd, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, element_wise::PassThrough, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, ABlockLdsM1PerBlock, ABlockLdsM0PerBlock, ABlockLdsM1Padding, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, BBlockLdsN1PerBlock, BBlockLdsN0PerBlock, BBlockLdsN1Padding, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, WorkspaceInOutScalarPerVector, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true, 1, PipelineVersion::v1, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:309
static constexpr index_t MaxScalarPerVectorFP32
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:207
static constexpr auto MakeElementwiseInputSequence()
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:365
static constexpr auto I1
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:191
static constexpr auto I4
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:194
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:1150
static constexpr index_t NumDTensor
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:181
OutElementwiseOperation AElementwiseOperation
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:183
decltype(GridwiseGemm64::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)) Block2CTileMap
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:554
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:1183
static constexpr auto GetDsGridPointerTuple()
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:372
static constexpr index_t ClusterLengthMPerBlock
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:526
decltype(GridwiseGemm64::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:551
WeiElementwiseOperation CDEElementwiseOperation
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:185
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle DeviceOp
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:172
WeiDataType EDataType
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:179
OutDataType ADataType
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:177
static auto GetABCGridDesc()
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:228
InElementwiseOperation BElementwiseOperation
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:184
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:942
static constexpr auto BBlockLdsN1Padding
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:225
GridwiseElementwise< CDGridDesc_M_N, Tuple< EGridDesc_M_N >, CDDataTypes, Tuple< EDataType * >, Block2TileMapElementwise, CDEElementwiseOperation, BlockSize, MPerBlock, NPerBlock, MPerBlock/ClusterLengthMPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 0, 1 >, decltype(MakeElementwiseInputSequence()), Sequence< CBlockTransferScalarPerVector_NWaveNPerXdl >, I1, I1 > GridwiseElementwise
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:532
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:363
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:306
static constexpr auto BBlockLdsN0PerBlock
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:224
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:302
CGridDesc_M_N EGridDesc_M_N
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:525
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:362
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:1196
static constexpr auto BankLength
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:214
static constexpr auto I3
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:193
Definition device_grouped_conv_bwd_weight_multiple_d.hpp:31
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129