device_gemm_multiple_abd_wmma_cshuffle_v3.hpp Source File

device_gemm_multiple_abd_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: device_gemm_multiple_abd_wmma_cshuffle_v3.hpp Source File
device_gemm_multiple_abd_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
32// operations that could be applied on each tensor respectively. The CDE_op is an
33// elementwise operation applied to the C and all D tensors.
129template <typename AsLayout,
130 typename BsLayout,
131 typename DsLayout,
132 typename ELayout,
133 typename AsDataType,
134 typename BsDataType,
135 typename AccDataType,
136 typename CShuffleDataType,
137 typename DsDataType,
138 typename EDataType,
139 typename AElementwiseOperation,
140 typename BElementwiseOperation,
141 typename CDEElementwiseOperation,
142 GemmSpecialization GemmSpec,
143 index_t BlockSize,
144 index_t MPerBlock,
145 index_t NPerBlock,
146 index_t KPerBlock,
147 index_t AK1,
148 index_t BK1,
149 index_t MPerWmma,
150 index_t NPerWmma,
151 index_t MRepeat,
152 index_t NRepeat,
153 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
154 typename ABlockTransferThreadClusterArrangeOrder,
155 typename ABlockTransferSrcAccessOrder,
156 index_t ABlockTransferSrcVectorDim,
157 index_t ABlockTransferSrcScalarPerVector,
158 index_t ABlockTransferDstScalarPerVector_AK1,
159 bool ABlockLdsExtraM,
160 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
161 typename BBlockTransferThreadClusterArrangeOrder,
162 typename BBlockTransferSrcAccessOrder,
163 index_t BBlockTransferSrcVectorDim,
164 index_t BBlockTransferSrcScalarPerVector,
165 index_t BBlockTransferDstScalarPerVector_BK1,
166 bool BBlockLdsExtraN,
167 index_t CShuffleMRepeatPerShuffle,
168 index_t CShuffleNRepeatPerShuffle,
169 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
170 typename CDEShuffleBlockTransferScalarPerVectors,
173 typename ComputeTypeA = EDataType,
174 typename ComputeTypeB = ComputeTypeA,
175 bool PermuteA = false,
176 bool PermuteB = false>
178 : public DeviceGemmMultipleABDSplitK<AsLayout,
179 BsLayout,
180 DsLayout,
181 ELayout,
182 AsDataType,
183 BsDataType,
184 DsDataType,
185 EDataType,
186 AElementwiseOperation,
187 BElementwiseOperation,
188 CDEElementwiseOperation>
189{
190 // Note: Pass multiple layout but then using only the first one
191 // This is to replicate xdl functionality but it should be extended
194
196 ALayout,
197 BLayout,
198 DsLayout,
199 ELayout,
200 AsDataType,
201 BsDataType,
202 AccDataType,
203 CShuffleDataType,
204 DsDataType,
205 EDataType,
206 AElementwiseOperation,
207 BElementwiseOperation,
208 CDEElementwiseOperation,
209 GemmSpec,
210 BlockSize,
211 MPerBlock,
212 NPerBlock,
213 KPerBlock,
214 AK1,
215 BK1,
216 MPerWmma,
217 NPerWmma,
218 MRepeat,
219 NRepeat,
220 ABlockTransferThreadClusterLengths_AK0_M_AK1,
221 ABlockTransferThreadClusterArrangeOrder,
222 ABlockTransferSrcAccessOrder,
223 ABlockTransferSrcVectorDim,
224 ABlockTransferSrcScalarPerVector,
225 ABlockTransferDstScalarPerVector_AK1,
226 false,
227 ABlockLdsExtraM,
228 BBlockTransferThreadClusterLengths_BK0_N_BK1,
229 BBlockTransferThreadClusterArrangeOrder,
230 BBlockTransferSrcAccessOrder,
231 BBlockTransferSrcVectorDim,
232 BBlockTransferSrcScalarPerVector,
233 BBlockTransferDstScalarPerVector_BK1,
234 false,
235 BBlockLdsExtraN,
236 CShuffleMRepeatPerShuffle,
237 CShuffleNRepeatPerShuffle,
238 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
239 CDEShuffleBlockTransferScalarPerVectors,
240 BlkGemmPipeSched,
241 BlkGemmPipelineVer,
242 ComputeTypeA,
243 ComputeTypeB,
244 PermuteA,
245 PermuteB>;
246
247 using Argument = typename GridwiseGemm::Argument;
248
251 AsDataType,
252 BsDataType,
253 DsDataType,
254 EDataType,
255 MPerBlock,
256 NPerBlock,
257 KPerBlock,
258 BlockSize,
259 AK1,
260 BK1,
261 GemmSpec,
262 CDEShuffleBlockTransferScalarPerVectors,
263 BlkGemmPipeSched,
264 BlkGemmPipelineVer,
265 ComputeTypeA,
266 ComputeTypeB>;
267
268 // Invoker
270
271 static bool IsSupportedArgument(const Argument& arg)
272 {
274 }
275
276 // polymorphic
277 bool IsSupportedArgument(const BaseArgument* p_arg) override
278 {
279 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
280 }
281
282 static auto MakeArgument(std::array<const void*, GridwiseGemm::NumATensor> p_as,
283 std::array<const void*, GridwiseGemm::NumBTensor> p_bs,
284 std::array<const void*, GridwiseGemm::NumDTensor> p_ds,
285 void* p_e,
286 index_t M,
287 index_t N,
288 index_t K,
289 std::array<ck::index_t, GridwiseGemm::NumATensor> StrideAs,
290 std::array<ck::index_t, GridwiseGemm::NumBTensor> StrideBs,
291 std::array<index_t, GridwiseGemm::NumDTensor> StrideDs,
292 index_t StrideE,
293 index_t KBatch,
294 AElementwiseOperation a_element_op,
295 BElementwiseOperation b_element_op,
296 CDEElementwiseOperation cde_element_op)
297 {
298 return Argument{p_as,
299 p_bs,
300 p_ds,
301 static_cast<EDataType*>(p_e),
302 M,
303 N,
304 K,
305 StrideAs,
306 StrideBs,
307 StrideDs,
308 StrideE,
309 KBatch,
310 a_element_op,
311 b_element_op,
312 cde_element_op};
313 }
314
315 static auto MakeInvoker() { return Invoker{}; }
316
317 // polymorphic
318 std::unique_ptr<BaseArgument>
319 MakeArgumentPointer(std::array<const void*, GridwiseGemm::NumATensor> p_as,
320 std::array<const void*, GridwiseGemm::NumBTensor> p_bs,
321 std::array<const void*, GridwiseGemm::NumDTensor> p_ds,
322 void* p_e,
323 index_t M,
324 index_t N,
325 index_t K,
326 std::array<ck::index_t, GridwiseGemm::NumATensor> StrideAs,
327 std::array<ck::index_t, GridwiseGemm::NumBTensor> StrideBs,
328 std::array<ck::index_t, GridwiseGemm::NumDTensor> StrideDs,
329 index_t StrideE,
330 index_t KBatch,
331 AElementwiseOperation a_element_op,
332 BElementwiseOperation b_element_op,
333 CDEElementwiseOperation cde_element_op) override
334 {
335 return std::make_unique<Argument>(p_as,
336 p_bs,
337 p_ds,
338 static_cast<EDataType*>(p_e),
339 M,
340 N,
341 K,
342 StrideAs,
343 StrideBs,
344 StrideDs,
345 StrideE,
346 KBatch,
347 a_element_op,
348 b_element_op,
349 cde_element_op);
350 }
351
352 // polymorphic
353 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
354 {
355 return std::make_unique<Invoker>(Invoker{});
356 }
357
358 // polymorphic
359 std::string GetTypeString() const override
360 {
361 auto str = std::stringstream();
362
363 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
366
367 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
373
374 // clang-format off
375 str << "DeviceGemmMultipleABD_Wmma_CShuffleV3"
376 << "<"
377 << getGemmSpecializationString(GemmSpec) << ", ";
379 using ALayout_ = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
380
381 str << std::string(ALayout_::name)[0];
382 });
384 using BLayout_ = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
385
386 str << std::string(BLayout_::name)[0];
387 });
389 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
390
391 str << std::string(DLayout::name)[0];
392 });
393 str << std::string(ELayout::name)[0]
394 << ">"
395 << " BlkSize: "
396 << BlockSize << ", "
397 << "BlkTile: "
398 << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
399 << "WaveTile: "
400 << MPerWmma << "x"<<NPerWmma << ", "
401 << "WaveMap: "
402 << MRepeat << "x" << NRepeat << ", "
403 << "VmemReadVec: "
404 << ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
405 << "BlkGemmPipelineScheduler: "
406 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
407 << "BlkGemmPipelineVersion: "
408 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
409 << "BlkGemmPipelinePrefetchStages: "
410 << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
411 << "KPack: "
413 // clang-format on
414
415 return str.str();
416 }
418};
419
420} // namespace device
421} // namespace tensor_operation
422} // namespace ck
#define REGISTER_EXTRA_PRINTING_METHODS
Definition device_base.hpp:47
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:233
Definition functional2.hpp:33
Definition device_base.hpp:197
Helper structure responsible for kernel invocation.
Definition device_gemm_wmma_cshuffle_v3_common.hpp:57
Definition device_gemm_wmma_cshuffle_v3_common.hpp:43
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_wmma_cshuffle_v3_common.hpp:268
"Universal" GEMM operation with SplitK support and multiple D tensors.
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:189
remove_cvref_t< tuple_element_t< 0, BsLayout > > BLayout
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:193
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:271
static auto MakeArgument(std::array< const void *, GridwiseGemm::NumATensor > p_as, std::array< const void *, GridwiseGemm::NumBTensor > p_bs, std::array< const void *, GridwiseGemm::NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, std::array< ck::index_t, GridwiseGemm::NumATensor > StrideAs, std::array< ck::index_t, GridwiseGemm::NumBTensor > StrideBs, std::array< index_t, GridwiseGemm::NumDTensor > StrideDs, index_t StrideE, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:282
static auto MakeInvoker()
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:315
DeviceGemm_Wmma_CShuffleV3_Common< GridwiseGemm, AsDataType, BsDataType, DsDataType, EDataType, MPerBlock, NPerBlock, KPerBlock, BlockSize, AK1, BK1, GemmSpec, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > DeviceGemmCommon
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:249
remove_cvref_t< tuple_element_t< 0, AsLayout > > ALayout
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:192
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:353
typename GridwiseGemm::Argument Argument
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:247
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, GridwiseGemm::NumATensor > p_as, std::array< const void *, GridwiseGemm::NumBTensor > p_bs, std::array< const void *, GridwiseGemm::NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, std::array< ck::index_t, GridwiseGemm::NumATensor > StrideAs, std::array< ck::index_t, GridwiseGemm::NumBTensor > StrideBs, std::array< ck::index_t, GridwiseGemm::NumDTensor > StrideDs, index_t StrideE, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:319
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:195
std::string GetTypeString() const override
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:359
typename DeviceGemmCommon::Invoker Invoker
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:269
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:277
Definition device_gemm_multiple_abd.hpp:78