device_grouped_conv_bwd_weight_xdl_cshuffle.hpp Source File#
device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
Go to the documentation of this file.
28#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
1222 << "TransposeTransferDstScalarPerVectorAligned: " << TransposeTransferDstScalarPerVectorAligned;
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
auto get_bwd_weight_gemm_sizes(const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths)
Definition split_k_utils.hpp:55
ConvolutionBackwardWeightSpecialization
Definition convolution_backward_weight_specialization.hpp:13
@ Filter1x1Stride1Pad0
Definition convolution_backward_weight_specialization.hpp:15
constexpr bool is_GNWC_GKXC_GNWK()
Definition device_grouped_conv_utils.hpp:23
__global__ void kernel_batched_gemm_xdlops_bwd_weight(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const index_t batch_count, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:50
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
Definition device_grouped_conv_utils.hpp:88
constexpr bool is_NHWGC_GKYXC_NHWGK()
Definition device_grouped_conv_utils.hpp:40
ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
Definition split_k_utils.hpp:30
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition device_grouped_conv_utils.hpp:80
constexpr bool is_NGCDHW_NGKDHW()
Definition device_grouped_conv_utils.hpp:112
constexpr bool is_NGCHW_GKCYX_NGKHW()
Definition device_grouped_conv_utils.hpp:64
std::string getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization &s)
Definition convolution_backward_weight_specialization.hpp:21
ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
Definition split_k_utils.hpp:84
constexpr bool is_GNHWC_GKYXC_GNHWK()
Definition device_grouped_conv_utils.hpp:48
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
Definition device_grouped_conv_utils.hpp:104
constexpr bool is_NGCHW_NGKHW()
Definition device_grouped_conv_utils.hpp:72
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition gridwise_elementwise_2d.hpp:61
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:261
Definition gridwise_elementwise_2d.hpp:278
Definition gridwise_gemm_xdlops_bwd_weight.hpp:254
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_K0_M_K1 &a_b_k0_m_k1_grid_desc, const BGridDesc_K0_N_K1 &b_b_k0_n_k1_grid_desc, const CGridDesc_M_N &c_m_n_grid_desc, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_xdlops_bwd_weight.hpp:544
__host__ static __device__ constexpr auto MakeCBlockClusterAdaptor(const CGridDesc_M_N &c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
Definition gridwise_gemm_xdlops_bwd_weight.hpp:625
__host__ static __device__ constexpr auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_m_n_grid_desc)
Definition gridwise_gemm_xdlops_bwd_weight.hpp:608
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:24
Definition transform_conv_ngchw_to_nhwgc.hpp:31
Definition split_k_arg.hpp:11
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
BaseArgument()=default
BaseInvoker()=default
virtual std::string GetInstanceString() const
Definition device_base.hpp:230
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:440
int max_occupancy_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:484
static int GetMaxOccupancy()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:442
ActiveWorkgroupsPerCU()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:466
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:488
WeiElementwiseOperation c_element_op_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:720
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:713
const std::array< ck::index_t, NDimSpatial > & input_left_pads_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:731
Argument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, const ck::index_t M01, const ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:489
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:698
NHWGCTransposeDescType b_out_transpose_desc_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:707
OutElementwiseOperation a_element_op_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:718
CGridDesc_M_N c_grid_desc_m_n_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:699
std::size_t GetWorkspaceATensorSizeBytes() const
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:643
std::array< ck::index_t, NDimSpatial > output_spatial_lengths_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:729
std::array< ck::index_t, NDimSpatial > input_spatial_lengths_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:727
InElementwiseOperation b_element_op_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:719
const index_t Conv_G_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:723
Block2CTileMap block_2_ctile_map_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:701
NGCHWTransposeDescType b_in_transpose_desc_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:706
long_index_t c_space_size_bytes
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:733
std::size_t GetWorkspaceSizeBytes() const
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:688
const ADataType * p_a_grid_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:694
const std::array< ck::index_t, NDimSpatial > & conv_filter_strides_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:730
index_t N01_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:716
const index_t Conv_K_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:725
NGCHWTransposeDescType a_in_transpose_desc_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:706
CDataType * p_c_grid_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:696
index_t M01_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:715
NHWGCTransposeDescType a_out_transpose_desc_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:707
Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_a_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:703
GKCYXTransposeDescType e_out_transpose_desc_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:710
std::array< ck::index_t, NDimSpatial > filter_spatial_lengths_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:728
Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_b_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:704
const std::array< ck::index_t, NDimSpatial > & input_right_pads_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:732
std::size_t GetWorkspaceBTensorSizeBytes() const
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:659
Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_e_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:704
GKYXCTransposeDescType e_in_transpose_desc_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:709
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:697
const BDataType * p_b_grid_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:695
const index_t Conv_C_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:726
const index_t Conv_N_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:724
std::size_t GetWorkspaceETensorSizeBytes() const
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:675
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:738
DeviceOp::Argument Argument
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:739
void ShowInfo(const Argument &arg)
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:741
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:928
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:760
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:170
static constexpr auto BBlockLdsN0PerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:220
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1258
GridwiseElementwise< Tuple< GKYXCTransposeDescType >, Tuple< GKCYXTransposeDescType >, Tuple< const CDataType * >, Tuple< CDataType * >, Block2TileMapTranspose, element_wise::PassThrough, BlockSize, MPerBlock, NPerBlock, MPerBlock/ClusterLengthMPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< CBlockTransferScalarPerVector_NWaveNPerXdl >, Sequence< 1 >, I1, I0 > GridwiseElementwiseWeightTranspose
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:357
OutElementwiseOperation AElementwiseOperation
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:185
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:298
InElementwiseOperation BElementwiseOperation
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:186
static constexpr auto I2
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:194
static constexpr auto ABlockLdsM1PerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:214
static constexpr index_t TransposeTransferSrcScalarPerVectorAligned
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:319
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNHWGCTransposeDesc< NDimSpatial >({}, {}))> NHWGCTransposeDescType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:327
static constexpr auto NXdlPerWave32
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:174
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1103
static constexpr index_t TransposeTransferDstScalarPerVectorAligned
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:321
BlockToCTileMap_M00_N0_M01Adapt< MPerBlock, NPerBlock > Block2TileMapTranspose
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:317
static constexpr auto I3
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:195
InDataType BDataType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:177
static constexpr auto ElePerBank
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:211
static constexpr index_t ClusterLengthMPerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:304
static constexpr auto I5
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:197
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:430
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKYXCTransposeDesc< NDimSpatial >({}, {}))> GKYXCTransposeDescType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:333
static constexpr auto I0
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:192
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, BDataType, AccDataType, CDataType, InMemoryDataOperationEnum::AtomicAdd, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, ABlockLdsM1PerBlock, ABlockLdsM0PerBlock, ABlockLdsM1Padding, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, BBlockLdsN1PerBlock, BBlockLdsN0PerBlock, BBlockLdsN1Padding, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true, 1, PipelineVersion::v1, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:376
decltype(GridwiseGemm64::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:433
decltype(GridwiseGemm64::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)) Block2CTileMap
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:436
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:935
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:302
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNGCHWTransposeDesc< NDimSpatial >({}, {}))> NGCHWTransposeDescType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:324
GridwiseElementwise< Tuple< NGCHWTransposeDescType >, Tuple< NHWGCTransposeDescType >, Tuple< const ADataType * >, Tuple< ADataType * >, Block2TileMapTranspose, element_wise::PassThrough, BlockSize, MPerBlock, NPerBlock, MPerBlock/ClusterLengthMPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< TransposeTransferSrcScalarPerVectorAligned >, Sequence< TransposeTransferDstScalarPerVectorAligned >, I1, I0 > GridwiseInOutTranspose
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:337
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:429
static constexpr auto conv_ngchw_to_nhwgc_transformer
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:309
static constexpr auto I1
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:193
static constexpr auto I4
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:196
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKCYXTransposeDesc< NDimSpatial >({}, {}))> GKCYXTransposeDescType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:330
static constexpr auto BBlockLdsN1Padding
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:221
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1245
static auto MakeInvoker()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1148
static constexpr auto BBlockLdsN1PerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:219
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:941
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:301
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1190
static constexpr auto ABlockLdsM1Padding
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:216
WeiDataType CDataType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:178
std::string GetTypeString() const override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1195
static constexpr auto BankLength
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:210
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, void *p_wei_grid, const void *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const ck::index_t split_k) override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1151
static constexpr auto K1Number
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:199
InDataType ABDataType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:190
static constexpr index_t ClusterLengthNPerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:306
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:300
static constexpr auto ABlockLdsM0PerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:215
WeiElementwiseOperation CElementwiseOperation
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:187
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:173
DeviceGroupedConvBwdWeight_Xdl_CShuffle DeviceOp
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:171
static constexpr auto conv_to_gemm_transformer
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:201
static auto MakeArgument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1109
OutDataType ADataType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:176
static auto GetABCGridDesc()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:224
Definition device_grouped_conv_bwd_weight.hpp:29
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340