gemm_pipeline_ag_bg_cr_comp_v3.hpp Source File

gemm_pipeline_ag_bg_cr_comp_v3.hpp Source File#

Composable Kernel: gemm_pipeline_ag_bg_cr_comp_v3.hpp Source File
gemm_pipeline_ag_bg_cr_comp_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11namespace ck_tile {
12
13// A Tile Window: global memory
14// B Tile Window: global memory
15// C Distributed tensor: register
16template <typename Problem>
18{
19 static constexpr index_t PrefetchStages = 2;
20 static constexpr index_t PrefillStages = 1;
21 static constexpr index_t GlobalBufferNum = 1;
22 static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
23
24 CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
25 {
26 return num_loop > PrefetchStages;
27 }
28
30 {
31 if(BlockHasHotloop(num_loop))
32 {
33 return TailNumber::Full;
34 }
35 else
36 {
37 if(num_loop == 1)
38 {
39 return TailNumber::Odd;
40 }
41 else
42 {
43 return TailNumber::Even;
44 }
45 }
46 }
47
48 template <typename RunFunction>
49 CK_TILE_HOST_DEVICE static auto
50 TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
51 {
52 // Handle all the valid cases.
53 if(has_hot_loop)
54 {
55 if(tail_number == TailNumber::Full)
56 {
57 return run_func(bool_constant<true>{},
59 }
60 }
61 else
62 {
63 if(tail_number == TailNumber::Odd)
64 {
65 return run_func(bool_constant<false>{},
67 }
68 else if(tail_number == TailNumber::Even)
69 {
70 return run_func(bool_constant<false>{},
72 }
73 }
74#if defined(__HIP_DEVICE_COMPILE__)
75 // This path should be unreachable in device code if tail_number is valid.
76 __builtin_unreachable();
77#else
78 // If execution reaches here, it's an invalid combination of arguments.
79 if(has_hot_loop)
80 {
81 throw std::logic_error("Invalid TailNumber: If has_hot_loop is true, tail_number must "
82 "be TailNumber::Full.");
83 }
84 else
85 {
86 throw std::logic_error("Invalid TailNumber: If has_hot_loop is false, tail_number must "
87 "be TailNumber::Odd or TailNumber::Even.");
88 }
89#endif
90 }
91};
92
93// Compute optimized pipeline
94// GlobalPrefetchStages: 2
95// LocalPreFillStages: 1
96// LocalPreFetchStages: 1
97// LocalSharedMemoryBuffer: 1
98template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
100{
103
107
111
115
118
121
123 using I0 = number<0>;
124 using I1 = number<1>;
125 using I2 = number<2>;
126
127 static constexpr index_t BlockSize = Problem::kBlockSize;
128
129 static constexpr index_t MPerBlock = BlockGemmShape::kM;
130 static constexpr index_t NPerBlock = BlockGemmShape::kN;
131 static constexpr index_t KPerBlock = BlockGemmShape::kK;
132
133 template <bool IsWave32Host = false>
134 static constexpr index_t GetVectorSizeA()
135 {
136 return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
137 }
138 template <bool IsWave32Host = false>
139 static constexpr index_t GetVectorSizeB()
140 {
141 return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
142 }
143 static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
144
145 static constexpr index_t APackedSize =
147 static constexpr index_t BPackedSize =
149
150 static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
151 static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
152
153 static constexpr bool kPadM = Problem::kPadM;
154 static constexpr bool kPadN = Problem::kPadN;
155 static constexpr bool kPadK = Problem::kPadK;
156
157 static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
158 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
159 static constexpr index_t Preshuffle = Problem::Preshuffle;
160
161 static constexpr bool HasHotLoop =
162 Problem::HasHotLoop; // Base::BlockHasHotloop(Problem::num_loop);
163 static constexpr auto TailNum =
164 Problem::TailNum; // Base::GetBlockLoopTailNum(Problem::num_loop);
165 static constexpr auto Scheduler = Problem::Scheduler;
166
169
172
173 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
174 {
175 // clang-format off
176 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
177 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
178 return concat('_', "pipeline_AgBgCrCompV3",
180 concat('x', WaveNumM, WaveNumN),
181 concat('x', kPadM, kPadN, kPadK));
182 // clang-format on
183 }
184
186 {
187 return Policy::template GetSmemSize<Problem>();
188 }
189
190 CK_TILE_HOST static std::string Print()
191 {
192 constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
193 constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
194 constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
195
196 constexpr index_t WaveSize = get_warp_size();
197 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
198 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
199
200 // Below should be equal to AK1|BK1
201 constexpr index_t A_LDS_Read_Width = GetSmemPackA();
202 constexpr index_t B_LDS_Read_Width = GetSmemPackB();
203
204 constexpr index_t A_LDS_Write_Width = GetSmemPackA();
205 constexpr index_t B_LDS_Write_Width = GetSmemPackB();
206
207 constexpr index_t A_Buffer_Load_Inst_Num =
209 constexpr index_t B_Buffer_Load_Inst_Num =
211
212 constexpr index_t A_LDS_Write_Inst_Num =
213 MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
214 constexpr index_t B_LDS_Write_Inst_Num =
215 NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
216
217 constexpr index_t A_LDS_Read_Inst_Num =
218 WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
219 constexpr index_t B_LDS_Read_Inst_Num =
220 WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
221
222 constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
223 (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
224
225 auto str = std::stringstream{};
226
227 str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n"
228 << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
229 << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
230 << "\n"
231 << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
232 << "\n"
233 << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
234 << "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
235 << "KPack: " << BlockGemm::Traits::KPack << "\n"
236 << "PrefetchStages: " << PrefetchStages << "\n";
237 return str.str();
238 }
239
240 template <GemmPipelineScheduler Scheduler>
242 {
243 };
244
245 template <>
247 {
249
250 CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
251 {
252 constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
253 constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
254 constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
255
256 constexpr index_t WaveSize = get_warp_size();
257 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
258 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
259
260 // Below should be equal to AK1|BK1
261 constexpr index_t A_LDS_Read_Width = GetSmemPackA();
262 constexpr index_t B_LDS_Read_Width = GetSmemPackB();
263
264 constexpr index_t A_LDS_Write_Width = GetSmemPackA();
265 constexpr index_t B_LDS_Write_Width = GetSmemPackB();
266
267 constexpr index_t A_Buffer_Load_Inst_Num =
269 constexpr index_t B_Buffer_Load_Inst_Num =
271
272 constexpr index_t A_LDS_Write_Inst_Num =
273 MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
274 constexpr index_t B_LDS_Write_Inst_Num =
275 NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
276
277 constexpr index_t A_LDS_Read_Inst_Num =
278 WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
279 constexpr index_t B_LDS_Read_Inst_Num =
280 WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
281
282 constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
283 (BlockSize / WaveSize) /
284 (MPerXDL * NPerXDL * KPerXDL);
285
286 // A/B split schedule
287 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
288 constexpr auto num_ds_read_inst_a =
289 A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
290 : A_LDS_Read_Inst_Num / 2;
291 constexpr auto num_ds_read_inst_b =
292 B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
293 : B_LDS_Read_Inst_Num / 2;
294
295 constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
296 constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
297
298 constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
299 constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
300
301 constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
302
303 constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
304 constexpr auto ds_read_a_issue_cycle =
305 A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
306 constexpr auto ds_read_b_issue_cycle =
307 B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
308 constexpr auto ds_read_a_mfma_rate =
309 (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
310 constexpr auto ds_read_b_mfma_rate =
311 (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
312
313 constexpr auto num_dsread_a_mfma =
314 (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
315 constexpr auto num_dsread_b_mfma =
316 (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
317
318 // stage 1
319 // Separate this part?
320 // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
321 // sizeof(ComputeDataType) /
322 // sizeof(BDataType)
323 // ? sizeof(ComputeDataType) /
324 // sizeof(ADataType) : sizeof(ComputeDataType)
325 // / sizeof(BDataType);
326 constexpr auto num_mfma_stage1 =
327 num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
328 constexpr auto num_mfma_per_issue =
329 num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
330 constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
331 constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
332
334 ignore = i;
335 static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
336 ignore = idswrite;
337 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
338 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
339 });
340 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
341 __builtin_amdgcn_sched_group_barrier(
342 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
343 });
345 ignore = i;
346 static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
347 ignore = idswrite;
348 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
349 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
350 });
351 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
352 __builtin_amdgcn_sched_group_barrier(
353 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
354 });
355
356 // stage 2
358 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
359 ds_read_a_mfma_rate)
360 {
361 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
362 }
363 else
364 {
365 __builtin_amdgcn_sched_group_barrier(
366 0x100,
367 num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
368 0); // DS read
369 }
370 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
371 });
372
374 if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
375 ds_read_b_mfma_rate)
376 {
377 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
378 }
379 else
380 {
381 __builtin_amdgcn_sched_group_barrier(
382 0x100,
383 num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate,
384 0); // DS read
385 }
386 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
387 });
388 }
389
390 template <bool HasHotLoop,
392 typename AsDramBlockWindowTmp,
393 typename BsDramBlockWindowTmp,
394 typename AElementFunction,
395 typename BElementFunction,
396 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
398 bool>* = nullptr>
399 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
400 const AElementFunction& a_element_func,
401 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
402 const BElementFunction& b_element_func,
403 index_t num_loop,
404 void* p_smem) const
405 {
406 using ADramBlockWindowTmp =
407 remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
408 using BDramBlockWindowTmp =
409 remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
410
411 static_assert(
412 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
413 std::is_same_v<BDataType,
415 "A/B Dram block window should have the same data type as appropriate "
416 "([A|B]DataType) defined in Problem definition!");
417
418 constexpr bool is_a_col_major =
419 std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
420 constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
421
422 static_assert(is_a_col_major
423 ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
424 MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
425 : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
426 KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
427 "A block window has incorrect lengths for defined ALayout!");
428 static_assert(is_b_row_major
429 ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
430 NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
431 : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
432 KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
433 "B block window has incorrect lengths for defined BLayout!");
434
435 // ------------------------------------------------------------------------------------
436 // Definitions of all needed tiles
437
438 // A/B tiles in LDS
439 auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
440
441 // Tile distribution for load from lds
442 constexpr auto a_lds_load_tile_distr =
443 make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
444 constexpr auto b_lds_load_tile_distr =
445 make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
446
447 // A DRAM tile window for load
448 // A LDS tile window for store
449 // A LDS tile for block GEMM
450 auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
451 Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
452
453 // B DRAM tile window for load
454 // B LDS tile window for store
455 // B LDS tile for block GEMM
456 auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
457 Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
458
459 // Block GEMM
460 auto block_gemm = BlockGemm();
461 auto c_block_tile = block_gemm.MakeCBlockTile();
462
463 using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
464 using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
465
466 constexpr ADramTileWindowStep a_dram_tile_window_step =
467 is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
468 constexpr BDramTileWindowStep b_dram_tile_window_step =
469 is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
470
471 // -----------------------------------------------------------------------------------------
472 // Gemm pipeline start
473 // initialize C
474 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
475
476 // Load tile — during value loading, an elementwise function is executed for each A0,
477 // A1, … AN. The values A0, A1, … AN are read by the same thread.
478 auto elementwise_As_res =
479 load_tile_with_elementwise(a_copy_dram_window, a_element_func);
480
481 // Move each A — the enhanced function move_tile_window is executed, which takes a tuple
482 // as input.
483 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
484
485 // Load tile — during value loading, an elementwise function is executed for each B0,
486 // B1, … BN. The values B0, B1, … BN are read by the same thread.
487 auto elementwise_Bs_res =
488 load_tile_with_elementwise(b_copy_dram_window, b_element_func);
489
490 // Move each B — the enhanced function move_tile_window is executed, which takes a tuple
491 // as input.
492 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
493
494 // LDS write 0
495 if constexpr(is_a_col_major && !is_a_load_tr_v())
496 {
498 Policy::template MakeShuffledARegTileDistribution<Problem>());
499 transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
500 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
501 }
502 else
503 {
504 Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
505 }
506 if constexpr(is_b_row_major && !is_b_load_tr_v())
507 {
509 Policy::template MakeShuffledBRegTileDistribution<Problem>());
510 transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
511 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
512 }
513 else
514 {
515 Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
516 }
517
518 // global read 1
519
520 elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
521 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
522
523 elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
524 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
525
527 block_gemm.LocalPrefetch(
528 a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
529
530 __builtin_amdgcn_sched_barrier(0);
531
532 // main body
533 if constexpr(HasHotLoop)
534 {
535 index_t i = 0;
536 do
537 {
539
540 if constexpr(is_a_col_major && !is_a_load_tr_v())
541 {
543 Policy::template MakeShuffledARegTileDistribution<Problem>());
544 transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
545 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
546 }
547 else
548 {
549 Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
550 }
551 if constexpr(is_b_row_major && !is_b_load_tr_v())
552 {
554 Policy::template MakeShuffledBRegTileDistribution<Problem>());
555 transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
556 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
557 }
558 else
559 {
560 Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
561 }
562
563 elementwise_As_res =
564 load_tile_with_elementwise(a_copy_dram_window, a_element_func);
565 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
566
567 elementwise_Bs_res =
568 load_tile_with_elementwise(b_copy_dram_window, b_element_func);
569 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
570
571 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
572
574
575 block_gemm.LocalPrefetch(
576 a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
578 __builtin_amdgcn_sched_barrier(0);
579
580 i += 1;
581 } while(i < (num_loop - 1));
582 }
583 // tail
584 if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
585 {
586 // Leak last MFMA block to epilogue region, cover the potential lds-shuffle
587 // latency
588 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
589 }
590 else
591 {
592 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
594
595 if constexpr(is_a_col_major && !is_a_load_tr_v())
596 {
598 Policy::template MakeShuffledARegTileDistribution<Problem>());
599 transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
600 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
601 }
602 else
603 {
604 Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
605 }
606 if constexpr(is_b_row_major && !is_b_load_tr_v())
607 {
609 Policy::template MakeShuffledBRegTileDistribution<Problem>());
610 transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
611 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
612 }
613 else
614 {
615 Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
616 }
618 block_gemm.LocalPrefetch(
619 a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
620 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
621 }
622 // __builtin_amdgcn_sched_barrier(0);
623 return c_block_tile;
624 }
625 };
626
627 template <typename AsDramBlockWindowTmp,
628 typename BsDramBlockWindowTmp,
629 typename AElementFunction,
630 typename BElementFunction,
631 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
633 bool>* = nullptr>
634 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
635 const AElementFunction& a_element_func,
636 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
637 const BElementFunction& b_element_func,
638 index_t num_loop,
639 void* p_smem) const
640 {
641 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
642 a_dram_block_window_tmp,
643 a_element_func,
644 b_dram_block_window_tmp,
645 b_element_func,
646 num_loop,
647 p_smem);
648 }
649
656 template <typename AsDramBlockWindowTmp,
657 typename BsDramBlockWindowTmp,
658 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
660 bool>* = nullptr>
661 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
662 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
663 index_t num_loop,
664 bool has_hot_loop,
665 TailNumber tail_number,
666 void* p_smem) const
667 {
668 const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
669 constexpr bool hot_loop = hot_loop_.value;
670 constexpr auto tail_num = tail_num_.value;
671 constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
672 return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
673 a_dram_block_window_tmp,
675 b_dram_block_window_tmp,
677 num_loop,
678 p_smem);
679 };
680 return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
681 }
682
690 template <typename AsDramBlockWindowTmp,
691 typename BsDramBlockWindowTmp,
692 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
694 bool>* = nullptr>
695 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
696 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
697 index_t num_loop,
698 void* p_smem) const
699 {
700 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
701 a_dram_block_window_tmp,
702 [](auto& e, const ADataType& a) { e = a; },
703 b_dram_block_window_tmp,
704 [](auto& e, const BDataType& b) { e = b; },
705 num_loop,
706 p_smem);
707 }
708
709 template <typename AsDramBlockWindowTmp,
710 typename BsDramBlockWindowTmp,
711 typename AElementFunction,
712 typename BElementFunction,
713 typename std::enable_if_t<!is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
715 bool>* = nullptr>
716 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
717 const AElementFunction& a_element_func,
718 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
719 const BElementFunction& b_element_func,
720 index_t num_loop,
721 void* p_smem) const
722 {
723 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
724 a_element_func,
725 ck_tile::make_tuple(b_dram_block_window_tmp),
726 b_element_func,
727 num_loop,
728 p_smem);
729 }
730
738 template <typename ADramBlockWindowTmp,
739 typename BDramBlockWindowTmp,
740 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
742 bool>* = nullptr>
743 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
744 const BDramBlockWindowTmp& b_dram_block_window_tmp,
745 index_t num_loop,
746 bool has_hot_loop,
747 TailNumber tail_number,
748 void* p_smem) const
749 {
750 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
751 ck_tile::make_tuple(b_dram_block_window_tmp),
752 num_loop,
753 has_hot_loop,
754 tail_number,
755 p_smem);
756 }
757
766 template <typename ADramBlockWindowTmp,
767 typename BDramBlockWindowTmp,
768 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
770 bool>* = nullptr>
771 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
772 const BDramBlockWindowTmp& b_dram_block_window_tmp,
773 index_t num_loop,
774 void* p_smem) const
775 {
776 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
777 ck_tile::make_tuple(b_dram_block_window_tmp),
778 num_loop,
779 p_smem);
780 }
781};
782
783} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ Even
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:24
@ Odd
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:23
@ Full
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:39
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access >={}, bool_constant< oob_conditional_check >={})
Load tile with elementwise function.
Definition load_tile.hpp:41
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
ck_tile::element_wise::PassThrough PassThrough
Definition grouped_convolution_utils.hpp:47
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_DEVICE void transpose_tile2d(OutTensor &out, const InTensor &in)
Definition transpose_tile.hpp:195
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
CK_TILE_HOST_DEVICE constexpr details::return_type< D, Ts... > make_array(Ts &&... ts)
Definition tile/core/container/array.hpp:242
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:18
static CK_TILE_HOST_DEVICE constexpr bool BlockHasHotloop(index_t num_loop)
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:24
static CK_TILE_HOST_DEVICE constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:29
static constexpr index_t PrefillStages
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:20
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:19
static constexpr index_t GlobalBufferNum
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:21
static constexpr bool UsePersistentKernel
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:22
static CK_TILE_HOST_DEVICE auto TailHandler(const RunFunction &run_func, bool has_hot_loop, TailNumber tail_number)
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:50
static CK_TILE_DEVICE constexpr auto HotLoopScheduler()
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:250
PipelineImplBase Base
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:248
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:399
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:242
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:100
static constexpr index_t GetVectorSizeA()
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:134
static CK_TILE_HOST std::string Print()
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:190
static constexpr index_t Preshuffle
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:159
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, void *p_smem) const
Quant operator(), single input: This function runs the pipeline using compile-time known hot loop and...
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:771
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, void *p_smem) const
Quant operator(), single input: This function runs the pipeline by wrapping it with the tail handler.
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:743
static constexpr bool DoubleSmemBuffer
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:157
remove_cvref_t< std::tuple_element_t< 0, BsLayout > > BLayout
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:117
static constexpr bool HasHotLoop
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:161
remove_cvref_t< typename Problem::AElementWise > AElementWise
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:108
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:130
static constexpr index_t BlockSize
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:127
remove_cvref_t< typename Problem::BsLayoutTuple > BsLayout
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:113
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:185
remove_cvref_t< typename Problem::AsDataTypeTuple > AsDataType
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:104
static constexpr index_t NumWaveGroups
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:158
static constexpr index_t GetVectorSizeC()
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:143
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:131
number< 0 > I0
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:123
remove_cvref_t< typename Problem::BsDataTypeTuple > BsDataType
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:105
static constexpr auto is_a_load_tr_v
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:167
static constexpr bool kPadN
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:154
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const BsDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, void *p_smem) const
This function runs the pipeline by wrapping it with the tail handler.
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:661
number< 1 > I1
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:124
number< 2 > I2
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:125
remove_cvref_t< typename Problem::AsLayoutTuple > AsLayout
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:112
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:110
BaseGemmPipelineAgBgCrCompV3< Problem > Base
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:101
remove_cvref_t< std::tuple_element_t< 0, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:120
static constexpr index_t GetSmemPackB()
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:151
static CK_TILE_HOST const std::string GetName()
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:173
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:634
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:19
static constexpr auto Scheduler
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:165
static constexpr auto TailNum
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:163
remove_cvref_t< typename Problem::CLayout > CLayout
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:114
remove_cvref_t< typename Problem::CDataType > CDataType
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:106
static constexpr index_t APackedSize
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:145
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const BsDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, void *p_smem) const
This function runs the pipeline using compile-time known hot loop and tail number.
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:695
static constexpr index_t GetVectorSizeB()
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:139
remove_cvref_t< decltype(Policy::template GetBlockGemm< Problem >())> BlockGemm
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:122
static constexpr index_t GetSmemPackA()
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:150
static constexpr bool kPadM
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:153
static constexpr bool kPadK
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:155
remove_cvref_t< typename Problem::BElementWise > BElementWise
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:109
static constexpr auto is_b_load_tr_v
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:168
remove_cvref_t< std::tuple_element_t< 0, AsDataType > > ADataType
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:119
GemmPipelineAgBgCrImplBase< Problem, Policy > PipelineImplBase
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:102
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:129
static constexpr index_t BPackedSize
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:147
remove_cvref_t< std::tuple_element_t< 0, AsLayout > > ALayout
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:116
Definition gemm_pipeline_ag_bg_cr_base.hpp:13
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp &b_dram_block_window_tmp, const BLdsTensorView &b_lds_block_view, const BLdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:225
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:22
CK_TILE_DEVICE auto GetABLdsTensorViews(void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:83
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:26
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:25
CK_TILE_DEVICE void LocalPrefill(DstTileWindow &lds_tile_window, const SrcBlockTile &src_block_tile, const ElementFunction &element_func) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:57
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp &a_dram_block_window_tmp, const ALdsTensorView &a_lds_block_view, const ALdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:190
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataType > > ADataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:20
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:27
Definition tile/core/numeric/integral_constant.hpp:30
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/utility/functional.hpp:43