device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp Source File#
device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__global__ void kernel_gemm_xdlops_bwd_weight(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor)
Definition gridwise_gemm_xdlops_bwd_weight.hpp:157
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_bwd_weight.hpp:254
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:69
static constexpr auto ABlockLdsM1Padding
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:109
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::AtomicAdd, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXdl, NPerXdl, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, ABlockLdsM1PerBlock, ABlockLdsM0PerBlock, ABlockLdsM1Padding, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, BBlockLdsN1PerBlock, BBlockLdsN0PerBlock, BBlockLdsN1Padding, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true > GridwiseGemmAtomicAddBase
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:341
static constexpr auto I5
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:95
static constexpr auto I2
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:92
static constexpr bool IsValidCompilationParameter()
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:660
static bool IsSupportedArgument(const Argument &arg)
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:666
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, void *p_wei_grid, const void *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::array< ck::index_t, NDimSpatial > input_spatial_lengths, std::array< ck::index_t, NDimSpatial > filter_spatial_lengths, std::array< ck::index_t, NDimSpatial > output_spatial_lengths, std::array< ck::index_t, NDimSpatial > conv_filter_strides, std::array< ck::index_t, NDimSpatial > conv_filter_dilations, std::array< ck::index_t, NDimSpatial > input_left_pads, std::array< ck::index_t, NDimSpatial > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k) override
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:752
static constexpr auto NXdlPerWave32
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:77
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:283
static constexpr auto I3
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:93
static constexpr auto N1Number
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:100
GridwiseGemmAtomicAddBase< NXdlPerWave32 > GridwiseGemmAtomicAdd32
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:390
WeiDataType CDataType
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:81
static constexpr auto K1Number
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:97
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, ck::index_t C, std::array< ck::index_t, NDimSpatial > input_spatial_lengths, std::array< ck::index_t, NDimSpatial > filter_spatial_lengths, std::array< ck::index_t, NDimSpatial > output_spatial_lengths, std::array< ck::index_t, NDimSpatial > conv_filter_strides, std::array< ck::index_t, NDimSpatial > conv_filter_dilations, std::array< ck::index_t, NDimSpatial > input_left_pads, std::array< ck::index_t, NDimSpatial > input_right_pads, ck::index_t batch_k)
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:116
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXdl, NPerXdl, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, ABlockLdsM1PerBlock, ABlockLdsM0PerBlock, ABlockLdsM1Padding, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, BBlockLdsN1PerBlock, BBlockLdsN0PerBlock, BBlockLdsN1Padding, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true > GridwiseGemmBase
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:289
static constexpr auto I0
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:90
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:338
std::string GetTypeString() const override
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:796
static constexpr auto GemmK1Number
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:98
static constexpr auto BBlockLdsN1PerBlock
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:112
static constexpr auto BBlockLdsN0PerBlock
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:113
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K DeviceOp
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:72
static auto MakeInvoker()
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:749
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:76
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:393
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:705
decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1)) ABCGridDescs
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:280
static constexpr auto I1
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:91
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:284
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:285
static constexpr auto ABlockLdsM0PerBlock
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:108
OutDataType ADataType
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:79
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:791
static constexpr auto BankLength
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:103
InDataType ABDataType
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:88
static constexpr auto ABlockLdsM1PerBlock
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:107
static auto MakeArgument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::array< ck::index_t, NDimSpatial > input_spatial_lengths, std::array< ck::index_t, NDimSpatial > filter_spatial_lengths, std::array< ck::index_t, NDimSpatial > output_spatial_lengths, std::array< ck::index_t, NDimSpatial > conv_filter_strides, std::array< ck::index_t, NDimSpatial > conv_filter_dilations, std::array< ck::index_t, NDimSpatial > input_left_pads, std::array< ck::index_t, NDimSpatial > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:710
static constexpr auto I4
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:94
OutElementwiseOperation AElementwiseOperation
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:83
InDataType BDataType
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:80
static constexpr auto BBlockLdsN1Padding
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:114
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)) Block2CTileMap
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:396
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:337
static constexpr ck::index_t NDimSpatial
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:70
WeiElementwiseOperation CElementwiseOperation
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:85
InElementwiseOperation BElementwiseOperation
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:84
static constexpr auto ElePerBank
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:104
GridwiseGemmAtomicAddBase< math::max(NXdlPerWave64, 1)> GridwiseGemmAtomicAdd64
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:389
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:499
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:521
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:653
void Print(const Argument &arg)
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:502
DeviceOp::Argument Argument
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:500
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:399
index_t N01_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:481
index_t Conv_N_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:486
const ADataType * p_a_grid_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:472
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:493
index_t Conv_K_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:487
WeiElementwiseOperation c_element_op_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:484
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:492
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:478
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:475
index_t k_batch_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:494
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:476
std::array< index_t, NDimSpatial > filter_spatial_lengths_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:490
Block2CTileMap block_2_ctile_map_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:479
index_t M01_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:480
index_t Conv_C_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:488
CGridDesc_M_N c_grid_desc_m_n_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:477
Argument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::array< ck::index_t, NDimSpatial > input_spatial_lengths, std::array< ck::index_t, NDimSpatial > filter_spatial_lengths, std::array< ck::index_t, NDimSpatial > output_spatial_lengths, std::array< ck::index_t, NDimSpatial > conv_filter_strides, std::array< ck::index_t, NDimSpatial > conv_filter_dilations, std::array< ck::index_t, NDimSpatial > input_left_pads, std::array< ck::index_t, NDimSpatial > input_right_pads, ck::index_t M01, ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:400
const BDataType * p_b_grid_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:473
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:491
OutElementwiseOperation b_element_op_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:483
InElementwiseOperation a_element_op_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:482
CDataType * p_c_grid_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:474
std::array< index_t, NDimSpatial > output_spatial_lengths_
Definition device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp:489