layernorm2d_fwd_kernel.hpp Source File#
layernorm2d_fwd_kernel.hpp
Go to the documentation of this file.
164 if (kXbias != Layernorm2dXBiasEnum::NO_BIAS) n += _SS_("_") + Layernorm2dXBiasEnumName<kXbias>::name;
165 if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name;
166 if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name;
188 _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
#define _TS_
#define _SS_
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
@ SMOOTH_DYNAMIC_QUANT
Definition layernorm2d_fwd_traits.hpp:42
@ DYNAMIC_QUANT
Definition layernorm2d_fwd_traits.hpp:43
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
@ PRE_ADD_STORE
Definition layernorm2d_fwd_traits.hpp:27
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view_packed(DataType *__restrict__ p, const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition tensor_view.hpp:494
Definition layernorm2d_fwd_traits.hpp:33
Definition layernorm2d_fwd_traits.hpp:47
Definition layernorm2d_fwd_kernel.hpp:84
const void * p_x_bias
Definition layernorm2d_fwd_kernel.hpp:88
void * p_y_residual
Definition layernorm2d_fwd_kernel.hpp:93
const void * p_sm_scale
Definition layernorm2d_fwd_kernel.hpp:87
const void * p_x_residual
Definition layernorm2d_fwd_kernel.hpp:86
static constexpr const char * name
Definition layernorm2d_fwd_kernel.hpp:147
static constexpr const char * name
Definition layernorm2d_fwd_kernel.hpp:149
static constexpr const char * name
Definition layernorm2d_fwd_kernel.hpp:146
static constexpr const char * name
Definition layernorm2d_fwd_kernel.hpp:148
static constexpr const char * name
Definition layernorm2d_fwd_kernel.hpp:150
static constexpr const char * name
Definition layernorm2d_fwd_kernel.hpp:145
Definition layernorm2d_fwd_kernel.hpp:144
Definition layernorm2d_fwd_kernel.hpp:14
void * p_y_residual
Definition layernorm2d_fwd_kernel.hpp:23
const void * p_x_bias
Definition layernorm2d_fwd_kernel.hpp:18
const void * p_gamma
Definition layernorm2d_fwd_kernel.hpp:19
const void * p_x_residual
Definition layernorm2d_fwd_kernel.hpp:16
const void * p_sm_scale
Definition layernorm2d_fwd_kernel.hpp:17
Definition layernorm2d_fwd_kernel.hpp:41
typename Pipeline::Problem Problem
Definition layernorm2d_fwd_kernel.hpp:44
remove_cvref_t< typename Problem::BetaDataType > BetaDataType
Definition layernorm2d_fwd_kernel.hpp:49
remove_cvref_t< Pipeline_ > Pipeline
Definition layernorm2d_fwd_kernel.hpp:42
remove_cvref_t< typename Problem::XDataType > XDataType
Definition layernorm2d_fwd_kernel.hpp:46
static constexpr bool kHasBeta
Definition layernorm2d_fwd_kernel.hpp:62
static CK_TILE_HOST constexpr auto GridSize(const Hargs &hargs)
Definition layernorm2d_fwd_kernel.hpp:132
static CK_TILE_HOST std::string GetName()
Definition layernorm2d_fwd_kernel.hpp:156
static constexpr auto kFusedAdd
Definition layernorm2d_fwd_kernel.hpp:73
static constexpr index_t Repeat_N
Definition layernorm2d_fwd_kernel.hpp:78
static constexpr index_t Block_N
Definition layernorm2d_fwd_kernel.hpp:68
remove_cvref_t< typename Problem::XBiasDataType > XBiasDataType
Definition layernorm2d_fwd_kernel.hpp:47
static constexpr bool kHasGamma
Definition layernorm2d_fwd_kernel.hpp:61
static constexpr index_t ThreadPerWarp_N
Definition layernorm2d_fwd_kernel.hpp:76
static CK_TILE_HOST constexpr auto BlockSize()
Definition layernorm2d_fwd_kernel.hpp:137
static constexpr bool kSaveMeanInvStd
Definition layernorm2d_fwd_kernel.hpp:63
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition layernorm2d_fwd_kernel.hpp:196
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition layernorm2d_fwd_kernel.hpp:50
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition layernorm2d_fwd_kernel.hpp:55
static constexpr index_t Vector_N
Definition layernorm2d_fwd_kernel.hpp:77
static constexpr bool kTwoPass
Definition layernorm2d_fwd_kernel.hpp:71
XDataType YResidualDataType
Definition layernorm2d_fwd_kernel.hpp:59
static constexpr auto kFusedQuant
Definition layernorm2d_fwd_kernel.hpp:74
static constexpr bool kSaveMean
Definition layernorm2d_fwd_kernel.hpp:64
remove_cvref_t< typename Problem::YDataType > YDataType
Definition layernorm2d_fwd_kernel.hpp:51
static constexpr index_t Block_M
Definition layernorm2d_fwd_kernel.hpp:67
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition layernorm2d_fwd_kernel.hpp:154
XDataType XResidualDataType
Definition layernorm2d_fwd_kernel.hpp:58
remove_cvref_t< typename Problem::InvStdDataType > InvStdDataType
Definition layernorm2d_fwd_kernel.hpp:53
remove_cvref_t< typename Problem::MeanDataType > MeanDataType
Definition layernorm2d_fwd_kernel.hpp:52
static constexpr bool kSaveInvStd
Definition layernorm2d_fwd_kernel.hpp:65
static constexpr index_t kBlockSize
Definition layernorm2d_fwd_kernel.hpp:79
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition layernorm2d_fwd_kernel.hpp:54
remove_cvref_t< Epilogue_ > Epilogue
Definition layernorm2d_fwd_kernel.hpp:43
remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition layernorm2d_fwd_kernel.hpp:48
static CK_TILE_HOST constexpr Kargs MakeKargs(const Hargs &hargs)
Definition layernorm2d_fwd_kernel.hpp:110
Definition layernorm2d_fwd_traits.hpp:18
Definition tile/core/container/sequence.hpp:49