device_normalization_fwd_impl.hpp Source File

device_normalization_fwd_impl.hpp Source File#

Composable Kernel: device_normalization_fwd_impl.hpp Source File
device_normalization_fwd_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
16
17namespace ck {
18namespace tensor_operation {
19namespace device {
20
21// Y = Normalization(X, Beta, Gamma)
22// M: Invariant length
23// K: Reduce length (Calculate mean and variance along K dimension)
24// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
25// Then, M = N, K = C * H * W
26template <typename XDataType,
27 typename GammaDataType,
28 typename BetaDataType,
29 typename ComputeDataType,
30 typename YDataType,
31 typename SaveMeanInvStdDataType,
32 typename YElementwiseOperation,
33 index_t Rank,
34 index_t NumReduceDim,
35 index_t BlockSize,
36 index_t MThreadClusterSize,
37 index_t KThreadClusterSize,
38 index_t MThreadSliceSize,
39 index_t KThreadSliceSize,
40 index_t XYSrcVectorDim,
41 index_t XSrcVectorSize,
42 index_t GammaSrcVectorDim,
43 index_t GammaSrcVectorSize,
44 index_t BetaSrcVectorDim,
45 index_t BetaSrcVectorSize,
46 index_t YDstVectorSize,
47 index_t SaveMeanInvStdDstVectorSize,
48 bool UseWelford = true>
50 GammaDataType,
51 BetaDataType,
52 YDataType,
53 SaveMeanInvStdDataType,
54 YElementwiseOperation,
55 Rank,
56 NumReduceDim>
57{
58 static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
59 static_assert(
60 ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
61 (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
62 "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
63
64 static_assert(
65 ((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) ||
66 (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
67 "Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
68
69 static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
70 "Invalid thread slice sizes and/or save mean and inverse std vector sizes "
71 "configuration, please check!");
72
74
75 static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
76 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
77 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
78
79 static constexpr bool reduceAllDim = (NumInvariantDim == 0);
80 static_assert(!reduceAllDim); // TODO
81
82 static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
83 const std::vector<index_t>& inStrides,
84 int numBlockTileIteration)
85 {
86 static constexpr index_t numSrcDim = Rank;
87
88 const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
89 const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
90
91 const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
92
93 const auto in_grid_desc_m_k = [&]() {
94 if constexpr(reduceAllDim)
95 {
96 const auto one_dim_inDesc = transform_tensor_descriptor(
97 inDesc,
98 make_tuple(make_merge_transform(tupleSrcLengths)),
101
102 return transform_tensor_descriptor(one_dim_inDesc,
104 1, one_dim_inDesc.GetLength(Number<0>{})))),
107 }
108 else
109 {
110 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
112
113 const auto reduceDimLengths =
114 make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
115 const auto invariantDimLengths =
116 make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
117
119 inDesc,
120 make_tuple(make_merge_transform(invariantDimLengths),
121 make_merge_transform(reduceDimLengths)),
122 make_tuple(InvariantDims{}, ReduceDims{}),
124 }
125 }();
126
127 const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
128 const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
129
130 const auto inPad_M =
131 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
132 const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength;
133
134 auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
135 in_grid_desc_m_k,
136 make_tuple(make_right_pad_transform(invariantLength, inPad_M),
137 make_right_pad_transform(reduceLength, inPad_K)),
140
141 return (in_grid_desc_m_k_padded);
142 };
143
144 static auto MakeSaveMeanInvStdDescriptor_M(const std::vector<index_t>& lengths,
145 const std::vector<index_t>& strides)
146 {
147 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
148
149 const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
150 const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{});
151
152 const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
153
154 const auto grid_desc_m =
156 make_tuple(make_merge_transform(tupleSrcLengths)),
157 make_tuple(InvariantDims{}),
159
160 const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
161 const auto pad_M =
162 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
163
164 auto grid_desc_m_padded = transform_tensor_descriptor(
165 grid_desc_m,
166 make_tuple(make_right_pad_transform(invariantLength, pad_M)),
169
170 return grid_desc_m_padded;
171 }
172
173 using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1));
174 using GridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1}));
175
176 struct Argument : public BaseArgument
177 {
178 Argument(const std::vector<index_t> lengths,
179 const std::vector<index_t> xStrides,
180 const std::vector<index_t> gammaStrides,
181 const std::vector<index_t> betaStrides,
182 const std::vector<index_t> yStrides,
183 const std::vector<index_t> saveMeanStrides,
184 const std::vector<index_t> saveInvStdStrides,
185 const std::vector<index_t> reduceDims,
186 YElementwiseOperation y_elementwise_op,
187 double epsilon,
188 const XDataType* p_x,
189 const GammaDataType* p_gamma,
190 const BetaDataType* p_beta,
191 YDataType* p_y,
192 SaveMeanInvStdDataType* p_saveMean,
193 SaveMeanInvStdDataType* p_saveInvStd)
194 : p_x_(p_x),
195 p_gamma_(p_gamma),
196 p_beta_(p_beta),
197 p_y_(p_y),
198 p_saveMean_(p_saveMean),
199 p_saveInvStd_(p_saveInvStd),
200 y_elementwise_op_(y_elementwise_op)
201 {
202 epsilon_ = static_cast<ComputeDataType>(epsilon);
203
209 saveMeanStrides_ = saveMeanStrides;
210 saveInvStdStrides_ = saveInvStdStrides;
211
213
215
217
226
228 x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
229
230 if constexpr(NumInvariantDim == 0)
232 else
234 }
235
236 ComputeDataType epsilon_;
237
238 const XDataType* p_x_;
239 const GammaDataType* p_gamma_;
240 const BetaDataType* p_beta_;
241 YDataType* p_y_;
242 SaveMeanInvStdDataType* p_saveMean_;
243 SaveMeanInvStdDataType* p_saveInvStd_;
244
245 std::vector<index_t> Lengths_;
246 std::vector<index_t> xStrides_;
247 std::vector<index_t> gammaStrides_;
248 std::vector<index_t> betaStrides_;
249 std::vector<index_t> yStrides_;
250 std::vector<index_t> saveMeanStrides_;
251 std::vector<index_t> saveInvStdStrides_;
252
253 YElementwiseOperation y_elementwise_op_;
254
256 size_t gridSize_;
257
265
266 index_t MRaw_; // Invariant length
267 index_t KRaw_; // reduce length
268
270 };
271
272 struct Invoker : public BaseInvoker
273 {
274 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
275 {
276 auto kernel_main = NormalizationKernelSelector<XDataType,
277 GammaDataType,
278 BetaDataType,
279 YDataType,
280 SaveMeanInvStdDataType,
281 ComputeDataType,
282 YElementwiseOperation,
285 BlockSize,
286 MThreadClusterSize,
287 KThreadClusterSize,
288 MThreadSliceSize,
289 KThreadSliceSize,
290 XYSrcVectorDim,
291 XSrcVectorSize,
292 GammaSrcVectorDim,
293 GammaSrcVectorSize,
294 BetaSrcVectorDim,
295 BetaSrcVectorSize,
296 XYSrcVectorDim,
297 YDstVectorSize,
298 SaveMeanInvStdDstVectorSize,
299 UseWelford>(arg.isSweeponce_);
300
301 float avg_time = 0;
302 avg_time += launch_and_time_kernel(stream_config,
303 kernel_main,
304 dim3(arg.gridSize_),
305 dim3(BlockSize),
306 0,
314 arg.epsilon_,
315 arg.p_x_,
316 arg.p_gamma_,
317 arg.p_beta_,
318 arg.p_y_,
319 arg.p_saveMean_,
320 arg.p_saveInvStd_,
322
323 return (avg_time);
324 };
325
326 float Run(const BaseArgument* p_arg,
327 const StreamConfig& stream_config = StreamConfig{}) override
328 {
329 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
330 };
331 };
332
333 bool IsSupportedArgument(const BaseArgument* p_arg) override
334 {
335 const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
336
337 if constexpr(XYSrcVectorDim == 0)
338 {
339 if constexpr(NumInvariantDim == 0)
340 {
341 return false;
342 }
343 else
344 {
345 if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
346 return false;
347
348 if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0)
349 return false;
350
351 if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0)
352 return false;
353 };
354 }
355 else
356 {
357 if(p_arg_->xStrides_[Rank - 1] != 1)
358 return false;
359
360 if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0)
361 return false;
362
363 if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0)
364 {
365 return false;
366 }
367 };
368
369 // if fastest dim is not reduced
370 if constexpr(GammaSrcVectorDim == 0)
371 {
372 if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1)
373 return (false);
374
375 if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
376 return (false);
377 }
378 else // if fastest dim is reduced
379 {
380 if(p_arg_->gammaStrides_[Rank - 1] != 1)
381 return (false);
382
383 if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
384 return (false);
385 }
386
387 // if fastest dim is not reduced
388 if constexpr(BetaSrcVectorDim == 0)
389 {
390 if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
391 return (false);
392
393 if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0)
394 return (false);
395 }
396 else // if fastest dim is reduced
397 {
398 if(p_arg_->betaStrides_[Rank - 1] != 1)
399 return (false);
400
401 if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0)
402 return (false);
403 }
404
405 if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0)
406 return false;
407
408 return true;
409 };
410
411 std::unique_ptr<BaseArgument>
412 MakeArgumentPointer(const std::vector<index_t> lengths,
413 const std::vector<index_t> xStrides,
414 const std::vector<index_t> gammaStrides,
415 const std::vector<index_t> betaStrides,
416 const std::vector<index_t> yStrides,
417 const std::vector<index_t> saveMeanStrides,
418 const std::vector<index_t> saveInvStdStrides,
419 const std::vector<index_t> reduceDims,
420 double epsilon,
421 const void* p_x,
422 const void* p_gamma,
423 const void* p_beta,
424 void* p_y,
425 void* p_saveMean,
426 void* p_saveInvStd,
427 YElementwiseOperation y_elementwise_op) override
428 {
429 if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank ||
430 betaStrides.size() != Rank || yStrides.size() != Rank ||
431 saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim)
432 throw std::runtime_error("dimension is incorrect");
433
434 return std::make_unique<Argument>(lengths,
435 xStrides,
436 gammaStrides,
437 betaStrides,
438 yStrides,
439 saveMeanStrides,
440 saveInvStdStrides,
441 reduceDims,
442 y_elementwise_op,
443 epsilon,
444 static_cast<const XDataType*>(p_x),
445 static_cast<const GammaDataType*>(p_gamma),
446 static_cast<const BetaDataType*>(p_beta),
447 static_cast<YDataType*>(p_y),
448 static_cast<SaveMeanInvStdDataType*>(p_saveMean),
449 static_cast<SaveMeanInvStdDataType*>(p_saveInvStd));
450 };
451
452 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
453 {
454 return std::make_unique<Invoker>();
455 };
456
457 std::string GetTypeString() const override
458 {
459 auto str = std::stringstream();
460
461 // clang-format off
462 str << "DeviceNormalizationFwdImpl<" << BlockSize << ",";
463 str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
464 str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
465 str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
466 str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">";
467 // clang-format on
468
469 return str.str();
470 }
471};
472
473} // namespace device
474} // namespace tensor_operation
475} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
auto make_tuple_from_array(const std::vector< index_t > &lengths, Number< arraySize >)
Definition device_reduce_common.hpp:65
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
auto make_tuple_from_array_and_index_seq(const std::vector< index_t > &lengths, Sequence< Ns... >)
Definition device_reduce_common.hpp:59
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
auto NormalizationKernelSelector(bool isSweepOnce)
Definition gridwise_normalization_selector.hpp:78
Definition ck/stream_config.hpp:10
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition device_base.hpp:197
Definition device_normalization_fwd.hpp:23
Definition device_normalization_fwd_impl.hpp:177
const GammaDataType * p_gamma_
Definition device_normalization_fwd_impl.hpp:239
const XDataType * p_x_
Definition device_normalization_fwd_impl.hpp:238
YDataType * p_y_
Definition device_normalization_fwd_impl.hpp:241
GridDesc_M save_mean_grid_desc_m_
Definition device_normalization_fwd_impl.hpp:262
Argument(const std::vector< index_t > lengths, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > saveMeanStrides, const std::vector< index_t > saveInvStdStrides, const std::vector< index_t > reduceDims, YElementwiseOperation y_elementwise_op, double epsilon, const XDataType *p_x, const GammaDataType *p_gamma, const BetaDataType *p_beta, YDataType *p_y, SaveMeanInvStdDataType *p_saveMean, SaveMeanInvStdDataType *p_saveInvStd)
Definition device_normalization_fwd_impl.hpp:178
index_t MRaw_
Definition device_normalization_fwd_impl.hpp:266
ComputeDataType epsilon_
Definition device_normalization_fwd_impl.hpp:236
std::vector< index_t > saveInvStdStrides_
Definition device_normalization_fwd_impl.hpp:251
GridDesc_M_K x_grid_desc_m_k_
Definition device_normalization_fwd_impl.hpp:258
SaveMeanInvStdDataType * p_saveInvStd_
Definition device_normalization_fwd_impl.hpp:243
std::vector< index_t > Lengths_
Definition device_normalization_fwd_impl.hpp:245
SaveMeanInvStdDataType * p_saveMean_
Definition device_normalization_fwd_impl.hpp:242
index_t invariant_lowest_length_
Definition device_normalization_fwd_impl.hpp:269
std::vector< index_t > betaStrides_
Definition device_normalization_fwd_impl.hpp:248
std::vector< index_t > saveMeanStrides_
Definition device_normalization_fwd_impl.hpp:250
bool isSweeponce_
Definition device_normalization_fwd_impl.hpp:264
std::vector< index_t > xStrides_
Definition device_normalization_fwd_impl.hpp:246
index_t KRaw_
Definition device_normalization_fwd_impl.hpp:267
std::vector< index_t > gammaStrides_
Definition device_normalization_fwd_impl.hpp:247
GridDesc_M_K y_grid_desc_m_k_
Definition device_normalization_fwd_impl.hpp:261
GridDesc_M_K gamma_grid_desc_m_k_
Definition device_normalization_fwd_impl.hpp:259
GridDesc_M save_inv_std_grid_desc_m_
Definition device_normalization_fwd_impl.hpp:263
YElementwiseOperation y_elementwise_op_
Definition device_normalization_fwd_impl.hpp:253
GridDesc_M_K beta_grid_desc_m_k_
Definition device_normalization_fwd_impl.hpp:260
int numBlockTileIteration_
Definition device_normalization_fwd_impl.hpp:255
std::vector< index_t > yStrides_
Definition device_normalization_fwd_impl.hpp:249
const BetaDataType * p_beta_
Definition device_normalization_fwd_impl.hpp:240
size_t gridSize_
Definition device_normalization_fwd_impl.hpp:256
Definition device_normalization_fwd_impl.hpp:273
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_normalization_fwd_impl.hpp:326
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_normalization_fwd_impl.hpp:274
Definition device_normalization_fwd_impl.hpp:57
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_normalization_fwd_impl.hpp:452
static constexpr index_t M_BlockTileSize
Definition device_normalization_fwd_impl.hpp:76
decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})) GridDesc_M
Definition device_normalization_fwd_impl.hpp:174
tensor_operation::element_wise::PassThrough PassThrough
Definition device_normalization_fwd_impl.hpp:73
static auto MakeSrc2dDescriptor(const std::vector< index_t > &inLengths, const std::vector< index_t > &inStrides, int numBlockTileIteration)
Definition device_normalization_fwd_impl.hpp:82
static constexpr index_t NumInvariantDim
Definition device_normalization_fwd_impl.hpp:75
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_normalization_fwd_impl.hpp:333
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > saveMeanStrides, const std::vector< index_t > saveInvStdStrides, const std::vector< index_t > reduceDims, double epsilon, const void *p_x, const void *p_gamma, const void *p_beta, void *p_y, void *p_saveMean, void *p_saveInvStd, YElementwiseOperation y_elementwise_op) override
Definition device_normalization_fwd_impl.hpp:412
static auto MakeSaveMeanInvStdDescriptor_M(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_normalization_fwd_impl.hpp:144
static constexpr index_t K_BlockTileSize
Definition device_normalization_fwd_impl.hpp:77
decltype(MakeSrc2dDescriptor({1}, {1}, 1)) GridDesc_M_K
Definition device_normalization_fwd_impl.hpp:173
std::string GetTypeString() const override
Definition device_normalization_fwd_impl.hpp:457
static constexpr bool reduceAllDim
Definition device_normalization_fwd_impl.hpp:79
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340