thread_group_tensor_slice_transfer_v6r2.hpp Source File

thread_group_tensor_slice_transfer_v6r2.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_v6r2.hpp Source File
thread_group_tensor_slice_transfer_v6r2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14// this version does following things to avoid scratch memory issue
15// 1. Use StaticallyIndexedArray instead of C array for thread buffer
16// 2. It does not keep reference to tensor descriptor
17// 3. Run() does not construct new tensor coordinate
18template <typename ThreadGroup,
19 typename ElementwiseOperation,
21 typename SliceLengths,
22 typename ThreadClusterLengths,
23 typename ThreadClusterArrangeOrder,
24 typename Src0Data,
25 typename Src1Data,
26 typename DstData,
27 typename Src0Desc,
28 typename Src1Desc,
29 typename DstDesc,
30 typename DimAccessOrder,
31 index_t VectorDim,
32 index_t ScalarPerVector,
33 bool ThreadTransferSrc0ResetCoordinateAfterRun,
34 bool ThreadTransferSrc1ResetCoordinateAfterRun,
35 bool ThreadTransferDstResetCoordinateAfterRun>
37{
39
40 static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
41
43
44 __device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
45 const Index& src0_block_slice_origin,
46 const Src1Desc& src1_desc,
47 const Index& src1_block_slice_origin,
48 const DstDesc& dst_desc,
49 const Index& dst_block_slice_origin,
50 const ElementwiseOperation& element_op)
51 : threadwise_transfer_(src0_desc,
53 src1_desc,
55 dst_desc,
57 element_op)
58
59 {
63 nDim == ThreadClusterLengths::Size() &&
64 nDim == ThreadClusterArrangeOrder::Size() &&
65 nDim == DimAccessOrder::Size(),
66 "wrong! nDim not consistent");
67
68 static_assert(
69 is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
70 "wrong! threads should be mapped to cover entire slicing window");
71
72 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
73 "wrong! ThreadGroup::GetNumOfThread() too small");
74
75 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
76 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
77 {
78 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
79 make_multi_index(ThreadGroup::GetThreadId()));
80
81 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
82
83 threadwise_transfer_.SetSrc0SliceOrigin(
84 src0_desc, src0_block_slice_origin + thread_data_idx_begin);
85 threadwise_transfer_.SetSrc1SliceOrigin(
86 src1_desc, src1_block_slice_origin + thread_data_idx_begin);
87 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
88 dst_block_slice_origin + thread_data_idx_begin);
89 }
90 }
91
92 template <typename Src0Buffer, typename Src1Buffer, typename DstBuffer>
93 __device__ void Run(const Src0Desc& src0_desc,
94 const Src0Buffer& src0_buf,
95 const Src1Desc& src1_desc,
96 const Src1Buffer& src1_buf,
97 const DstDesc& dst_desc,
98 DstBuffer& dst_buf)
99 {
100 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
101 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
102 {
103 threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf);
104 }
105 }
106
107 __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
108 {
109 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
110 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
111 {
112 threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
113 }
114 }
115
116 __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
117 {
118 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
119 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
120 {
121 threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
122 }
123 }
124
125 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
126 {
127 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
128 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
129 {
130 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
131 }
132 }
133
134 private:
135 static constexpr auto thread_cluster_desc_ =
136 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
137
138 using ThreadwiseTransfer =
139 ThreadwiseTensorSliceTransfer_v6r2<Src0Data,
140 Src1Data,
141 DstData,
142 Src0Desc,
143 Src1Desc,
144 DstDesc,
145 ElementwiseOperation,
146 decltype(thread_slice_lengths),
147 DimAccessOrder,
148 VectorDim,
149 ScalarPerVector,
150 DstInMemOp,
151 ThreadTransferSrc0ResetCoordinateAfterRun,
152 ThreadTransferSrc1ResetCoordinateAfterRun,
153 ThreadTransferDstResetCoordinateAfterRun>;
154
155 ThreadwiseTransfer threadwise_transfer_;
156};
157
158} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r2.hpp:125
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v6r2.hpp:42
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v6r2.hpp:38
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v6r2.hpp:40
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc &src0_desc, const Index &src0_block_slice_origin, const Src1Desc &src1_desc, const Index &src1_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v6r2.hpp:44
__device__ void MoveSrc1SliceWindow(const Src1Desc &src1_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r2.hpp:116
__device__ void MoveSrc0SliceWindow(const Src0Desc &src0_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r2.hpp:107
__device__ void Run(const Src0Desc &src0_desc, const Src0Buffer &src0_buf, const Src1Desc &src1_desc, const Src1Buffer &src1_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition thread_group_tensor_slice_transfer_v6r2.hpp:93
Definition type.hpp:177