gridwise_welford_second_half_layernorm2d.hpp Source File

gridwise_welford_second_half_layernorm2d.hpp Source File#

Composable Kernel: gridwise_welford_second_half_layernorm2d.hpp Source File
gridwise_welford_second_half_layernorm2d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
19
20namespace ck {
21
22template <typename EMeanVarDataType,
23 typename HDataType,
24 typename GammaDataType,
25 typename BetaDataType,
26 typename ComputeDataType,
27 typename EHGridDesc_M_N,
28 typename MeanVarGridDesc_M_NBlock,
29 typename CountGridDesc_M_NBlock,
30 typename GammaBetaGridDesc_N,
31 typename HElementwiseOperation,
32 index_t BlockSize,
33 index_t MThreadClusterSize,
34 index_t NThreadClusterSize,
35 index_t MThreadSliceSize,
36 index_t NThreadSliceSize,
37 index_t ESrcVectorSize,
38 index_t HDstVectorSize,
39 index_t GammaSrcVectorSize,
40 index_t BetaSrcVectorSize>
42{
43 static_assert(NThreadSliceSize % ESrcVectorSize == 0 &&
44 NThreadSliceSize % GammaSrcVectorSize == 0 &&
45 NThreadSliceSize % BetaSrcVectorSize == 0,
46 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
47
48 static_assert(NThreadSliceSize % HDstVectorSize == 0,
49 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
50
54
55 static constexpr auto thread_cluster_desc_m_n =
57
61
63 static constexpr auto thread_buffer_desc_m_1 =
65
67 static constexpr auto thread_buffer_desc_n =
69
73
76
77 using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
78 BlockSize,
81
82 static constexpr auto I0 = Number<0>{};
83 static constexpr auto I1 = Number<1>{};
84
85 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
86 static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
87
88 __device__ static void Run(const EMeanVarDataType* __restrict__ p_e_grid,
89 const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
90 const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
91 const int32_t* __restrict__ p_in_welford_count_grid,
92 const GammaDataType* __restrict__ p_gamma_grid,
93 const BetaDataType* __restrict__ p_beta_grid,
94 HDataType* __restrict__ p_h_grid,
95 const EHGridDesc_M_N& e_grid_desc_m_n,
96 const EHGridDesc_M_N& h_grid_desc_m_n,
97 const MeanVarGridDesc_M_NBlock& mean_var_grid_desc_m_nblock,
98 const CountGridDesc_M_NBlock& count_grid_desc_m_nblock,
99 const GammaBetaGridDesc_N& gamma_grid_desc_n,
100 const GammaBetaGridDesc_N& beta_grid_desc_n,
101 index_t numMeanVarCountBlockTileIteration_N,
102 index_t NBlockClusterLength,
103 ComputeDataType epsilon,
104 HElementwiseOperation h_element_op)
105 {
106 // Thread/Block id
107 const index_t thread_local_id = get_thread_local_1d_id();
108 const index_t block_global_id = get_block_1d_id();
109 const auto block_work_idx = make_tuple(block_global_id / NBlockClusterLength,
110 block_global_id % NBlockClusterLength);
111
112 const auto thread_cluster_idx =
113 thread_cluster_desc_m_n.CalculateBottomIndex(make_multi_index(thread_local_id));
114 const auto thread_m_cluster_id = thread_cluster_idx[I0];
115 const auto thread_n_cluster_id = thread_cluster_idx[I1];
116
117 // Global Memory
118 const auto e_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
119 p_e_grid, e_grid_desc_m_n.GetElementSpaceSize());
120
121 const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
122 p_in_welford_mean_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
123
124 const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
125 p_in_welford_var_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize());
126
127 const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
128 p_in_welford_count_grid, count_grid_desc_m_nblock.GetElementSpaceSize());
129
130 const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
131 p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize());
132
133 const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
134 p_beta_grid, beta_grid_desc_n.GetElementSpaceSize());
135
136 auto h_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
137 p_h_grid, h_grid_desc_m_n.GetElementSpaceSize());
138
139 // VGPR
141 in_welford_mean_thread_buf;
143 in_welford_var_thread_buf;
145 in_welford_count_thread_buf;
146
148 welford_mean_thread_buf;
150 welford_var_thread_buf;
152 welford_count_thread_buf;
153
155 ComputeDataType,
156 MThreadSliceSize * NThreadSliceSize,
157 true>
158 e_thread_buf;
160 ComputeDataType,
161 MThreadSliceSize * NThreadSliceSize,
162 true>
163 gamma_thread_buf;
165 ComputeDataType,
166 MThreadSliceSize * NThreadSliceSize,
167 true>
168 beta_thread_buf;
170 ComputeDataType,
171 MThreadSliceSize * NThreadSliceSize,
172 true>
173 h_thread_buf;
174
175 // IO
176 auto threadwise_mean_load_m_nblock =
177 ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
178 ComputeDataType,
179 MeanVarGridDesc_M_NBlock,
180 decltype(thread_buffer_desc_m_1),
183 1,
184 1,
185 1,
186 true>(
187 mean_var_grid_desc_m_nblock,
188 make_multi_index(block_work_idx[I0] * M_BlockTileSize +
189 thread_m_cluster_id * MThreadSliceSize,
190 thread_n_cluster_id));
191
192 auto threadwise_var_load_m_nblock =
193 ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
194 ComputeDataType,
195 MeanVarGridDesc_M_NBlock,
196 decltype(thread_buffer_desc_m_1),
199 1,
200 1,
201 1,
202 true>(
203 mean_var_grid_desc_m_nblock,
204 make_multi_index(block_work_idx[I0] * M_BlockTileSize +
205 thread_m_cluster_id * MThreadSliceSize,
206 thread_n_cluster_id));
207
208 auto threadwise_count_load_m_nblock =
210 int32_t,
211 CountGridDesc_M_NBlock,
212 decltype(thread_buffer_desc_m_1),
215 1,
216 1,
217 1,
218 true>(
219 count_grid_desc_m_nblock,
220 make_multi_index(block_work_idx[I0] * M_BlockTileSize +
221 thread_m_cluster_id * MThreadSliceSize,
222 thread_n_cluster_id));
223
224 auto threadwise_e_load_m_n =
225 ThreadwiseTensorSliceTransfer_v2<EMeanVarDataType,
226 ComputeDataType,
227 decltype(e_grid_desc_m_n),
228 decltype(thread_buffer_desc_m_n),
231 1, // SrcVectorDim
232 ESrcVectorSize,
233 1,
234 true>(
235 e_grid_desc_m_n,
237 block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
238 block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize));
239
240 auto threadwise_gamma_load_n =
242 ComputeDataType,
243 decltype(gamma_grid_desc_n),
244 decltype(thread_buffer_desc_n),
246 Sequence<0>, // DimAccessOrder,
247 0, // SrcVectorDim,
248 GammaSrcVectorSize,
249 1,
250 true>(
251 gamma_grid_desc_n,
252 make_multi_index(block_work_idx[I1] * N_BlockTileSize +
253 thread_n_cluster_id * NThreadSliceSize));
254
255 auto threadwise_beta_load_n =
257 ComputeDataType,
258 decltype(beta_grid_desc_n),
259 decltype(thread_buffer_desc_n),
261 Sequence<0>, // DimAccessOrder,
262 0, // SrcVectorDim,
263 BetaSrcVectorSize,
264 1,
265 true>(
266 beta_grid_desc_n,
267 make_multi_index(block_work_idx[I1] * N_BlockTileSize +
268 thread_n_cluster_id * NThreadSliceSize));
269
270 auto threadwise_h_store_m_n =
272 HDataType,
273 decltype(thread_buffer_desc_m_n),
274 decltype(h_grid_desc_m_n),
275 HElementwiseOperation,
278 1, // DstVectorDim
279 HDstVectorSize,
281 1,
282 true>(
283 h_grid_desc_m_n,
285 block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
286 block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize),
287 h_element_op);
288
289 // step1: Merge mean and variance
290 constexpr auto mean_var_count_thread_copy_step_I0_n =
291 make_multi_index(I0, NThreadClusterSize);
292
294 welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
295 welford_var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
296 welford_count_thread_buf(I) = 0;
297 });
298
299 for(index_t n = 0; n < numMeanVarCountBlockTileIteration_N; ++n)
300 {
301 threadwise_mean_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
302 welford_mean_global_val_buf,
304 make_tuple(I0, I0),
305 in_welford_mean_thread_buf);
306
307 threadwise_var_load_m_nblock.Run(mean_var_grid_desc_m_nblock,
308 welford_var_global_val_buf,
310 make_tuple(I0, I0),
311 in_welford_var_thread_buf);
312
313 threadwise_count_load_m_nblock.Run(count_grid_desc_m_nblock,
314 welford_count_global_val_buf,
316 make_tuple(I0, I0),
317 in_welford_count_thread_buf);
318
319 ThreadwiseWelford::Run(in_welford_mean_thread_buf,
320 in_welford_var_thread_buf,
321 in_welford_count_thread_buf,
322 welford_mean_thread_buf,
323 welford_var_thread_buf,
324 welford_count_thread_buf);
325
326 threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
327 mean_var_count_thread_copy_step_I0_n);
328 threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock,
329 mean_var_count_thread_copy_step_I0_n);
330 threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_nblock,
331 mean_var_count_thread_copy_step_I0_n);
332 }
333
335 if constexpr(I > 0)
337
339 welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
340 });
341
342 // step2: normalization
343 // h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
344 threadwise_e_load_m_n.Run(e_grid_desc_m_n,
345 e_global_val_buf,
347 make_tuple(I0, I0),
348 e_thread_buf);
349
351 auto divisor = 1 / ck::math::sqrt(welford_var_thread_buf(m) + epsilon);
353 constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
354 h_thread_buf(Number<m_n>{}) =
355 (e_thread_buf(Number<m_n>{}) - welford_mean_thread_buf(m)) * divisor;
356 });
357 });
358
359 threadwise_gamma_load_n.Run(gamma_grid_desc_n,
360 gamma_global_val_buf,
362 make_tuple(I0),
363 gamma_thread_buf);
364
367 constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
368 h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) * gamma_thread_buf(n);
369 });
370 });
371
372 threadwise_beta_load_n.Run(beta_grid_desc_n,
373 beta_global_val_buf,
375 make_tuple(I0),
376 beta_thread_buf);
377
380 constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
381 h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) + beta_thread_buf(n);
382 });
383 });
384
385 threadwise_h_store_m_n.Run(thread_buffer_desc_m_n,
386 make_tuple(I0, I0),
387 h_thread_buf,
388 h_grid_desc_m_n,
389 h_global_val_buf);
390
391 } // run
392};
393
394} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
signed int int32_t
Definition stdint.h:123
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_welford_second_half_layernorm2d.hpp:42
static __device__ void Run(const EMeanVarDataType *__restrict__ p_e_grid, const EMeanVarDataType *__restrict__ p_in_welford_mean_grid, const EMeanVarDataType *__restrict__ p_in_welford_var_grid, const int32_t *__restrict__ p_in_welford_count_grid, const GammaDataType *__restrict__ p_gamma_grid, const BetaDataType *__restrict__ p_beta_grid, HDataType *__restrict__ p_h_grid, const EHGridDesc_M_N &e_grid_desc_m_n, const EHGridDesc_M_N &h_grid_desc_m_n, const MeanVarGridDesc_M_NBlock &mean_var_grid_desc_m_nblock, const CountGridDesc_M_NBlock &count_grid_desc_m_nblock, const GammaBetaGridDesc_N &gamma_grid_desc_n, const GammaBetaGridDesc_N &beta_grid_desc_n, index_t numMeanVarCountBlockTileIteration_N, index_t NBlockClusterLength, ComputeDataType epsilon, HElementwiseOperation h_element_op)
Definition gridwise_welford_second_half_layernorm2d.hpp:88
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:276
Definition threadwise_welford.hpp:83
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition threadwise_welford.hpp:110
Definition functional2.hpp:33