27#ifdef CK_EXPERIMENTAL_BUILDER
28#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
62template <
typename GridwiseGemm,
66 typename AElementwiseOperation,
67 typename BElementwiseOperation,
68 typename CDEElementwiseOperation,
69 typename AGridDesc_K0_M0_M1_K1,
70 typename BGridDesc_K0_N0_N1_K1,
71 typename DsGridDesc_M0_M10_M11_N0_N10_N11,
72 typename CGridDesc_M0_M10_M11_N0_N10_N11,
73 typename Block2CTileMap,
74 typename ComputePtrOffsetOfBatch,
75 bool HasMainKBlockLoop,
76 bool HasDoubleTailKBlockLoop>
78#if CK_USE_LAUNCH_BOUNDS
81 kernel_grouped_conv_fwd_dl_multiple_d(
82 const ABDataType* __restrict__ p_a_grid,
83 const ABDataType* __restrict__ p_b_grid,
85 EDataType* __restrict__ p_e_grid,
86 const AElementwiseOperation a_element_op,
87 const BElementwiseOperation b_element_op,
88 const CDEElementwiseOperation cde_element_op,
90 const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
91 const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
92 const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
93 const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
94 const Block2CTileMap block_2_ctile_map,
95 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
97#if(defined(__gfx906__) || defined(__gfx103__) || defined(__gfx90a__) || defined(__gfx908__) || \
98 defined(__gfx94__) || defined(__gfx11__) || defined(__gfx12__))
100 const index_t num_blocks_per_batch =
101 __builtin_amdgcn_readfirstlane(
get_grid_size() / batch_count);
105 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
107 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
109 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
111 const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
113 constexpr index_t shared_block_size =
114 GridwiseGemm::GetSharedMemoryNumberOfByte() /
sizeof(ABDataType);
116 __shared__ ABDataType p_shared[shared_block_size];
118 DsPointer p_ds_grid_grp;
120 static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size();
122 static_for<0, NumDTensor, 1>{}(
123 [&](
auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
125 GridwiseGemm::Run(p_a_grid + a_batch_offset,
126 p_b_grid + b_batch_offset,
128 p_e_grid + c_batch_offset,
133 a_grid_desc_k0_m0_m1_k1,
134 b_grid_desc_k0_n0_n1_k1,
135 ds_grid_desc_m0_m10_m11_n0_n10_n11,
136 e_grid_desc_m0_m10_m11_n0_n10_n11,
138 integral_constant<bool, HasMainKBlockLoop>{},
139 integral_constant<bool, HasDoubleTailKBlockLoop>{});
149 ignore = a_grid_desc_k0_m0_m1_k1;
150 ignore = b_grid_desc_k0_n0_n1_k1;
151 ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
152 ignore = e_grid_desc_m0_m10_m11_n0_n10_n11;
153 ignore = compute_ptr_offset_of_batch;
154 ignore = block_2_ctile_map;
156 compute_ptr_offset_of_batch.GetAPtrOffset(0);
157 compute_ptr_offset_of_batch.GetBPtrOffset(0);
158 compute_ptr_offset_of_batch.GetEPtrOffset(0);
184 typename AccDataType,
189 typename AElementwiseOperation,
190 typename BElementwiseOperation,
191 typename CDEElementwiseOperation,
202 typename M1N1ThreadClusterM1Xs,
203 typename M1N1ThreadClusterN1Xs,
204 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
205 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
206 typename ABlockTransferThreadClusterArrangeOrder,
207 typename ABlockTransferSrcAccessOrder,
208 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
209 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
210 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
211 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
212 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
213 typename BBlockTransferThreadClusterArrangeOrder,
214 typename BBlockTransferSrcAccessOrder,
215 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
216 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
217 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
218 typename CThreadTransferSrcDstAccessOrder,
219 index_t CThreadTransferSrcDstVectorDim,
220 index_t CThreadTransferDstScalarPerVector>
231 AElementwiseOperation,
232 BElementwiseOperation,
233 CDEElementwiseOperation>
249 template <
typename ALay>
253 const auto in_gemmmraw_gemmkraw_desc =
254 conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
256 const auto in_gemmm_gemmk_desc =
259 const auto M = in_gemmm_gemmk_desc.GetLength(
I0);
260 const auto K = in_gemmm_gemmk_desc.GetLength(
I1);
261 const auto AK0 = K / K1;
270 template <
typename BLay>
274 const auto wei_gemmnraw_gemmkraw_desc =
275 conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
277 const auto wei_gemmn_gemmk_desc =
278 matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
280 const auto N = wei_gemmn_gemmk_desc.GetLength(
I0);
281 const auto K = wei_gemmn_gemmk_desc.GetLength(
I1);
283 const auto BK0 = K / K1;
286 wei_gemmn_gemmk_desc,
292 template <
typename ELay>
295 const auto out_gemmmraw_gemmnraw_desc =
296 conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
298 const auto out_gemmm_gemmn_desc =
299 matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
301 return out_gemmm_gemmn_desc;
333 AElementwiseOperation,
334 BElementwiseOperation,
335 CDEElementwiseOperation,
347 M1N1ThreadClusterM1Xs,
348 M1N1ThreadClusterN1Xs,
349 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
350 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
351 ABlockTransferThreadClusterArrangeOrder,
352 ABlockTransferSrcAccessOrder,
353 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
354 ABlockTransferSrcVectorTensorContiguousDimOrder,
355 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
356 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
357 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
358 BBlockTransferThreadClusterArrangeOrder,
359 BBlockTransferSrcAccessOrder,
360 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
361 BBlockTransferSrcVectorTensorContiguousDimOrder,
362 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
363 CThreadTransferSrcDstAccessOrder,
364 CThreadTransferSrcDstVectorDim,
365 CThreadTransferDstScalarPerVector>;
383 const std::array<const void*, NumDTensor>& p_ds,
385 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
386 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
387 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
388 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
389 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>&
390 ds_g_n_k_wos_lengths,
391 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>&
392 ds_g_n_k_wos_strides,
393 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
394 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
395 const std::array<index_t, NDimSpatial>& conv_filter_strides,
396 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
397 const std::array<index_t, NDimSpatial>& input_left_pads,
398 const std::array<index_t, NDimSpatial>& input_right_pads,
399 const AElementwiseOperation& a_element_op,
400 const BElementwiseOperation& b_element_op,
401 const CDEElementwiseOperation& cde_element_op)
402 :
p_a_grid_{static_cast<const ADataType*>(p_a)},
403 p_b_grid_{static_cast<const BDataType*>(p_b)},
414 conv_filter_dilations,
457 ds_g_n_k_wos_lengths[i],
458 ds_g_n_k_wos_strides[i],
460 conv_filter_dilations,
465 p_ds_grid_(i) =
static_cast<const DDataType*
>(p_ds[i]);
499 std::cout <<
"num_group: " <<
num_group_ << std::endl;
571 throw std::runtime_error(
572 "wrong! DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK has invalid setting");
580 auto launch_kernel = [&](
auto has_main_k_block_loop,
581 auto has_double_tail_k_block_loop) {
582 constexpr bool has_main_loop = has_main_k_block_loop.value;
583 constexpr bool has_double_loop = has_double_tail_k_block_loop;
585 const auto kernel = kernel_grouped_conv_fwd_dl_multiple_d<
590 AElementwiseOperation,
591 BElementwiseOperation,
592 CDEElementwiseOperation,
598 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
625 const bool has_double_tail_k_block_loop =
628 if(has_main_k_block_loop && has_double_tail_k_block_loop)
633 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
638 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
654 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
670 if constexpr(ConvForwardSpecialization ==
674 for(
index_t i = 0; i < NDimSpatial; ++i)
683 std::cout <<
"Filter1x1Stride1Pad0 check: XY_index = " << i <<
" X = " << X
684 <<
" ConvStride = " << ConvStride <<
" LeftPad = " <<
LeftPad
685 <<
" RightPad = " <<
RightPad << std::endl;
690 else if constexpr(ConvForwardSpecialization ==
694 for(
index_t i = 0; i < NDimSpatial; ++i)
702 std::cout <<
"Filter1x1Stride1Pad0 check: XY_index = " << i <<
" X = " << X
718 auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
719 if(srcVectorLengths[
I1] != 1 || srcVectorLengths[
I2] != 1)
723 if(K1 % srcVectorLengths[
I3] != 0 || K0PerBlock % srcVectorLengths[
I0] != 0)
730 if(C % (srcVectorLengths[
I0] * srcVectorLengths[
I3]) != 0)
749 auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
750 if(srcVectorLengths[
I1] != 1 || srcVectorLengths[
I2] != 1)
754 if(K1 % srcVectorLengths[
I3] != 0 || K0PerBlock % srcVectorLengths[
I0] != 0)
761 if(C % (srcVectorLengths[
I0] * srcVectorLengths[
I3]) != 0)
780 if(!(K % CThreadTransferDstScalarPerVector == 0 && CThreadTransferSrcDstVectorDim == 5))
803 const std::array<const void*, NumDTensor>& p_ds,
805 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
806 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
807 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
808 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
809 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_lengths,
810 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_strides,
811 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
812 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
813 const std::array<index_t, NDimSpatial>& conv_filter_strides,
814 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
815 const std::array<index_t, NDimSpatial>& input_left_pads,
816 const std::array<index_t, NDimSpatial>& input_right_pads,
817 const AElementwiseOperation& a_element_op,
818 const BElementwiseOperation& b_element_op,
819 const CDEElementwiseOperation& cde_element_op)
829 ds_g_n_k_wos_lengths,
830 ds_g_n_k_wos_strides,
834 conv_filter_dilations,
845 const std::array<const void*, NumDTensor>& p_ds,
847 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
848 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
849 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
850 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
851 const std::array<std::array<long_index_t, NDimSpatial + 3>,
NumDTensor>&
852 ds_g_n_k_wos_lengths,
853 const std::array<std::array<long_index_t, NDimSpatial + 3>,
NumDTensor>&
854 ds_g_n_k_wos_strides,
855 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
856 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
857 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
858 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
859 const std::array<long_index_t, NDimSpatial>& input_left_pads,
860 const std::array<long_index_t, NDimSpatial>& input_right_pads,
861 const AElementwiseOperation& a_element_op,
862 const BElementwiseOperation& b_element_op,
863 const CDEElementwiseOperation& cde_element_op)
865 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
866 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
867 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
868 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
869 std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor> ds_g_n_k_wos_lengths_i32;
870 std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor> ds_g_n_k_wos_strides_i32;
871 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
872 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
873 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
874 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
875 std::array<index_t, NDimSpatial> input_left_pads_i32;
876 std::array<index_t, NDimSpatial> input_right_pads_i32;
884 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
885 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
890 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
898 a_g_n_c_wis_lengths_i32,
899 a_g_n_c_wis_strides_i32,
900 b_g_k_c_xs_lengths_i32,
901 b_g_k_c_xs_strides_i32,
902 ds_g_n_k_wos_lengths_i32,
903 ds_g_n_k_wos_strides_i32,
904 e_g_n_k_wos_lengths_i32,
905 e_g_n_k_wos_strides_i32,
906 conv_filter_strides_i32,
907 conv_filter_dilations_i32,
909 input_right_pads_i32,
920 const std::array<const void*, NumDTensor>& p_ds,
922 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
923 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
924 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
925 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
926 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_lengths,
927 const std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor>& ds_g_n_k_wos_strides,
928 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
929 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
930 const std::array<index_t, NDimSpatial>& conv_filter_strides,
931 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
932 const std::array<index_t, NDimSpatial>& input_left_pads,
933 const std::array<index_t, NDimSpatial>& input_right_pads,
934 const AElementwiseOperation& a_element_op,
935 const BElementwiseOperation& b_element_op,
936 const CDEElementwiseOperation& cde_element_op)
override
938 return std::make_unique<Argument>(p_a,
946 ds_g_n_k_wos_lengths,
947 ds_g_n_k_wos_strides,
951 conv_filter_dilations,
959 std::unique_ptr<BaseArgument>
962 const std::array<const void*, NumDTensor>& p_ds,
964 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
965 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
966 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
967 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
968 const std::array<std::array<long_index_t, NDimSpatial + 3>,
NumDTensor>&
969 ds_g_n_k_wos_lengths,
970 const std::array<std::array<long_index_t, NDimSpatial + 3>,
NumDTensor>&
971 ds_g_n_k_wos_strides,
972 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
973 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
974 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
975 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
976 const std::array<long_index_t, NDimSpatial>& input_left_pads,
977 const std::array<long_index_t, NDimSpatial>& input_right_pads,
978 const AElementwiseOperation& a_element_op,
979 const BElementwiseOperation& b_element_op,
980 const CDEElementwiseOperation& cde_element_op)
override
982 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
983 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
984 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
985 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
986 std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor> ds_g_n_k_wos_lengths_i32;
987 std::array<std::array<index_t, NDimSpatial + 3>,
NumDTensor> ds_g_n_k_wos_strides_i32;
988 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
989 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
990 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
991 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
992 std::array<index_t, NDimSpatial> input_left_pads_i32;
993 std::array<index_t, NDimSpatial> input_right_pads_i32;
1001 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
1002 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
1004 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
1005 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
1006 array_convert(conv_filter_strides_i32, conv_filter_strides);
1007 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
1011 return std::make_unique<Argument>(p_a,
1015 a_g_n_c_wis_lengths_i32,
1016 a_g_n_c_wis_strides_i32,
1017 b_g_k_c_xs_lengths_i32,
1018 b_g_k_c_xs_strides_i32,
1019 ds_g_n_k_wos_lengths_i32,
1020 ds_g_n_k_wos_strides_i32,
1021 e_g_n_k_wos_lengths_i32,
1022 e_g_n_k_wos_strides_i32,
1023 conv_filter_strides_i32,
1024 conv_filter_dilations_i32,
1025 input_left_pads_i32,
1026 input_right_pads_i32,
1034 return std::make_unique<Invoker>(
Invoker{});
1039 auto str = std::stringstream();
1042 str <<
"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"
1044 << BlockSize <<
", "
1045 << MPerBlock <<
", "
1046 << NPerBlock <<
", "
1047 << K0PerBlock <<
", "
1056#ifdef CK_EXPERIMENTAL_BUILDER
1059 static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
1060 "Specialization of instance_traits not found. Please check that a "
1061 "specialization exists in file "
1062 "ck_tile/builder/reflect/"
1063 "instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp "
1064 "for the given template parameters.");
1065 return ck_tile::reflect::instance_string<DeviceOp>();
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
Definition convolution_backward_data_specialization.hpp:7
bool is_xdl_supported()
Definition host_utility/device_prop.hpp:68
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
__host__ __device__ void array_convert(std::array< Y, NumElems > &y, const std::array< X, NumElems > &x)
Definition utility/type_convert.hpp:2466
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
Definition ck/stream_config.hpp:10
int log_level_
Definition ck/stream_config.hpp:13
Definition gridwise_gemm_dl_multiple_d.hpp:60
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeDefaultBlock2CTileMap __host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const EGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:242
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeBGridDescriptor_K0_N0_N1_K1 __host__ static __device__ constexpr auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_BK0_N_BK1 &b_grid_desc_k0_n_k1)
Definition gridwise_gemm_dl_multiple_d.hpp:178
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeAGridDescriptor_K0_M0_M1_K1 __host__ static __device__ constexpr auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_AK0_M_AK1 &a_grid_desc_k0_m_k1)
Definition gridwise_gemm_dl_multiple_d.hpp:158
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::DsGridPointer decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_dl_multiple_d.hpp:253
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateGridSize __host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition gridwise_gemm_dl_multiple_d.hpp:136
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasDoubleTailKBlockLoop __host__ static __device__ constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_multiple_d.hpp:150
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasMainKBlockLoop __host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_multiple_d.hpp:143
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_k0_m_k1, const BGridDesc_BK0_N_BK1 &b_grid_desc_k0_n_k1, const EGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:110
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11 __host__ static __device__ constexpr auto MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:234
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11 __host__ static __device__ constexpr auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:200
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition device_base.hpp:197
virtual std::string GetInstanceString() const
Definition device_base.hpp:230
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:380
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:545
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:542
void Print() const
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:494
const ADataType * p_a_grid_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:509
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:517
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:550
EDataType * p_e_grid_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:512
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:525
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:543
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:549
GridwiseGemm::DsGridPointer p_ds_grid_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:511
std::array< index_t, NDimSpatial > conv_filter_dilations_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:551
ComputePtrOffsetOfStridedBatch< I1, I1, NumDTensor > compute_ptr_offset_of_batch_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:534
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:526
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:519
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:547
const BDataType * p_b_grid_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:510
EGridDesc_M_N e_grid_desc_m_n_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:522
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:546
BElementwiseOperation b_element_op_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:538
AElementwiseOperation a_element_op_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:537
DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:527
Argument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:381
DefaultBlock2CTileMap block_2_ctile_map_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:531
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:521
index_t num_group_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:515
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:548
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:544
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:520
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:553
CDEElementwiseOperation cde_element_op_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:539
CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:528
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:552
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:558
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:651
DeviceOp::Argument Argument
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:559
float Run(const Argument &arg, const StreamConfig &stream_config)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:561
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:234
static constexpr auto I1
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:240
remove_cvref_t< decltype(MakeEGridDescriptor_M_N< ELayout >(dummy_conv_to_gemm_transformer))> EGridDesc_M_N
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:323
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))> DsGridDesc_M_N
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:321
decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})) DsGridDesc_M0_M10_M11_N0_N10_N11
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:371
static auto MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:272
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization > ConvToGemmFwdTransformer
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:244
static constexpr index_t NumDTensor
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:237
static constexpr auto I3
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:242
static constexpr auto I2
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:241
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_BK0_N_BK1{})) BGridDesc_K0_N0_N1_K1
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:369
GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:327
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK DeviceOp
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:235
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})) DefaultBlock2CTileMap
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:375
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_AK0_M_AK1{})) AGridDesc_K0_M0_M1_K1
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:367
std::string GetTypeString() const override
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:1037
static auto MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:251
static constexpr auto matrix_padder
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:246
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:1032
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:658
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:373
static constexpr ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:316
remove_cvref_t< decltype(MakeAGridDescriptor_AK0_M_AK1< ALayout >( dummy_conv_to_gemm_transformer))> AGridDesc_AK0_M_AK1
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:317
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:960
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:917
static auto MakeInvoker()
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:915
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:800
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:293
remove_cvref_t< decltype(MakeBGridDescriptor_BK0_N_BK1< BLayout >( dummy_conv_to_gemm_transformer))> BGridDesc_BK0_N_BK1
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:319
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:843
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:304
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:795
static constexpr auto I0
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:239
Grouped Convolution Forward.
Definition device_grouped_conv_fwd_multiple_abd.hpp:73
Definition matrix_padder.hpp:180