blockwise_gemm_wmma.hpp Source File

blockwise_gemm_wmma.hpp Source File#

Composable Kernel: blockwise_gemm_wmma.hpp Source File
blockwise_gemm_wmma.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12#define CK_MNK_LOOP
13
14namespace ck {
15
16#ifdef __gfx12__
17template <index_t BlockSize,
18 typename FloatA,
19 typename FloatB,
20 typename FloatAcc,
21 typename ABlockDesc,
22 typename BBlockDesc,
23 index_t MPerBlock,
24 index_t NPerBlock,
25 index_t KPerBlock,
26 index_t MPerWMMA,
27 index_t NPerWMMA,
28 index_t MRepeat,
29 index_t NRepeat,
30 index_t KPack,
31 bool AEnableLds = true,
32 bool BEnableLds = true,
33 bool TransposeC = false>
34/* Option: Read from LDS, big buffer hold all threads required data
35 * Source
36 * A: K0PerBlock x MPerBlock x K1
37 * B: K0PerBlock x NPerBlock x K1
38 * Destination
39 * C, non-transpose
40 * thread level: MRepeat x NRepeat x MAccVgprs
41 * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
42 * KPACK == WMMA_K = 16
43 *
44 * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
45 * Source:
46 * A(if skip LDS): MRepeat x KPack
47 * B(if skip LDS): NRepeat x KPack
48 * Destination
49 * C, non-transpose
50 * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
51 */
52struct BlockwiseGemmWMMA
53{
54 static constexpr auto I0 = Number<0>{};
55 static constexpr auto I1 = Number<1>{};
56 static constexpr auto I2 = Number<2>{};
57 static constexpr auto I3 = Number<3>{};
58 static constexpr auto I4 = Number<4>{};
59 static constexpr auto I5 = Number<5>{};
60 static constexpr auto WmmaK = Number<16>{};
61
63
64 // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
65 static constexpr index_t WaveSize = 32;
66
67 // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
68 // When not use LDS, each Row read half of whole data from source buffer, exchange the data via
69 // permutation
70 static constexpr index_t A_KRow = 2;
71 static constexpr index_t B_KRow = 2;
72
73 static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
74 static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
75
76 static constexpr auto wmma_gemm =
77 WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
78
79 static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
80 static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
81
82 StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
83 FloatAcc,
84 MRepeat * NRepeat,
85 wmma_gemm.GetRegSizePerWmma(),
86 true>
88
89 __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
90
91 __device__ static auto GetWaveIdx()
92 {
93 const index_t thread_id = ThisThreadBlock::GetThreadId();
94
95 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
97 make_tuple(Sequence<0, 1, 2>{}),
98 make_tuple(Sequence<0>{}));
99
100 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
101 }
102
103 // Default, Block buffer in LDS, thread level offset enabled
104 __device__ static auto CalculateAThreadOriginDataIndex()
105 {
106 if constexpr(AEnableLds)
107 {
108 const auto wave_idx = GetWaveIdx();
109 const auto waveId_m = wave_idx[I0];
110 const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
111
112 // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
113 return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0);
114 }
115 else
116 {
117 return make_tuple(0, 0, 0, 0, 0, 0);
118 }
119 }
120
121 __device__ static auto CalculateBThreadOriginDataIndex()
122 {
123 if constexpr(BEnableLds)
124 {
125 const auto wave_idx = GetWaveIdx();
126 const auto waveId_n = wave_idx[I1];
127 const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
128
129 // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
130 return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0);
131 }
132 else
133 {
134 return make_tuple(0, 0, 0, 0, 0, 0);
135 }
136 }
137
138 template <index_t m0, index_t n0>
140 {
141 const auto wave_idx = GetWaveIdx();
142
143 const auto waveId_m = wave_idx[I0];
144 const auto waveId_n = wave_idx[I1];
145
146 const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
147
148 constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
150 make_tuple(Sequence<0>{}),
151 make_tuple(Sequence<0, 1, 2>{}));
152
153 constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
155 make_tuple(Sequence<0>{}),
156 make_tuple(Sequence<0, 1, 2>{}));
157
158 const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
159 make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
160 const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
161 make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
162
163 return make_tuple(c_thread_m, c_thread_n);
164 }
165
166 template <index_t m0, index_t n0>
168 {
169 const auto wave_idx = GetWaveIdx();
170
171 const auto waveId_m = wave_idx[I0];
172 const auto waveId_n = wave_idx[I1];
173
174 const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
175
176 return make_tuple(
177 Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
178 }
179
180 using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
181 __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
183 : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
184 {
185 static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
186 "wrong! Desc should be known at compile-time");
187
189 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
190
191 static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
192 NPerBlock % (NPerWMMA * NRepeat) == 0,
193 "wrong!");
194 }
195
196 // transposed WMMA output C' = B' * A'
197 __host__ __device__ static constexpr auto
199 {
200 constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
201 wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
202
203 constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
204
206 // |MRepeat |MWave |MSubGroup |NRepeat |NWave
207 // |NThreadPerSubGroup |MAccVgprs
208 make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
209 }
210
211 // Thread level, register decriptor. Vector-write
212 __host__ __device__ static constexpr auto
214 {
215 constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
216 wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
217
218 constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
219 constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
221 // |MRepeat |MWave |MSubGroup |NRepeat |NWave
222 // |NThreadPerSubGroup |MAccVgprs
223 make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
224 make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
225 Number<NRepeat>{} * MAccVgprs * AccStride,
226 Number<NRepeat>{} * MAccVgprs * AccStride,
227 MAccVgprs * AccStride,
228 MAccVgprs * AccStride,
229 MAccVgprs * AccStride,
230 AccStride));
231 }
232
233 template <typename CGridDesc_M_N>
234 __host__ __device__ static constexpr auto
236 const CGridDesc_M_N& c_grid_desc_m_n)
237 {
238 const auto M = c_grid_desc_m_n.GetLength(I0);
239 const auto N = c_grid_desc_m_n.GetLength(I1);
240
241 const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
243 c_grid_desc_m_n,
245 make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
246 make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
247 make_tuple(Sequence<0>{}, Sequence<1>{}),
248 make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
249
250 return wmma_gemm
251 .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
252 c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
253 }
254
255 // transposed WMMA output C' = B' * A'
256 __host__ __device__ static constexpr auto
258 {
259 constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
266
267 return wmma_gemm
268 .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
269 c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
270 }
271
272 // Provide dimension size
273 __host__ __device__ static constexpr auto
275 {
276 constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
283
284 return wmma_gemm
285 .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
286 c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
287 }
288
289 // Describe how data allocated in thread copy src buffer
290 // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
291 static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
292 static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
293
294 template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
295 __device__ void Run(const ABlockBuffer& a_block_buf,
296 const BBlockBuffer& b_block_buf,
297 CThreadBuffer& c_thread_buf) const
298 {
300 a_thread_desc_.GetElementSpaceSize());
302 b_thread_desc_.GetElementSpaceSize());
303
304 static_assert(KPack % (A_K1 * A_KRow) == 0, "");
305 static_assert(KPack % (B_K1 * B_KRow) == 0, "");
306
307 // basic intrinsic to determine loopover direction
308 if constexpr(MRepeat < NRepeat)
309 {
310 static_for<0, KPerBlock / KPack, 1>{}(
311 [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
312 static_for<0, MRepeat, 1>{}([&](auto m0) {
313 // read A
314 a_thread_copy_.Run(
317 a_block_buf,
319 make_tuple(I0, m0, I0, I0, I0, I0),
320 a_thread_buf);
321
322 static_for<0, NRepeat, 1>{}([&](auto n0) {
323 // read B
324 b_thread_copy_.Run(
327 b_block_buf,
329 make_tuple(I0, n0, I0, I0, I0, I0),
330 b_thread_buf);
331
332 vector_type<FloatA, KPack / A_KRow> a_thread_vec;
333 vector_type<FloatB, KPack / B_KRow> b_thread_vec;
334
335 static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
336 a_thread_vec.template AsType<FloatA>()(i) =
337 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
338 make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
339 });
340
341 static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
342 b_thread_vec.template AsType<FloatB>()(i) =
343 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
344 make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
345 });
346
347 using wmma_input_type_a =
348 typename vector_type<FloatA, WmmaK / A_KRow>::type;
349 using wmma_input_type_b =
350 typename vector_type<FloatB, WmmaK / B_KRow>::type;
351
352 constexpr index_t c_offset =
353 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
354
355 wmma_gemm.template Run<>(
356 a_thread_vec.template AsType<wmma_input_type_a>(),
357 b_thread_vec.template AsType<wmma_input_type_b>(),
358 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
359 });
360 });
361 });
362 }
363 else
364 {
365 static_for<0, NRepeat, 1>{}([&](auto n0) {
366 static_for<0, MRepeat, 1>{}([&](auto m0) {
367 static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
368 // k=0,kpack*1, ..
369 // read B
370 b_thread_copy_.Run(
373 b_block_buf,
375 make_tuple(I0, n0, I0, I0, I0, I0),
376 b_thread_buf);
377 // read A
378 a_thread_copy_.Run(
381 a_block_buf,
383 make_tuple(I0, m0, I0, I0, I0, I0),
384 a_thread_buf);
385
386 vector_type<FloatA, KPack / A_KRow> a_thread_vec;
387 vector_type<FloatB, KPack / B_KRow> b_thread_vec;
388
389 static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
390 a_thread_vec.template AsType<FloatA>()(i) =
391 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
392 make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
393 });
394
395 static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
396 b_thread_vec.template AsType<FloatB>()(i) =
397 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
398 make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
399 });
400
401 using wmma_input_type_a =
402 typename vector_type<FloatA, WmmaK / A_KRow>::type;
403 using wmma_input_type_b =
404 typename vector_type<FloatB, WmmaK / B_KRow>::type;
405
406 constexpr index_t c_offset =
407 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
408
409 wmma_gemm.template Run<>(
410 a_thread_vec.template AsType<wmma_input_type_a>(),
411 b_thread_vec.template AsType<wmma_input_type_b>(),
412 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
413 });
414 });
415 });
416 }
417 }
418
419 protected:
420 static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
423 Number<KPack / A_KRow>{},
424 Number<A_K1>{},
425 Number<A_K1>{},
426 Number<A_K1>{},
427 Number<1>{}));
428
429 static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
432 Number<KPack / B_KRow>{},
433 Number<B_K1>{},
434 Number<B_K1>{},
435 Number<B_K1>{},
436 Number<1>{}));
437
438 // C[M, N, NumRegWMMA]
440 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
441
442 template <bool EnableLds>
443 struct AThreadCopySelector;
444
445 template <>
446 struct AThreadCopySelector<true>
447 {
448 using type =
449 ThreadwiseTensorSliceTransfer_v4<FloatA,
450 FloatA,
452 decltype(a_thread_desc_),
453 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
454 Sequence<0, 1, 2, 3, 4, 5>,
455 5,
456 A_K1,
457 A_K1>;
458 };
459
460 template <>
461 struct AThreadCopySelector<false>
462 {
463 using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
464 FloatA,
465 FloatA,
467 decltype(a_thread_desc_),
468 tensor_operation::element_wise::PassThrough,
469 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
470 Sequence<0, 1, 2, 3, 4, 5>,
471 5,
472 A_K1,
473 false>;
474 };
475
476 template <bool EnableLds>
477 struct BThreadCopySelector;
478
479 template <>
480 struct BThreadCopySelector<true>
481 {
482 using type =
483 ThreadwiseTensorSliceTransfer_v4<FloatB,
484 FloatB,
486 decltype(b_thread_desc_),
487 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
488 Sequence<0, 1, 2, 3, 4, 5>,
489 5,
490 B_K1,
491 B_K1>;
492 };
493
494 template <>
495 struct BThreadCopySelector<false>
496 {
497 using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
498 FloatB,
499 FloatB,
501 decltype(b_thread_desc_),
502 tensor_operation::element_wise::PassThrough,
503 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
504 Sequence<0, 1, 2, 3, 4, 5>,
505 5,
506 B_K1,
507 false>;
508 };
509
510 typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
511 typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
512};
513#else
514template <index_t BlockSize,
515 typename FloatA,
516 typename FloatB,
517 typename FloatAcc,
518 typename ABlockDesc,
519 typename BBlockDesc,
520 index_t MPerBlock,
521 index_t NPerBlock,
522 index_t KPerBlock,
523 index_t MPerWMMA,
524 index_t NPerWMMA,
525 index_t MRepeat,
526 index_t NRepeat,
527 index_t KPack,
528 bool AEnableLds = true,
529 bool BEnableLds = true,
530 bool TransposeC = false>
531/* Option: Read from LDS, big buffer hold all threads required data
532 * Source
533 * A: K0PerBlock x MPerBlock x K1
534 * B: K0PerBlock x NPerBlock x K1
535 * Destination
536 * C, non-transpose
537 * thread level: MRepeat x NRepeat x MAccVgprs
538 * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
539 * KPACK == WMMA_K = 16
540 *
541 * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
542 * Source:
543 * A(if skip LDS): MRepeat x KPack
544 * B(if skip LDS): NRepeat x KPack
545 * Destination
546 * C, non-transpose
547 * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
548 */
550{
551 static constexpr auto I0 = Number<0>{};
552 static constexpr auto I1 = Number<1>{};
553 static constexpr auto I2 = Number<2>{};
554 static constexpr auto I3 = Number<3>{};
555 static constexpr auto I4 = Number<4>{};
556 static constexpr auto I5 = Number<5>{};
557 static constexpr auto WmmaK = Number<16>{};
558
560
561 // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
562 static constexpr index_t WaveSize = 32;
563
564 // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
565 // When not use LDS, each Row read half of whole data from source buffer, exchange the data via
566 // permutation
567 static constexpr index_t A_KRow = AEnableLds ? 1 : 2;
568 static constexpr index_t B_KRow = BEnableLds ? 1 : 2;
569 static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
570 static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
571
574
575 static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
576 static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
577
579 FloatAcc,
580 MRepeat * NRepeat,
581 wmma_gemm.GetRegSizePerWmma(),
582 true>
584
585 __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
586
587 __device__ static auto GetWaveIdx()
588 {
589 const index_t thread_id = ThisThreadBlock::GetThreadId();
590
591 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
595
596 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
597 }
598
599 // Default, Block buffer in LDS, thread level offset enabled
600 __device__ static auto CalculateAThreadOriginDataIndex()
601 {
602 if constexpr(AEnableLds)
603 {
604 const auto wave_idx = GetWaveIdx();
605 const auto waveId_m = wave_idx[I0];
606 const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
607
608 // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
609 return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0);
610 }
611 else
612 {
613 return make_tuple(0, 0, 0, 0, 0, 0);
614 }
615 }
616
617 __device__ static auto CalculateBThreadOriginDataIndex()
618 {
619 if constexpr(BEnableLds)
620 {
621 const auto wave_idx = GetWaveIdx();
622 const auto waveId_n = wave_idx[I1];
623 const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
624
625 // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
626 return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0);
627 }
628 else
629 {
630 return make_tuple(0, 0, 0, 0, 0, 0);
631 }
632 }
633
634 template <index_t m0, index_t n0>
636 {
637 const auto wave_idx = GetWaveIdx();
638
639 const auto waveId_m = wave_idx[I0];
640 const auto waveId_n = wave_idx[I1];
641
642 const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
643
644 constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
648
649 constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
653
654 const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
655 make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
656 const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
657 make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
658
659 return make_tuple(c_thread_m, c_thread_n);
660 }
661
662 template <index_t m0, index_t n0>
664 {
665 const auto wave_idx = GetWaveIdx();
666
667 const auto waveId_m = wave_idx[I0];
668 const auto waveId_n = wave_idx[I1];
669
670 const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
671
672 return make_tuple(
673 Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
674 }
675
679 : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
680 {
681 static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
682 "wrong! Desc should be known at compile-time");
683
685 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
686
687 static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
688 NPerBlock % (NPerWMMA * NRepeat) == 0,
689 "wrong!");
690 }
691
692 // transposed WMMA output C' = B' * A'
693 __host__ __device__ static constexpr auto
695 {
696 constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
697 wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
698
699 constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
700
702 // |MRepeat |MWave |MSubGroup |NRepeat |NWave
703 // |NThreadPerSubGroup |MAccVgprs
704 make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
705 }
706
707 // Thread level, register decriptor. Vector-write
708 __host__ __device__ static constexpr auto
710 {
711 constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
712 wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
713
714 constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
715 constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
717 // |MRepeat |MWave |MSubGroup |NRepeat |NWave
718 // |NThreadPerSubGroup |MAccVgprs
719 make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
720 make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
721 Number<NRepeat>{} * MAccVgprs * AccStride,
722 Number<NRepeat>{} * MAccVgprs * AccStride,
723 MAccVgprs * AccStride,
724 MAccVgprs * AccStride,
725 MAccVgprs * AccStride,
726 AccStride));
727 }
728
729 template <typename CGridDesc_M_N>
730 __host__ __device__ static constexpr auto
732 const CGridDesc_M_N& c_grid_desc_m_n)
733 {
734 const auto M = c_grid_desc_m_n.GetLength(I0);
735 const auto N = c_grid_desc_m_n.GetLength(I1);
736
737 const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
739 c_grid_desc_m_n,
741 make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
742 make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
745
746 return wmma_gemm
747 .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
748 c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
749 }
750
751 // transposed WMMA output C' = B' * A'
752 __host__ __device__ static constexpr auto
754 {
755 constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
762
763 return wmma_gemm
764 .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
765 c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
766 }
767
768 // Provide dimension size
769 __host__ __device__ static constexpr auto
771 {
772 constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
779
780 return wmma_gemm
781 .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
782 c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
783 }
784
785 // Describe how data allocated in thread copy src buffer
786 // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
787 static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
788 static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
789
790 template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
791 __device__ void Run(const ABlockBuffer& a_block_buf,
792 const BBlockBuffer& b_block_buf,
793 CThreadBuffer& c_thread_buf) const
794 {
796 a_thread_desc_.GetElementSpaceSize());
798 b_thread_desc_.GetElementSpaceSize());
799
800 // basic intrinsic to determine loopover direction
801 if constexpr(MRepeat < NRepeat)
802 {
803 static_for<0, KPerBlock / KPack, 1>{}(
804 [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
805 static_for<0, MRepeat, 1>{}([&](auto m0) {
806 // read A
807 a_thread_copy_.Run(
810 a_block_buf,
812 make_tuple(I0, m0, I0, I0, I0, I0),
813 a_thread_buf);
814
815 static_for<0, NRepeat, 1>{}([&](auto n0) {
816 // read B
817 b_thread_copy_.Run(
820 b_block_buf,
822 make_tuple(I0, n0, I0, I0, I0, I0),
823 b_thread_buf);
824
825 vector_type<FloatA, KPack> a_thread_vec;
826 vector_type<FloatB, KPack> b_thread_vec;
827
828 static_for<0, KPack, 1>{}([&](auto i) {
829 a_thread_vec.template AsType<FloatA>()(i) =
830 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
831 make_tuple(i / A_K1 / A_KRow,
832 m0,
833 0,
834 (i / A_K1) % A_KRow,
835 0,
836 i % A_K1))>{}];
837 b_thread_vec.template AsType<FloatB>()(i) =
838 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
839 make_tuple(i / B_K1 / B_KRow,
840 n0,
841 0,
842 (i / B_K1) % B_KRow,
843 0,
844 i % B_K1))>{}];
845 });
846
847 using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
848 using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
849
850 constexpr index_t c_offset =
851 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
852
853 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
854 b_thread_vec.template AsType<wmma_input_type_b>(),
855 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
856 });
857 });
858 });
859 }
860 else
861 {
862 static_for<0, NRepeat, 1>{}([&](auto n0) {
863 static_for<0, MRepeat, 1>{}([&](auto m0) {
864 static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
865 // k=0,kpack*1, ..
866 // read B
867 b_thread_copy_.Run(
870 b_block_buf,
872 make_tuple(I0, n0, I0, I0, I0, I0),
873 b_thread_buf);
874 // read A
875 a_thread_copy_.Run(
878 a_block_buf,
880 make_tuple(I0, m0, I0, I0, I0, I0),
881 a_thread_buf);
882
883 vector_type<FloatA, KPack> a_thread_vec;
884 vector_type<FloatB, KPack> b_thread_vec;
885
886 static_for<0, KPack, 1>{}([&](auto i) {
887 b_thread_vec.template AsType<FloatB>()(i) =
888 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
889 make_tuple(i / B_K1 / B_KRow,
890 n0,
891 0,
892 (i / B_K1) % B_KRow,
893 0,
894 i % B_K1))>{}];
895 a_thread_vec.template AsType<FloatA>()(i) =
896 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
897 make_tuple(i / A_K1 / A_KRow,
898 m0,
899 0,
900 (i / A_K1) % A_KRow,
901 0,
902 i % A_K1))>{}];
903 });
904
905 using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
906 using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
907
908 constexpr index_t c_offset =
909 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
910
911 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
912 b_thread_vec.template AsType<wmma_input_type_b>(),
913 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
914 });
915 });
916 });
917 }
918 }
919
920 protected:
921 static constexpr auto a_thread_desc_ =
924 I1,
926 I1,
927 Number<A_K1>{}),
931 Number<A_K1>{},
932 Number<A_K1>{},
933 Number<1>{}));
934
935 static constexpr auto b_thread_desc_ =
938 I1,
940 I1,
941 Number<B_K1>{}),
945 Number<B_K1>{},
946 Number<B_K1>{},
947 Number<1>{}));
948
949 // C[M, N, NumRegWMMA]
951 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
952
953 template <bool EnableLds>
955
956 template <>
958 {
959 using type =
961 FloatA,
963 decltype(a_thread_desc_),
964 Sequence<KPack / A_K1 / A_KRow, 1, 1, A_KRow, 1, A_K1>,
966 5,
967 A_K1,
968 A_K1>;
969 };
970
971 template <>
973 {
975 FloatA,
976 FloatA,
978 decltype(a_thread_desc_),
980 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
982 5,
983 A_K1,
984 0x76543210,
985 0xfedcba98,
986 TransposeC ? false : true>;
987 };
988
989 template <bool EnableLds>
991
992 template <>
994 {
995 using type =
997 FloatB,
999 decltype(b_thread_desc_),
1000 Sequence<KPack / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>,
1002 5,
1003 B_K1,
1004 B_K1>;
1005 };
1006
1007 template <>
1009 {
1011 FloatB,
1012 FloatB,
1014 decltype(b_thread_desc_),
1016 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
1018 5,
1019 B_K1,
1020 0x76543210,
1021 0xfedcba98,
1022 TransposeC ? true : false>;
1023 };
1024
1027};
1028#endif
1029
1030} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< FloatA, FloatA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), tensor_operation::element_wise::PassThrough, Sequence< KPack/A_K1/A_KRow, 1, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, 0x76543210, 0xfedcba98, TransposeC ? false :true > type
Definition blockwise_gemm_wmma.hpp:974
ThreadwiseTensorSliceTransfer_v4< FloatA, FloatA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, 1, 1, A_KRow, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > type
Definition blockwise_gemm_wmma.hpp:959
Definition blockwise_gemm_wmma.hpp:954
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< FloatB, FloatB, decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), tensor_operation::element_wise::PassThrough, Sequence< KPack/B_K1/B_KRow, 1, 1, 1, 1, B_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, B_K1, 0x76543210, 0xfedcba98, TransposeC ? true :false > type
Definition blockwise_gemm_wmma.hpp:1010
ThreadwiseTensorSliceTransfer_v4< FloatB, FloatB, decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), Sequence< KPack/B_K1/B_KRow, 1, 1, B_KRow, 1, B_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, B_K1, B_K1 > type
Definition blockwise_gemm_wmma.hpp:995
Definition blockwise_gemm_wmma.hpp:990
__host__ static __device__ constexpr auto GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
Definition blockwise_gemm_wmma.hpp:753
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, wmma_gemm.GetRegSizePerWmma(), true > c_thread_buf_
Definition blockwise_gemm_wmma.hpp:583
static constexpr index_t NWaves
Definition blockwise_gemm_wmma.hpp:576
static constexpr index_t A_KRow
Definition blockwise_gemm_wmma.hpp:567
static constexpr auto b_thread_desc_
Definition blockwise_gemm_wmma.hpp:935
static constexpr index_t B_K1
Definition blockwise_gemm_wmma.hpp:570
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition blockwise_gemm_wmma.hpp:559
__host__ static __device__ constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition blockwise_gemm_wmma.hpp:770
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition blockwise_gemm_wmma.hpp:791
static constexpr index_t A_K1
Definition blockwise_gemm_wmma.hpp:569
static constexpr auto I0
Definition blockwise_gemm_wmma.hpp:551
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_wmma.hpp:585
static constexpr auto I5
Definition blockwise_gemm_wmma.hpp:556
static constexpr index_t B_KRow
Definition blockwise_gemm_wmma.hpp:568
BThreadCopySelector< BEnableLds >::type b_thread_copy_
Definition blockwise_gemm_wmma.hpp:1026
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >)
Definition blockwise_gemm_wmma.hpp:635
static constexpr auto I1
Definition blockwise_gemm_wmma.hpp:552
static constexpr auto I3
Definition blockwise_gemm_wmma.hpp:554
static constexpr index_t MWaves
Definition blockwise_gemm_wmma.hpp:575
decltype(CalculateAThreadOriginDataIndex()) Tuple6
Definition blockwise_gemm_wmma.hpp:676
static constexpr auto a_thread_desc_
Definition blockwise_gemm_wmma.hpp:921
static constexpr index_t WaveSize
Definition blockwise_gemm_wmma.hpp:562
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_wmma.hpp:600
static constexpr auto c_thread_desc_
Definition blockwise_gemm_wmma.hpp:950
static constexpr auto WmmaK
Definition blockwise_gemm_wmma.hpp:557
__host__ static __device__ constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition blockwise_gemm_wmma.hpp:709
static constexpr auto I4
Definition blockwise_gemm_wmma.hpp:555
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1
Definition blockwise_gemm_wmma.hpp:787
__host__ static __device__ constexpr auto GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
Definition blockwise_gemm_wmma.hpp:694
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1
Definition blockwise_gemm_wmma.hpp:788
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_wmma.hpp:731
static __device__ auto GetWaveIdx()
Definition blockwise_gemm_wmma.hpp:587
static constexpr auto wmma_gemm
Definition blockwise_gemm_wmma.hpp:572
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_wmma.hpp:617
__host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin=CalculateAThreadOriginDataIndex(), Tuple6 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_wmma.hpp:677
static constexpr auto I2
Definition blockwise_gemm_wmma.hpp:553
AThreadCopySelector< AEnableLds >::type a_thread_copy_
Definition blockwise_gemm_wmma.hpp:1025
static __device__ auto CalculateCThreadOriginDataIndex7D(Number< m0 >, Number< n0 >)
Definition blockwise_gemm_wmma.hpp:663
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
static __device__ constexpr index_t GetNumOfThread()
Definition thread_group.hpp:15
static __device__ index_t GetThreadId()
Definition thread_group.hpp:19
Definition threadwise_tensor_slice_transfer.hpp:1877
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition wmma_gemm.hpp:663
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition dtype_vector.hpp:10