32template <
typename... Ts,
typename... Ls>
33__host__ __device__
constexpr auto CalculateLocalPartitionShape(
const Tuple<Ts...>&
shape,
34 const Tuple<Ls...>& thread_lengths)
36 static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(),
"Wrong thread_lengths shape.");
39 constexpr auto num_i = Number<i>{};
40 const auto slice_len =
44 Number<Tuple<Ls...>::Size()>{});
56template <
typename MultiIndex,
typename ProjectionTuple>
57__host__ __device__
constexpr auto
58ApplyProjection([[maybe_unused]]
const MultiIndex& base_tuple,
59 [[maybe_unused]]
const ProjectionTuple& projection)
61 if constexpr(is_same_v<ProjectionTuple, Tuple<>>)
71 is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>
::value ||
72 is_same_v<tuple_element_t<i_num, ProjectionTuple>, Number<1>>);
73 if constexpr(is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>
::value)
84 Number<MultiIndex::Size()>{});
86 return UnrollNestedTuple<0, 1>(base_tuple_after_projection);
99template <
typename... Ts,
typename... Ps>
100__host__ __device__
constexpr auto CalculateShapeWithProjection(
const Tuple<Ts...>&
shape,
101 const Tuple<Ps...>& projection)
107 return size<i>(projection).to_;
114 detail::ApplyProjection(TupleSlice<0, i>(Tuple<Ts...>{}),
115 TupleSlice<0, i>(Tuple<Ps...>{}))
117 return size<shape_i>(
shape);
120 Number<Tuple<Ps...>::Size()>{});
130template <
typename... Ts,
typename... Ls,
typename... Ps>
131__host__ __device__
constexpr auto CalculateGridSize(
const Tuple<Ts...>&
shape,
132 const Tuple<Ls...>& tile_shape)
136 Number<Tuple<Ls...>::Size()>{});
147template <
typename ThreadIdxs,
typename PartitionLengthsSeq,
typename OldOffsetIdxs>
148__host__ __device__
constexpr auto
149CalculateOffsetMultiIdxs(
const ThreadIdxs& thread_idxs,
150 const PartitionLengthsSeq& partition_lengths_seq,
151 const OldOffsetIdxs& old_offset_idxs)
153 return thread_idxs * partition_lengths_seq + old_offset_idxs;
162template <
typename BlockIdxs>
163__host__ __device__
constexpr auto GetDimsToPartition([[maybe_unused]]
const BlockIdxs& block_idxs)
167 if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>
::value)
176 Number<BlockIdxs::Size()>{});
178 return UnrollNestedTuple<0, 1>(dims_to_partition);
187template <
typename BlockIdxs>
188__host__ __device__
constexpr auto ReplaceSlicesWithZeros(
const BlockIdxs& block_idxs)
192 if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>
::value)
194 return block_idxs.At(i);
201 Number<BlockIdxs::Size()>{});
210template <
typename TileShape>
211__host__ __device__
constexpr auto
212GenerateDefaultProjection([[maybe_unused]]
const TileShape tile_shape)
224template <
typename ThreadShape,
typename ThreadUnrolledDesc>
225__host__ __device__
constexpr auto CalculateThreadMultiIdx(
226 [[maybe_unused]]
const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
227 const index_t thread_id)
229 static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1,
230 "Thread layout should not be transformed.");
231 constexpr auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{});
232 constexpr auto shape = ThreadShape{};
233 constexpr auto strides = embed_transform.coefficients_;
237 constexpr auto num_i = Number<i>{};
238 return (thread_id / strides.At(num_i)) %
shape.At(num_i);
240 Number<ThreadShape::Size()>{});
258template <
typename TensorType,
259 typename ThreadShape,
260 typename ThreadUnrolledDesc,
261 typename ProjectionTuple>
262__host__ __device__
constexpr auto
265 const index_t thread_id,
266 const ProjectionTuple& projection)
268 static_assert(!IsNestedTuple(ThreadShape{}));
270 const auto& tensor_shape =
shape(tensor);
272 constexpr auto projected_thread_lengths =
273 detail::ApplyProjection(ThreadShape{}, ProjectionTuple{});
274 constexpr auto partition_shape =
275 detail::CalculateLocalPartitionShape(
decltype(tensor_shape){}, projected_thread_lengths);
276 constexpr auto partition_shape_seq =
277 generate_sequence_v2([&](
auto I) {
return size<I>(partition_shape); },
278 Number<
decltype(partition_shape)::Size()>{});
280 const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id);
282 const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
283 const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
284 projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
286 auto& unrolled_desc =
layout(tensor).GetUnrolledDescriptor();
288 const auto transforms = generate_tuple(
290 return make_slice_transform(partition_shape.At(i),
291 offset_multi_idxs.At(i),
292 partition_shape.At(i) + offset_multi_idxs.At(i));
294 Number<remove_reference_t<
decltype(tensor_shape)>::Size()>{});
295 const auto lower_upper_dims =
296 generate_tuple([&](
auto i) {
return Sequence<i.value>{}; },
297 Number<remove_reference_t<
decltype(tensor_shape)>::Size()>{});
299 transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims);
301 const auto partition_layout =
302 Layout<remove_reference_t<
decltype(partition_shape)>,
decltype(sliced_desc)>(
303 partition_shape, sliced_desc);
304 auto partition_tensor =
307 return partition_tensor;
319template <
typename TensorType,
typename ThreadShape,
typename ThreadUnrolledDesc>
320__host__ __device__
constexpr auto
323 const index_t thread_id)
325 const auto projection = detail::GenerateDefaultProjection(ThreadShape{});
346template <
typename TensorType,
347 typename BlockShapeTuple,
349 typename ProjectionTuple>
351 const BlockShapeTuple& tile_shape,
352 const BlockIdxs& block_idxs,
353 const ProjectionTuple& projection)
355 static_assert(!IsNestedTuple(BlockShapeTuple{}));
356 static_assert(!IsNestedTuple(BlockIdxs{}));
358 constexpr auto I0 = Number<0>{};
359 constexpr auto I1 = Number<1>{};
360 constexpr auto I2 = Number<2>{};
362 auto& aligned_desc =
layout(tensor).GetMergedNestingDescriptor();
364 constexpr auto projected_tile_shape =
365 detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
367 constexpr auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{});
368 const auto parsed_block_idxs = detail::ReplaceSlicesWithZeros(block_idxs);
369 if constexpr(
decltype(dims_to_partition)::Size() == I2)
371 const auto shape_with_projection_dims =
372 detail::CalculateShapeWithProjection(
shape(tensor), projection);
374 const auto M = shape_with_projection_dims.At(dims_to_partition.At(I0));
375 const auto N = shape_with_projection_dims.At(dims_to_partition.At(I1));
376 constexpr auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0));
377 constexpr auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1));
378 auto m_n_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N));
380 const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape);
381 const auto block_lengths_desc = make_naive_tensor_descriptor_packed(grid_size);
382 const index_t block_id_1d = block_lengths_desc.CalculateOffset(parsed_block_idxs);
384 const auto block_2_tile_map =
385 BlockToCTileMap_M00_N0_M01Adapt<MPerBlock,
387 remove_cvref_t<
decltype(m_n_desc)>>(m_n_desc);
388 const auto block_work_idx =
389 block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id_1d));
390 const index_t m_block_data_idx_on_grid =
391 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
392 const index_t n_block_data_idx_on_grid =
393 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
395 const auto offset_multi_idxs = generate_tuple(
397 if constexpr(i == dims_to_partition.At(I0))
399 return m_block_data_idx_on_grid;
401 else if constexpr(i == dims_to_partition.At(I1))
403 return n_block_data_idx_on_grid;
410 Number<BlockShapeTuple::Size()>{});
411 const auto projected_offset_multi_idxs =
412 detail::ApplyProjection(offset_multi_idxs, projection);
414 const auto tile_layout =
415 Layout<remove_reference_t<
decltype(projected_tile_shape)>,
decltype(aligned_desc)>(
416 projected_tile_shape, aligned_desc);
420 tile_tensor.SetMultiIdxOffset(to_multi_index(projected_offset_multi_idxs));
427 using ProjectedTileShapeTuple =
decltype(projected_tile_shape);
428 constexpr auto projected_tile_shape_seq =
429 generate_sequence_v2([](
auto I) {
return ProjectedTileShapeTuple{}.At(I); },
430 Number<ProjectedTileShapeTuple::Size()>{});
432 const auto projected_block_idxs =
433 to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection));
434 const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
435 projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
437 const auto tile_layout =
439 projected_tile_shape, aligned_desc);
443 tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
460template <
typename TensorType,
typename BlockShapeTuple,
typename BlockIdxs>
462 const BlockShapeTuple& tile_shape,
463 const BlockIdxs& block_idxs)
465 const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
__host__ __device__ constexpr const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition layout_utils.hpp:431
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
int32_t index_t
Definition ck.hpp:299
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24
static constexpr T value
Definition utility/integral_constant.hpp:21
__host__ __device__ constexpr auto make_local_partition(TensorType &tensor, const Layout< ThreadShape, ThreadUnrolledDesc > &thread_layout, const index_t thread_id, const ProjectionTuple &projection)
Create local partition for thread (At now only packed partition is supported).
Definition tensor_partition.hpp:263
__host__ __device__ constexpr auto make_local_tile(const TensorType &tensor, const BlockShapeTuple &tile_shape, const BlockIdxs &block_idxs, const ProjectionTuple &projection)
Create local tile for thread block. (At now only packed tile is supported).
Definition tensor_partition.hpp:350
__host__ __device__ constexpr const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition tensor_utils.hpp:162
constexpr auto make_tensor(ElementType *pointer, const Layout< Shape, UnrolledDescriptorType > &layout)
Make tensor function.
Definition tensor_utils.hpp:112