device_gemm_xdl_cshuffle.hpp Source File

device_gemm_xdl_cshuffle.hpp Source File#

Composable Kernel: device_gemm_xdl_cshuffle.hpp Source File
device_gemm_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
24// version currently has compiler issues with register spill which further causes validation
25// failures.
26template <typename ALayout,
27 typename BLayout,
28 typename CLayout,
29 typename ADataType,
30 typename BDataType,
31 typename CDataType,
32 typename GemmAccDataType,
33 typename CShuffleDataType,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CElementwiseOperation,
37 GemmSpecialization GemmSpec,
38 index_t NumGemmKPrefetchStage,
39 index_t BlockSize,
40 index_t MPerBlock,
41 index_t NPerBlock,
42 index_t KPerBlock,
43 index_t AK1,
44 index_t BK1,
45 index_t MPerXDL,
46 index_t NPerXDL,
47 index_t MXdlPerWave,
48 index_t NXdlPerWave,
49 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
50 typename ABlockTransferThreadClusterArrangeOrder,
51 typename ABlockTransferSrcAccessOrder,
52 index_t ABlockTransferSrcVectorDim,
53 index_t ABlockTransferSrcScalarPerVector,
54 index_t ABlockTransferDstScalarPerVector_AK1,
55 bool ABlockLdsExtraM,
56 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
57 typename BBlockTransferThreadClusterArrangeOrder,
58 typename BBlockTransferSrcAccessOrder,
59 index_t BBlockTransferSrcVectorDim,
60 index_t BBlockTransferSrcScalarPerVector,
61 index_t BBlockTransferDstScalarPerVector_BK1,
62 bool BBlockLdsExtraN,
63 index_t CShuffleMXdlPerWavePerShuffle,
64 index_t CShuffleNXdlPerWavePerShuffle,
65 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
66 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
69 typename ComputeTypeA = CDataType,
70 typename ComputeTypeB = ComputeTypeA>
71struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
72 BLayout,
73 CLayout,
74 ADataType,
75 BDataType,
76 CDataType,
77 AElementwiseOperation,
78 BElementwiseOperation,
79 CElementwiseOperation>
80{
82
84 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
85 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
86
87 static constexpr auto I0 = Number<0>{};
88 static constexpr auto I1 = Number<1>{};
89 static constexpr auto I2 = Number<2>{};
90
91 // GridwiseGemm
92 template <index_t NXdlPerWave_>
94 ALayout,
95 BLayout,
96 CLayout,
97 ADataType,
98 BDataType,
99 GemmAccDataType,
100 CShuffleDataType,
101 CDataType,
102 AElementwiseOperation,
103 BElementwiseOperation,
104 CElementwiseOperation,
105 GemmSpec,
107 NumGemmKPrefetchStage,
108 BlockSize,
109 MPerBlock,
110 NPerBlock,
111 KPerBlock,
112 AK1,
113 BK1,
114 MPerXDL,
115 NPerXDL,
116 MXdlPerWave,
117 NXdlPerWave_,
118 ABlockTransferThreadClusterLengths_AK0_M_AK1,
119 ABlockTransferThreadClusterArrangeOrder,
120 ABlockTransferSrcAccessOrder,
121 ABlockTransferSrcVectorDim,
122 ABlockTransferSrcScalarPerVector,
123 ABlockTransferDstScalarPerVector_AK1,
124 false,
125 ABlockLdsExtraM,
126 BBlockTransferThreadClusterLengths_BK0_N_BK1,
127 BBlockTransferThreadClusterArrangeOrder,
128 BBlockTransferSrcAccessOrder,
129 BBlockTransferSrcVectorDim,
130 BBlockTransferSrcScalarPerVector,
131 BBlockTransferDstScalarPerVector_BK1,
132 false,
133 BBlockLdsExtraN,
134 CShuffleMXdlPerWavePerShuffle,
135 CShuffleNXdlPerWavePerShuffle,
136 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
137 CShuffleBlockTransferScalarPerVector_NPerBlock,
138 LoopSched,
139 PipelineVer,
140 ComputeTypeA,
141 ComputeTypeB>;
144
145 using Argument = typename GridwiseGemm64::Argument;
146
147 // Invoker
148 struct Invoker : public BaseInvoker
149 {
150 template <typename GridwiseGemm>
151 float RunImp(const typename GridwiseGemm::Argument& arg,
152 const StreamConfig& stream_config = StreamConfig{})
153 {
154 if(stream_config.log_level_ > 0)
155 {
156 arg.Print();
157 }
158
159 if(!GridwiseGemm::CheckValidity(arg))
160 {
161 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
162 }
163
164 index_t gdx, gdy, gdz;
165 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
166
167 const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
168
169 float ave_time = 0;
170
171 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
172 {
174
175 ave_time = launch_and_time_kernel(
176 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
177 }
178 else
179 {
181
182 ave_time = launch_and_time_kernel(
183 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
184 }
185
186 return ave_time;
187 }
188
190
191 // polymorphic
192 float Run(const BaseArgument* p_arg,
193 const StreamConfig& stream_config = StreamConfig{}) override
194 {
195 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
196 }
197 };
198
199 static constexpr bool IsValidCompilationParameter()
200 {
201 // TODO: properly implement this check
202 return true;
203 }
204
205 static bool IsSupportedArgument(const Argument& arg)
206 {
208 {
209 return false;
210 }
211 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
212 GemmSpec == GemmSpecialization::NKPadding ||
213 GemmSpec == GemmSpecialization::MNKPadding ||
214 GemmSpec == GemmSpecialization::KPadding))
215 {
216 return false;
217 }
218
219 if(get_warp_size() == 64)
220 {
221 if constexpr(NXdlPerWave64 > 0)
222 {
224 }
225 }
226 else
227 {
228 if constexpr(NXdlPerWave32 > 0)
229 {
231 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
232 }
233 }
234 return false;
235 }
236
237 // polymorphic
238 bool IsSupportedArgument(const BaseArgument* p_arg) override
239 {
240 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
241 }
242
243 static auto MakeArgument(const ADataType* p_a,
244 const BDataType* p_b,
245 CDataType* p_c,
246 index_t M,
247 index_t N,
248 index_t K,
249 index_t StrideA,
250 index_t StrideB,
251 index_t StrideC,
252 AElementwiseOperation,
253 BElementwiseOperation,
254 CElementwiseOperation)
255 {
256 return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
257 }
258
259 static auto MakeInvoker() { return Invoker{}; }
260
261 // polymorphic
262 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
263 const void* p_b,
264 void* p_c,
265 index_t M,
266 index_t N,
267 index_t K,
268 index_t StrideA,
269 index_t StrideB,
270 index_t StrideC,
271 AElementwiseOperation,
272 BElementwiseOperation,
273 CElementwiseOperation) override
274 {
275 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
276 static_cast<const BDataType*>(p_b),
277 static_cast<CDataType*>(p_c),
278 M,
279 N,
280 K,
281 StrideA,
282 StrideB,
283 StrideC);
284 }
285
286 // polymorphic
287 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
288 {
289 return std::make_unique<Invoker>(Invoker{});
290 }
291
292 // polymorphic
293 std::string GetTypeString() const override
294 {
295 auto str = std::stringstream();
296
297 std::map<LoopScheduler, std::string> LoopSchedToString{
298 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
299
300 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
301 {PipelineVersion::v2, "v2"}};
302
303 // clang-format off
304 str << "DeviceGemm_Xdl_CShuffle"
305 << "<"
306 << getGemmSpecializationString(GemmSpec) << ", "
307 << BlockSize << ", "
308 << MPerBlock << ", "
309 << NPerBlock << ", "
310 << KPerBlock << ", "
311 << AK1 << ", "
312 << BK1 << ", "
313 << MPerXDL << ", "
314 << NPerXDL << ", "
315 << MXdlPerWave << ", "
316 << NXdlPerWave << ", "
317 << ABlockTransferSrcScalarPerVector << ", "
318 << BBlockTransferSrcScalarPerVector << ", "
319 << CShuffleMXdlPerWavePerShuffle << ", "
320 << CShuffleNXdlPerWavePerShuffle
321 << ">"
322 << " LoopScheduler: "
323 << LoopSchedToString[LoopSched] << ", "
324 << "PipelineVersion: "
325 << PipelineVersionToString[PipelineVer];
326 // clang-format on
327
328 return str.str();
329 }
330};
331
332} // namespace device
333} // namespace tensor_operation
334} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__global__ void kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:25
integral_constant< index_t, N > Number
Definition number.hpp:12
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_v1.hpp:121
Definition device_base.hpp:197
Definition device_gemm_xdl_cshuffle.hpp:149
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle.hpp:151
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle.hpp:192
Definition device_gemm_xdl_cshuffle.hpp:80
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle.hpp:287
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle.hpp:238
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition device_gemm_xdl_cshuffle.hpp:243
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle.hpp:293
DeviceGemm_Xdl_CShuffle DeviceOp
Definition device_gemm_xdl_cshuffle.hpp:81
static constexpr auto I1
Definition device_gemm_xdl_cshuffle.hpp:88
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle.hpp:205
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_cshuffle.hpp:199
static constexpr auto I0
Definition device_gemm_xdl_cshuffle.hpp:87
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition device_gemm_xdl_cshuffle.hpp:262
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_gemm_xdl_cshuffle.hpp:93
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle.hpp:259
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle.hpp:85
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_cshuffle.hpp:145
static constexpr auto I2
Definition device_gemm_xdl_cshuffle.hpp:89
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_cshuffle.hpp:143
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle.hpp:84
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_cshuffle.hpp:142
Definition device_gemm.hpp:22