gridwise_welford_second_half_layernorm2d.hpp Source File#
gridwise_welford_second_half_layernorm2d.hpp
Go to the documentation of this file.
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_welford_second_half_layernorm2d.hpp:42
static constexpr index_t M_BlockTileSize
Definition gridwise_welford_second_half_layernorm2d.hpp:85
static __device__ void Run(const EMeanVarDataType *__restrict__ p_e_grid, const EMeanVarDataType *__restrict__ p_in_welford_mean_grid, const EMeanVarDataType *__restrict__ p_in_welford_var_grid, const int32_t *__restrict__ p_in_welford_count_grid, const GammaDataType *__restrict__ p_gamma_grid, const BetaDataType *__restrict__ p_beta_grid, HDataType *__restrict__ p_h_grid, const EHGridDesc_M_N &e_grid_desc_m_n, const EHGridDesc_M_N &h_grid_desc_m_n, const MeanVarGridDesc_M_NBlock &mean_var_grid_desc_m_nblock, const CountGridDesc_M_NBlock &count_grid_desc_m_nblock, const GammaBetaGridDesc_N &gamma_grid_desc_n, const GammaBetaGridDesc_N &beta_grid_desc_n, index_t numMeanVarCountBlockTileIteration_N, index_t NBlockClusterLength, ComputeDataType epsilon, HElementwiseOperation h_element_op)
Definition gridwise_welford_second_half_layernorm2d.hpp:88
ThreadwiseWelfordMerge< AccDataType, ThreadWelfordSrcDesc_M_1, ThreadWelfordDstDesc_M > ThreadwiseWelford
Definition gridwise_welford_second_half_layernorm2d.hpp:74
BlockwiseWelford< AccDataType, BlockSize, ThreadClusterLengths_M_N, ThreadClusterArrangeOrder > BlockwiseWelford
Definition gridwise_welford_second_half_layernorm2d.hpp:77
Sequence< NThreadSliceSize > ThreadBufferLengths_N
Definition gridwise_welford_second_half_layernorm2d.hpp:66
static constexpr auto thread_buffer_desc_m_1
Definition gridwise_welford_second_half_layernorm2d.hpp:63
Sequence< 0, 1 > ThreadClusterArrangeOrder
Definition gridwise_welford_second_half_layernorm2d.hpp:53
static constexpr auto thread_cluster_desc_m_n
Definition gridwise_welford_second_half_layernorm2d.hpp:55
Sequence< MThreadClusterSize, NThreadClusterSize > ThreadClusterLengths_M_N
Definition gridwise_welford_second_half_layernorm2d.hpp:51
Sequence< MThreadSliceSize, 1 > ThreadBufferLengths_M_1
Definition gridwise_welford_second_half_layernorm2d.hpp:62
static constexpr auto I1
Definition gridwise_welford_second_half_layernorm2d.hpp:83
decltype(thread_buffer_desc_m_1) ThreadWelfordSrcDesc_M_1
Definition gridwise_welford_second_half_layernorm2d.hpp:70
Sequence< 0, 1 > ThreadBufferDimAccessOrder
Definition gridwise_welford_second_half_layernorm2d.hpp:52
static constexpr auto thread_buffer_desc_n
Definition gridwise_welford_second_half_layernorm2d.hpp:67
static constexpr index_t N_BlockTileSize
Definition gridwise_welford_second_half_layernorm2d.hpp:86
Sequence< MThreadSliceSize, NThreadSliceSize > ThreadBufferLengths_M_N
Definition gridwise_welford_second_half_layernorm2d.hpp:58
static constexpr auto I0
Definition gridwise_welford_second_half_layernorm2d.hpp:82
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadWelfordDstDesc_M
Definition gridwise_welford_second_half_layernorm2d.hpp:71
static constexpr auto thread_buffer_desc_m_n
Definition gridwise_welford_second_half_layernorm2d.hpp:59
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:276
Definition threadwise_welford.hpp:83
ck::ThreadwiseWelfordMerge< ComputeDataType, ThreadWelfordSrcDesc_M_1, ThreadWelfordDstDesc_M >::Run
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition threadwise_welford.hpp:110
Definition functional2.hpp:33