blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp Source File

blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp Source File
blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Maximum Global Memory throughput pipeline with >=32KB data in fly
11// GlobalPrefetchStages: >=2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 0
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack,
100 true>
101
102{
104 ADataType,
105 BDataType,
106 ComputeDataType,
107 AccDataType,
108 ATileDesc,
109 BTileDesc,
110 AMmaTileDesc,
111 BMmaTileDesc,
112 ABlockTransferSrcScalarPerVector,
113 BBlockTransferSrcScalarPerVector,
114 MPerBlock,
115 NPerBlock,
116 KPerBlock,
117 MPerXDL,
118 NPerXDL,
119 MRepeat,
120 NRepeat,
121 KPack,
122 true>;
123 using Base::I0;
124 using Base::KRepeat;
125 using Base::xdlops_gemm;
126
138
141
142 using Base::AMmaKStride;
143 using Base::BMmaKStride;
144 using Base::WaveSize;
145
147
148 static constexpr index_t WgpPerCU =
149 (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
151 32768 / WgpPerCU,
152 (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
153 static constexpr index_t PrefetchStages =
156 : 2;
157
158 static constexpr index_t PrefillStages = 1;
160
161 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
162 {
163 return num_loop > PrefetchStages;
164 }
165
166 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
167 {
168 if(num_loop % PrefetchStages == 1)
169 {
170 return TailNumber::One;
171 }
172 else if(num_loop % PrefetchStages == 2)
173 {
174 return TailNumber::Two;
175 }
176 else if(num_loop % PrefetchStages == 3)
177 {
178 return TailNumber::Three;
179 }
180 else if(num_loop % PrefetchStages == 4)
181 {
182 return TailNumber::Four;
183 }
184 else if(num_loop % PrefetchStages == 5)
185 {
186 return TailNumber::Five;
187 }
188 else if(num_loop % PrefetchStages == 6)
189 {
190 return TailNumber::Six;
191 }
192 else if(num_loop % PrefetchStages == 7)
193 {
194 return TailNumber::Seven;
195 }
196 else
197 {
198 return TailNumber::Full;
199 }
200 }
201
202 template <bool HasMainLoop,
203 TailNumber TailNum,
204 typename AGridDesc,
205 typename ABlockDesc,
206 typename ABlockTransfer,
207 typename AGridBuffer,
208 typename ABlockBuffer,
209 typename ABlockTransferStep,
210 typename BGridDesc,
211 typename BBlockDesc,
212 typename BBlockTransfer,
213 typename BGridBuffer,
214 typename BBlockBuffer,
215 typename BBlockTransferStep,
216 typename CThreadBuffer,
217 typename AScaleGridBuffer,
218 typename AScaleGridDesc,
219 typename AScaleThreadDesc,
220 typename AScaleThreadTransfer,
221 typename AScaleThreadTransferStep,
222 typename BScaleGridBuffer,
223 typename BScaleGridDesc,
224 typename BScaleThreadDesc,
225 typename BScaleThreadTransfer,
226 typename BScaleThreadTransferStep>
227 __device__ void Run(
228 // ABlockCopy
229 const AGridDesc& a_grid_desc,
230 const ABlockDesc& a_block_desc,
231 ABlockTransfer& a_blockwise_copy,
232 const AGridBuffer& a_grid_buf,
233 ABlockBuffer& a_block_buf,
234 const ABlockTransferStep& a_block_copy_step,
235 // BBlockCopy
236 const BGridDesc& b_grid_desc,
237 const BBlockDesc& b_block_desc,
238 BBlockTransfer& b_blockwise_copy,
239 const BGridBuffer& b_grid_buf,
240 BBlockBuffer& b_block_buf,
241 const BBlockTransferStep& b_block_copy_step,
242 // CThread
243 CThreadBuffer& c_thread_buf,
244 // AScaleThreadCopy
245 const AScaleGridDesc& a_scale_grid_desc,
246 const AScaleThreadDesc& a_scale_thread_desc,
247 AScaleThreadTransfer& a_scale_thread_copy,
248 const AScaleGridBuffer& a_scale_grid_buf,
249 const AScaleThreadTransferStep& a_scale_thread_copy_step,
250 // BScaleThreadCopy
251 const BScaleGridDesc& b_scale_grid_desc,
252 const BScaleThreadDesc& b_scale_thread_desc,
253 BScaleThreadTransfer& b_scale_thread_copy,
254 const BScaleGridBuffer& b_scale_grid_buf,
255 const BScaleThreadTransferStep& b_scale_thread_copy_step,
256 // num_loop
257 index_t num_loop,
258 index_t num_loop_per_scale) const
259 {
260 // assume kperblock = scaleblockk
261 ignore = num_loop_per_scale;
263 a_thread_desc_.GetElementSpaceSize());
265 b_thread_desc_.GetElementSpaceSize());
267 a_scale_thread_desc.GetElementSpaceSize());
269 b_scale_thread_desc.GetElementSpaceSize());
270
271 // Global prefetch 1
272 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
273 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
274
275 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
276 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
277
278 static_for<0, MRepeat, 1>{}([&](auto m0) {
279 a_scale_thread_copy.Run(a_scale_grid_desc,
280 a_scale_grid_buf,
281 a_scale_thread_desc,
282 make_tuple(m0, I0),
283 a_scale_thread_buf);
284 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
285 a_scale_thread_copy_step.At(Number<0>{}));
286 });
287
288 if(num_loop_per_scale == 1)
289 {
290 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
291 a_scale_thread_copy_step.At(Number<2>{}));
292 }
293 else
294 {
295 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
296 a_scale_thread_copy_step.At(Number<1>{}));
297 }
298
299 b_scale_thread_copy.Run(b_scale_grid_desc,
300 b_scale_grid_buf,
301 b_scale_thread_desc,
302 make_tuple(I0, I0),
303 b_scale_thread_buf);
304
305 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
306
307 // Local prefill 1
308 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
309 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
310
311 // Initialize C
312 c_thread_buf.Clear();
313
314 // Global prefetch [2, PrefetchStages]
315 static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) {
316 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
317 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
318
319 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
320 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
321 });
322
323 auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>();
324
325 // main body
326 if constexpr(HasMainLoop)
327 {
328 index_t i = 0;
329 do
330 {
331 static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) {
333 static_for<0, KRepeat, 1>{}([&](auto k) {
334 static_for<0, MRepeat, 1>{}([&](auto m0) {
337 a_block_buf,
339 make_tuple(m0, I0, k, I0),
340 a_thread_buf);
341 });
342 static_for<0, NRepeat, 1>{}([&](auto n0) {
345 b_block_buf,
347 make_tuple(n0, I0, k, I0),
348 b_thread_buf);
349 });
350 });
351
352 static_for<0, MRepeat, 1>{}([&](auto m0) {
353 static_for<0, NRepeat, 1>{}([&](auto n0) {
354 c_thread_buf_per_scale.Clear();
355 static_for<0, KRepeat, 1>{}([&](auto k0) {
358
359 static_for<0, KPack, 1>{}([&](auto ik) {
360 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
361 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
362 make_tuple(m0, I0, k0, ik))>{}];
363 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
364 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
365 make_tuple(n0, I0, k0, ik))>{}];
366 });
367
368 using mfma_input_type =
370 xdlops_gemm.K1PerXdlops>::type;
371
372 xdlops_gemm.template Run<>(
373 a_thread_vec.template AsType<mfma_input_type>(),
374 b_thread_vec.template AsType<mfma_input_type>(),
375 c_thread_buf_per_scale.GetVectorTypeReference(I0));
376 });
377 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
378 constexpr index_t c_offset =
379 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
380 c_thread_buf(Number<c_offset>{}) +=
381 c_thread_buf_per_scale[Number<t>{}] *
382 type_convert<AccDataType>(a_scale_thread_buf[m0]) *
383 type_convert<AccDataType>(b_scale_thread_buf[I0]);
384 });
385 });
386 });
387
388 static_for<0, MRepeat, 1>{}([&](auto m0) {
389 a_scale_thread_copy.Run(a_scale_grid_desc,
390 a_scale_grid_buf,
391 a_scale_thread_desc,
392 make_tuple(m0, I0),
393 a_scale_thread_buf);
394 a_scale_thread_copy.MoveSrcSliceWindow(
395 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
396 });
397
398 if(num_loop_per_scale == 1)
399 {
400 a_scale_thread_copy.MoveSrcSliceWindow(
401 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
402 }
403 else
404 {
405 a_scale_thread_copy.MoveSrcSliceWindow(
406 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
407 }
408
409 b_scale_thread_copy.Run(b_scale_grid_desc,
410 b_scale_grid_buf,
411 b_scale_thread_desc,
412 make_tuple(I0, I0),
413 b_scale_thread_buf);
414
415 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
416 b_scale_thread_copy_step);
417
419 a_blockwise_copy.RunWrite(
420 a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
421 b_blockwise_copy.RunWrite(
422 b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
423
424 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
425 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
426
427 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
428 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
429 });
430
431 i += PrefetchStages;
432 } while(i < (num_loop - PrefetchStages));
433 }
434
435 // tail
436 auto LoopTailFunc = [&](auto tail_num) {
437 static_for<1, tail_num, 1>{}([&](auto iprefetch) {
439 static_for<0, KRepeat, 1>{}([&](auto k) {
440 static_for<0, MRepeat, 1>{}([&](auto m0) {
443 a_block_buf,
445 make_tuple(m0, I0, k, I0),
446 a_thread_buf);
447 static_for<0, NRepeat, 1>{}([&](auto n0) {
450 b_block_buf,
452 make_tuple(n0, I0, k, I0),
453 b_thread_buf);
454 });
455 });
456 });
457
458 static_for<0, MRepeat, 1>{}([&](auto m0) {
459 static_for<0, NRepeat, 1>{}([&](auto n0) {
460 c_thread_buf_per_scale.Clear();
461 static_for<0, KRepeat, 1>{}([&](auto k0) {
464
465 static_for<0, KPack, 1>{}([&](auto ik) {
466 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
467 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
468 make_tuple(m0, I0, k0, ik))>{}];
469 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
470 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
471 make_tuple(n0, I0, k0, ik))>{}];
472 });
473
474 using mfma_input_type =
476 xdlops_gemm.K1PerXdlops>::type;
477
478 xdlops_gemm.template Run<>(
479 a_thread_vec.template AsType<mfma_input_type>(),
480 b_thread_vec.template AsType<mfma_input_type>(),
481 c_thread_buf_per_scale.GetVectorTypeReference(I0));
482 });
483 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
484 constexpr index_t c_offset =
485 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
486 c_thread_buf(Number<c_offset>{}) +=
487 c_thread_buf_per_scale[Number<t>{}] *
488 type_convert<AccDataType>(a_scale_thread_buf[m0]) *
489 type_convert<AccDataType>(b_scale_thread_buf[I0]);
490 });
491 });
492 });
493
494 static_for<0, MRepeat, 1>{}([&](auto m0) {
495 a_scale_thread_copy.Run(a_scale_grid_desc,
496 a_scale_grid_buf,
497 a_scale_thread_desc,
498 make_tuple(m0, I0),
499 a_scale_thread_buf);
500 a_scale_thread_copy.MoveSrcSliceWindow(
501 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
502 });
503
504 if(num_loop_per_scale == 1)
505 {
506 a_scale_thread_copy.MoveSrcSliceWindow(
507 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
508 }
509 else
510 {
511 a_scale_thread_copy.MoveSrcSliceWindow(
512 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
513 }
514
515 b_scale_thread_copy.Run(b_scale_grid_desc,
516 b_scale_grid_buf,
517 b_scale_thread_desc,
518 make_tuple(I0, I0),
519 b_scale_thread_buf);
520
521 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
522
524 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
525 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
526 });
527
529 static_for<0, KRepeat, 1>{}([&](auto k) {
530 static_for<0, MRepeat, 1>{}([&](auto m0) {
533 a_block_buf,
535 make_tuple(m0, I0, k, I0),
536 a_thread_buf);
537 static_for<0, NRepeat, 1>{}([&](auto n0) {
540 b_block_buf,
542 make_tuple(n0, I0, k, I0),
543 b_thread_buf);
544 });
545 });
546 });
547
548 static_for<0, MRepeat, 1>{}([&](auto m0) {
549 static_for<0, NRepeat, 1>{}([&](auto n0) {
550 c_thread_buf_per_scale.Clear();
551 static_for<0, KRepeat, 1>{}([&](auto k0) {
554
555 static_for<0, KPack, 1>{}([&](auto ik) {
556 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
557 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
558 make_tuple(m0, I0, k0, ik))>{}];
559 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
560 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
561 make_tuple(n0, I0, k0, ik))>{}];
562 });
563
564 using mfma_input_type =
565 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
566
567 xdlops_gemm.template Run<>(
568 a_thread_vec.template AsType<mfma_input_type>(),
569 b_thread_vec.template AsType<mfma_input_type>(),
570 c_thread_buf_per_scale.GetVectorTypeReference(I0));
571 });
572 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
573 constexpr index_t c_offset =
574 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
575 c_thread_buf(Number<c_offset>{}) +=
576 c_thread_buf_per_scale[Number<t>{}] *
577 type_convert<AccDataType>(a_scale_thread_buf[m0]) *
578 type_convert<AccDataType>(b_scale_thread_buf[I0]);
579 });
580 });
581 });
582 };
583
584 if constexpr(TailNum == TailNumber::One)
585 {
587 static_for<0, KRepeat, 1>{}([&](auto k) {
588 static_for<0, MRepeat, 1>{}([&](auto m0) {
591 a_block_buf,
593 make_tuple(m0, I0, k, I0),
594 a_thread_buf);
595 static_for<0, NRepeat, 1>{}([&](auto n0) {
598 b_block_buf,
600 make_tuple(n0, I0, k, I0),
601 b_thread_buf);
602 });
603 });
604 });
605
606 static_for<0, MRepeat, 1>{}([&](auto m0) {
607 static_for<0, NRepeat, 1>{}([&](auto n0) {
608 c_thread_buf_per_scale.Clear();
609 static_for<0, KRepeat, 1>{}([&](auto k0) {
612
613 static_for<0, KPack, 1>{}([&](auto ik) {
614 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
615 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
616 make_tuple(m0, I0, k0, ik))>{}];
617 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
618 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
619 make_tuple(n0, I0, k0, ik))>{}];
620 });
621
622 using mfma_input_type =
623 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
624
625 xdlops_gemm.template Run<>(
626 a_thread_vec.template AsType<mfma_input_type>(),
627 b_thread_vec.template AsType<mfma_input_type>(),
628 c_thread_buf_per_scale.GetVectorTypeReference(I0));
629 });
630 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
631 constexpr index_t c_offset =
632 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
633 c_thread_buf(Number<c_offset>{}) +=
634 c_thread_buf_per_scale[Number<t>{}] *
635 type_convert<AccDataType>(a_scale_thread_buf[m0]) *
636 type_convert<AccDataType>(b_scale_thread_buf[I0]);
637 });
638 });
639 });
640 }
641 else if constexpr(TailNum == TailNumber::Two)
642 {
643 LoopTailFunc(Number<2>{});
644 }
645 else if constexpr(TailNum == TailNumber::Three)
646 {
647 LoopTailFunc(Number<3>{});
648 }
649 else if constexpr(TailNum == TailNumber::Four)
650 {
651 LoopTailFunc(Number<4>{});
652 }
653 else if constexpr(TailNum == TailNumber::Five)
654 {
655 LoopTailFunc(Number<5>{});
656 }
657 else if constexpr(TailNum == TailNumber::Six)
658 {
659 LoopTailFunc(Number<6>{});
660 }
661 else if constexpr(TailNum == TailNumber::Seven)
662 {
663 LoopTailFunc(Number<7>{});
664 }
665 else if constexpr(TailNum == TailNumber::Full)
666 {
667 LoopTailFunc(Number<PrefetchStages>{});
668 }
669 }
670
671 protected:
672 using Base::a_thread_copy_;
673 using Base::a_thread_desc_;
674 using Base::b_thread_copy_;
675 using Base::b_thread_desc_;
676 using Base::c_thread_desc_;
677};
678
679} // namespace ck
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, const AScaleThreadDesc &a_scale_thread_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const AScaleThreadTransferStep &a_scale_thread_copy_step, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop, index_t num_loop_per_scale) const
Definition blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp:227
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack, true > Base
Definition blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp:103
Definition blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp:37
Definition functional2.hpp:33
Definition dtype_vector.hpp:10