14template <
typename TensorShape,
typename WindowShape>
20 void* output_index_ptr_,
21 TensorShape input_shape_,
22 TensorShape output_shape_,
23 TensorShape input_strides_,
24 TensorShape output_strides_,
25 WindowShape window_lengths_,
26 WindowShape window_strides_,
27 WindowShape window_dilations_,
28 WindowShape input_left_pads_,
29 WindowShape input_right_pads_)
61template <
typename TensorShape,
typename WindowShape>
78template <
typename Problem_,
typename Policy_ = PoolDefaultPolicy>
96 template <
typename TensorShape,
typename WindowShape>
99 using S =
typename Problem::BlockShape;
102 static_assert(TensorShape::size() == 4,
"2D pooling requires 4D input tensor (N,H,W,C)");
103 static_assert(WindowShape::size() == 2,
"2D pooling requires 2D window shape (Y,X)");
131 const index_t MRaw = N * Ho * Wo * C;
136 auto reduce_op =
typename Problem::ReduceOp{};
163 const auto merged_embed_in_desc =
171 merged_embed_in_desc,
185 const auto out_desc_padded =
199 in_desc.get_element_space_size(),
201 const auto in_tensor_padded =
202 tensor_view<
decltype(in_buffer_view),
decltype(in_desc_padded)>{in_buffer_view,
207 out_desc.get_element_space_size(),
209 const auto out_tensor_padded =
210 tensor_view<
decltype(out_buffer_view),
decltype(out_desc_padded)>{out_buffer_view,
213 if constexpr(Problem::kOutputIndex)
217 out_desc.get_element_space_size(),
219 const auto out_index_tensor_padded =
220 tensor_view<
decltype(out_index_buffer_view),
decltype(out_desc_padded)>{
221 out_index_buffer_view, out_desc_padded};
223 return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
232 template <
typename TensorShape,
typename WindowShape>
235 using S =
typename Problem::BlockShape;
238 static_assert(TensorShape::size() == 5,
"3D pooling requires 5D input tensor (N,D,H,W,C)");
239 static_assert(WindowShape::size() == 3,
"3D pooling requires 3D window shape (Z,Y,X)");
274 const index_t MRaw = N * Do * Ho * Wo * C;
275 const index_t KRaw = Z * Y * X;
279 auto reduce_op =
typename Problem::ReduceOp{};
320 merged_embed_in_desc,
334 const auto out_desc_padded =
348 in_desc.get_element_space_size(),
350 const auto in_tensor_padded =
351 tensor_view<
decltype(in_buffer_view),
decltype(in_desc_padded)>{in_buffer_view,
356 out_desc.get_element_space_size(),
358 const auto out_tensor_padded =
359 tensor_view<
decltype(out_buffer_view),
decltype(out_desc_padded)>{out_buffer_view,
362 if constexpr(Problem::kOutputIndex)
366 out_desc.get_element_space_size(),
368 const auto out_index_tensor_padded =
369 tensor_view<
decltype(out_index_buffer_view),
decltype(out_desc_padded)>{
370 out_index_buffer_view, out_desc_padded};
372 return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
382 template <
typename TensorShape,
typename WindowShape>
385 using S =
typename Problem::BlockShape;
388 static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
389 "Only 2D and 3D pooling operations are supported");
394 auto [in_tensor_padded, out_tensor_padded, out_index_tensor_padded] = [&]() {
395 if constexpr(WindowShape::size() == 2)
397 else if constexpr(WindowShape::size() == 3)
400 static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
401 "Unsupported WindowShape rank: only 2D or 3D pooling is supported");
404 auto reduce_op =
typename Problem::ReduceOp{};
409 Policy::template MakeXBlockTileDistribution<Problem>());
412 __shared__
char smem[Policy::template GetSmemSize<Problem>()];
414 const auto reduce_len =
415 in_tensor_padded.get_tensor_descriptor().get_lengths().at(
number<1>{});
419 auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
420 auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
421 auto block_reduce2d_cross_warp = Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
423 using XTensorTile =
decltype(
load_tile(x_window));
424 auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
425 set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
427 if constexpr(Problem::kOutputIndex)
429 auto y_index_window =
433 block_reduce2d.template MakeYIndexBlockTile<XTensorTile, IndexDataType>();
440 auto index_calculator = [&](
const auto& x_indices) {
442 const auto global_M = x_indices.at(
number<0>{}) + iM;
443 const auto global_N = (k_tile * S::Block_N) + x_indices.at(
number<1>{});
444 return in_tensor_padded.get_tensor_descriptor().calculate_offset(
448 block_reduce2d(x_tile, y_tile, y_index_tile, reduce_op, index_calculator);
452 block_reduce2d_sync(y_tile, y_index_tile, reduce_op);
453 if constexpr(Problem::kNeedCrossWarpSync)
455 __shared__
char smem_indices[Policy::template GetIndicesSmemSize<Problem>()];
457 block_reduce2d_cross_warp(y_tile, y_index_tile, smem, smem_indices, reduce_op);
466 for(
int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
469 block_reduce2d(x_tile, y_tile, reduce_op);
473 block_reduce2d_sync(y_tile, reduce_op);
474 block_reduce2d_cross_warp(y_tile, smem, reduce_op);
490 template <
typename TensorShape,
typename WindowShape>
493 constexpr index_t InputRank = TensorShape::size();
494 constexpr index_t OutputRank = TensorShape::size();
495 constexpr index_t WindowRank = WindowShape::size();
498 if constexpr(WindowRank != 2 && WindowRank != 3)
508 if constexpr((WindowRank == 2 && InputRank != 4) || (WindowRank == 3 && InputRank != 5))
512 CK_TILE_ERROR(
"Input tensor rank doesn't match window dimensions!");
522 CK_TILE_ERROR(
"Input tensor's channel dimension must have stride 1!");
531 CK_TILE_ERROR(
"Output tensor's channel dimension must have stride 1!");
541 template <
typename TensorShape,
typename WindowShape>
545 using S =
typename Problem::BlockShape;
552 return (M + S::Block_M - 1) / S::Block_M;
556 template <
typename TensorShape,
typename WindowShape>
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
bool EnvIsEnabled(EnvVar)
Definition tile/core/utility/env.hpp:156
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition coordinate_transform.hpp:1565
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE index_t get_block_id()
Definition arch.hpp:119
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST_DEVICE constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition coordinate_transform.hpp:1584
CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T *__restrict__ p, BufferSizeType buffer_size)
Definition buffer_view.hpp:1262
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition coordinate_transform.hpp:1594
Host arguments for pooling operations.
Definition pool_kernel.hpp:16
TensorShape input_strides
Definition pool_kernel.hpp:51
void * output_ptr
Definition pool_kernel.hpp:46
WindowShape input_left_pads
Definition pool_kernel.hpp:56
const void * input_ptr
Definition pool_kernel.hpp:45
WindowShape window_lengths
Definition pool_kernel.hpp:53
WindowShape window_strides
Definition pool_kernel.hpp:54
TensorShape input_shape
Definition pool_kernel.hpp:49
TensorShape output_strides
Definition pool_kernel.hpp:52
CK_TILE_HOST PoolHostArgs(const void *input_ptr_, void *output_ptr_, void *output_index_ptr_, TensorShape input_shape_, TensorShape output_shape_, TensorShape input_strides_, TensorShape output_strides_, WindowShape window_lengths_, WindowShape window_strides_, WindowShape window_dilations_, WindowShape input_left_pads_, WindowShape input_right_pads_)
Definition pool_kernel.hpp:18
TensorShape output_shape
Definition pool_kernel.hpp:50
WindowShape input_right_pads
Definition pool_kernel.hpp:57
WindowShape window_dilations
Definition pool_kernel.hpp:55
void * output_index_ptr
Definition pool_kernel.hpp:47
Kernel arguments for pooling operations.
Definition pool_kernel.hpp:63
TensorShape output_shape
Definition pool_kernel.hpp:68
WindowShape input_right_pads
Definition pool_kernel.hpp:75
WindowShape window_lengths
Definition pool_kernel.hpp:71
WindowShape window_dilations
Definition pool_kernel.hpp:73
TensorShape input_strides
Definition pool_kernel.hpp:69
const void * input_ptr
Definition pool_kernel.hpp:64
WindowShape input_left_pads
Definition pool_kernel.hpp:74
TensorShape input_shape
Definition pool_kernel.hpp:67
WindowShape window_strides
Definition pool_kernel.hpp:72
void * output_ptr
Definition pool_kernel.hpp:65
TensorShape output_strides
Definition pool_kernel.hpp:70
void * output_index_ptr
Definition pool_kernel.hpp:66
Definition pool_kernel.hpp:80
static CK_TILE_HOST constexpr index_t CalculateGridSize(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition pool_kernel.hpp:543
ck_tile::remove_cvref_t< Policy_ > Policy
Definition pool_kernel.hpp:82
ck_tile::remove_cvref_t< typename Problem::OutDataType > OutDataType
Definition pool_kernel.hpp:86
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition pool_kernel.hpp:85
static constexpr index_t kBlockSize
Definition pool_kernel.hpp:89
static CK_TILE_HOST bool IsSupportedArgument(PoolKernelArgs< TensorShape, WindowShape > kargs)
Validates if the given arguments are supported by the pooling kernel.
Definition pool_kernel.hpp:491
static CK_TILE_HOST constexpr auto BlockSize()
Definition pool_kernel.hpp:91
static CK_TILE_DEVICE auto MakeTensorView2D(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition pool_kernel.hpp:97
static CK_TILE_DEVICE auto MakeTensorView3D(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition pool_kernel.hpp:233
ck_tile::remove_cvref_t< typename Problem::InDataType > InDataType
Definition pool_kernel.hpp:84
ck_tile::remove_cvref_t< typename Problem::IndexDataType > IndexDataType
Definition pool_kernel.hpp:87
CK_TILE_DEVICE void operator()(PoolKernelArgs< TensorShape, WindowShape > kargs) const
Definition pool_kernel.hpp:383
static CK_TILE_HOST constexpr auto MakeKernelArgs(PoolHostArgs< TensorShape, WindowShape > &host_args)
Create kernel arguments from host arguments.
Definition pool_kernel.hpp:558
ck_tile::remove_cvref_t< Problem_ > Problem
Definition pool_kernel.hpp:81
Definition null_tensor.hpp:9
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tensor_view.hpp:41
#define CK_TILE_ENV(name)
Definition tile/core/utility/env.hpp:145