gridwise_gemm_xdlops_v2r4.hpp Source File

gridwise_gemm_xdlops_v2r4.hpp Source File#

Composable Kernel: gridwise_gemm_xdlops_v2r4.hpp Source File
gridwise_gemm_xdlops_v2r4.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
17namespace ck {
18
19template <typename GridwiseGemm,
20 typename FloatAB,
21 typename FloatC,
22 typename ABK0MK1GridDesc,
23 typename BBK0NK1GridDesc,
24 typename CM0N0M1N1M2M3M4N2GridDesc,
25 typename AElementwiseOperation,
26 typename BElementwiseOperation,
27 typename CElementwiseOperation,
28 typename CBlockClusterAdaptor,
29 bool HasMainKBlockLoop>
30__global__ void
31#if CK_USE_LAUNCH_BOUNDS
33#endif
34 kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid,
35 const FloatAB* __restrict__ p_b_grid,
36 FloatC* __restrict__ p_c_grid,
37 const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc,
38 const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc,
39 const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
40 const AElementwiseOperation a_element_op,
41 const BElementwiseOperation b_element_op,
42 const CElementwiseOperation c_element_op,
43 const CBlockClusterAdaptor c_block_cluster_adaptor)
44{
45#ifdefined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
46 defined(__gfx12__)
47 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
48 {
49 constexpr index_t shared_block_size =
50 GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
51
52 __shared__ FloatAB p_shared_block[shared_block_size];
53
54 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
55 p_b_grid,
56 p_c_grid,
57 p_shared_block,
58 a_b_k0_m_k1_grid_desc,
59 b_b_k0_n_k1_grid_desc,
60 c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
61 a_element_op,
62 b_element_op,
63 c_element_op,
64 c_block_cluster_adaptor);
65 }
66#else
67 ignore = p_a_grid;
68 ignore = p_b_grid;
69 ignore = p_c_grid;
70 ignore = a_b_k0_m_k1_grid_desc;
71 ignore = b_b_k0_n_k1_grid_desc;
72 ignore = c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc;
73 ignore = a_element_op;
74 ignore = b_element_op;
75 ignore = c_element_op;
76 ignore = c_block_cluster_adaptor;
77#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
78}
79
80template <index_t BlockSize,
81 typename FloatAB,
82 typename FloatAcc,
83 typename FloatC,
84 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
85 typename ABK0MK1GridDesc,
86 typename BBK0NK1GridDesc,
87 typename CMNGridDesc,
88 typename AElementwiseOperation,
89 typename BElementwiseOperation,
90 typename CElementwiseOperation,
91 index_t MPerBlock,
92 index_t NPerBlock,
93 index_t K0PerBlock,
94 index_t MPerXDL,
95 index_t NPerXDL,
96 index_t K1Value,
97 index_t MRepeat,
98 index_t NRepeat,
99 typename ABlockTransferThreadClusterLengths_K0_M_K1,
100 typename ABlockTransferThreadClusterArrangeOrder,
101 typename ABlockTransferSrcAccessOrder,
102 index_t ABlockTransferSrcVectorDim,
103 index_t ABlockTransferSrcScalarPerVector,
104 index_t ABlockTransferDstScalarPerVector_K1,
105 bool AThreadTransferSrcResetCoordinateAfterRun,
106 bool ABlockLdsExtraM,
107 typename BBlockTransferThreadClusterLengths_K0_N_K1,
108 typename BBlockTransferThreadClusterArrangeOrder,
109 typename BBlockTransferSrcAccessOrder,
110 index_t BBlockTransferSrcVectorDim,
111 index_t BBlockTransferSrcScalarPerVector,
112 index_t BBlockTransferDstScalarPerVector_K1,
113 bool BThreadTransferSrcResetCoordinateAfterRun,
114 bool BBlockLdsExtraN,
115 typename CThreadTransferSrcDstAccessOrder,
116 index_t CThreadTransferSrcDstVectorDim,
117 index_t CThreadTransferDstScalarPerVector>
119{
120 static constexpr auto I0 = Number<0>{};
121 static constexpr auto I1 = Number<1>{};
122 static constexpr auto I2 = Number<2>{};
123 static constexpr auto I3 = Number<3>{};
124 static constexpr auto I4 = Number<4>{};
125 static constexpr auto I5 = Number<5>{};
126 static constexpr auto I6 = Number<6>{};
127 static constexpr auto I7 = Number<7>{};
128
129 // K1 should be Number<...>
130 static constexpr auto K1 = Number<K1Value>{};
131
133
134 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
135 {
136 constexpr auto max_lds_align = K1;
137
138 // A matrix in LDS memory, dst of blockwise copy
139 constexpr auto a_k0_m_k1_block_desc = [&]() {
140 if constexpr(ABlockLdsExtraM)
141 {
145 }
146 else
147 {
149 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
150 }
151 }();
152
153 // B matrix in LDS memory, dst of blockwise copy
154 constexpr auto b_k0_n_k1_block_desc = [&]() {
155 if constexpr(BBlockLdsExtraN)
156 {
160 }
161 else
162 {
164 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
165 }
166 }();
167
168 // LDS allocation for A and B: be careful of alignment
169 constexpr auto a_block_space_size =
170 math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
171
172 constexpr auto b_block_space_size =
173 math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
174
175 return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
176 }
177
178 template <
179 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
180 __device__ static bool constexpr IsValidCompilationParameter()
181 {
182 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
183 BlockSize,
184 MPerBlock,
185 NPerBlock,
186 MPerXdl,
187 NPerXdl,
188 MXdlPerWave,
189 NXdlPerWave,
190 FloatC,
191 CGlobalMemoryDataOperation>();
192 }
193
194 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
195 template <typename Block2CTileMap>
196 __host__ __device__ static constexpr bool
197 CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
198 const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
199 const CMNGridDesc& c_m_n_grid_desc,
200 const Block2CTileMap& block_2_ctile_map)
201 {
202 static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
203 "wrong! K1 need to be known at compile-time");
204
205 static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
206 (NPerBlock % (NRepeat * NPerXDL)) == 0,
207 "Invalid tuning param!");
208
209 const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
210 const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
211 const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
212 const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
213
214 if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
215 K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
216 K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
217 K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
218 KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
219 return false;
220
221 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
222 return false;
223
224 if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
225 {
226 return false;
227 }
228
229 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
230 return true;
231 }
232
233 __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
234 {
235 const bool has_main_k0_block_loop = K0 > K0PerBlock;
236
237 return has_main_k0_block_loop;
238 }
239
240 __host__ __device__ static constexpr auto
241 MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
242 {
243 constexpr auto max_lds_align = K1;
244
245 // A matrix in LDS memory, dst of blockwise copy
246 constexpr auto a_k0_m_k1_block_desc = [&]() {
247 if constexpr(ABlockLdsExtraM)
248 {
252 }
253 else
254 {
256 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
257 }
258 }();
259
260 // B matrix in LDS memory, dst of blockwise copy
261 constexpr auto b_k0_n_k1_block_desc = [&]() {
262 if constexpr(BBlockLdsExtraN)
263 {
267 }
268 else
269 {
271 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
272 }
273 }();
274
275 using BlockwiseGemm =
277 FloatAB,
278 FloatAcc,
279 decltype(a_k0_m_k1_block_desc),
280 decltype(b_k0_n_k1_block_desc),
281 MPerXDL,
282 NPerXDL,
283 MRepeat,
284 NRepeat,
285 K1>;
286
287 return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_m_n_grid_desc);
288 }
289
290 // return block_id to C matrix tile idx (m0, n0) mapping
291 __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
292 const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
293 {
295 c_m_n_grid_desc, 8, KBatch);
296 }
297
299 using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
300
301 template <bool HasMainKBlockLoop>
302 __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
303 const FloatAB* __restrict__ p_b_grid,
304 FloatC* __restrict__ p_c_grid,
305 FloatAB* __restrict__ p_shared_block,
306 const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
307 const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
308 const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
309 const AElementwiseOperation& a_element_op,
310 const BElementwiseOperation& b_element_op,
311 const CElementwiseOperation& c_element_op,
312 const CBlockClusterAdaptor& c_block_cluster_adaptor)
313 {
314 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
315 p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
316 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
317 p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
319 p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
320
321 const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
322
323 // divide block work by [M, N]
324 const auto block_work_idx =
325 c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
326
327 if(!c_block_cluster_adaptor.ValidCTileIndex(
328 make_tuple(block_work_idx[I1], block_work_idx[I2]),
329 make_tuple(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I0),
330 c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I1))))
331 {
332 return;
333 }
334
335 const index_t k_batch_id = block_work_idx[I0];
336
337 // HACK: this force m/n_block_data_idx_on_grid into SGPR
338 const index_t m_block_data_idx_on_grid =
339 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
340
341 const index_t n_block_data_idx_on_grid =
342 __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
343
344 // lds max alignment
345 constexpr auto max_lds_align = K1;
346
347 // A matrix in LDS memory, dst of blockwise copy
348 constexpr auto a_k0_m_k1_block_desc = [&]() {
349 if constexpr(ABlockLdsExtraM)
350 {
354 }
355 else
356 {
358 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
359 }
360 }();
361
362 constexpr auto a_b_k0_m_k1_block_desc = [&]() {
363 if constexpr(ABlockLdsExtraM)
364 {
369 K1,
370 I1));
371 }
372 else
373 {
376 max_lds_align);
377 }
378 }();
379 // B matrix in LDS memory, dst of blockwise copy
380 constexpr auto b_k0_n_k1_block_desc = [&]() {
381 if constexpr(BBlockLdsExtraN)
382 {
386 }
387 else
388 {
390 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
391 }
392 }();
393
394 constexpr auto b_b_k0_n_k1_block_desc = [&]() {
395 if constexpr(BBlockLdsExtraN)
396 {
401 K1,
402 I1));
403 }
404 else
405 {
408 max_lds_align);
409 }
410 }();
411 // A matrix blockwise copy
412 auto a_blockwise_copy =
414 AElementwiseOperation,
418 ABlockTransferThreadClusterLengths_K0_M_K1,
419 ABlockTransferThreadClusterArrangeOrder,
420 FloatAB,
421 FloatAB,
422 decltype(a_b_k0_m_k1_grid_desc),
423 decltype(a_b_k0_m_k1_block_desc),
424 ABlockTransferSrcAccessOrder,
426 ABlockTransferSrcVectorDim,
427 3,
428 ABlockTransferSrcScalarPerVector,
429 ABlockTransferDstScalarPerVector_K1,
430 1,
431 1,
432 AThreadTransferSrcResetCoordinateAfterRun,
433 true>(
434 a_b_k0_m_k1_grid_desc,
435 make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
436 a_element_op,
437 a_b_k0_m_k1_block_desc,
438 make_multi_index(0, 0, 0, 0),
440
441 // B matrix blockwise copy
442 auto b_blockwise_copy =
444 BElementwiseOperation,
448 BBlockTransferThreadClusterLengths_K0_N_K1,
449 BBlockTransferThreadClusterArrangeOrder,
450 FloatAB,
451 FloatAB,
452 decltype(b_b_k0_n_k1_grid_desc),
453 decltype(b_b_k0_n_k1_block_desc),
454 BBlockTransferSrcAccessOrder,
456 BBlockTransferSrcVectorDim,
457 3,
458 BBlockTransferSrcScalarPerVector,
459 BBlockTransferDstScalarPerVector_K1,
460 1,
461 1,
462 BThreadTransferSrcResetCoordinateAfterRun,
463 true>(
464 b_b_k0_n_k1_grid_desc,
465 make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
466 b_element_op,
467 b_b_k0_n_k1_block_desc,
468 make_multi_index(0, 0, 0, 0),
470
471 // GEMM definition
472 // c_mtx += transpose(a_mtx) * b_mtx
473 // a_mtx[K0PerBlock, MPerBlock] is in LDS
474 // b_mtx[K0PerBlock, NPerBlock] is in LDS
475 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
476 // register
477 // sanity check
478
479 auto blockwise_gemm =
481 FloatAB,
482 FloatAcc,
483 decltype(a_k0_m_k1_block_desc),
484 decltype(b_k0_n_k1_block_desc),
485 MPerXDL,
486 NPerXDL,
487 MRepeat,
488 NRepeat,
489 K1>{};
490
491 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
492
493 // LDS allocation for A and B: be careful of alignment
494 constexpr auto a_block_space_size =
495 math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
496
497 FloatAB* p_a_block = p_shared_block;
498 FloatAB* p_b_block = p_shared_block + a_block_space_size;
499
500 constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
501 constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
502
504 p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
506 p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
507
508 // preload data into LDS
509 {
510 a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
511 b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
512
513 a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
514 b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
515 }
516
517 // Initialize C
518 c_thread_buf.Clear();
519
520 // main body
521 if constexpr(HasMainKBlockLoop)
522 {
523 index_t k0_block_data_begin = 0;
524
525 do
526 {
527 a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
528 b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
529
530 a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
531
533
534 b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
535
536 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
537
539
540 a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
541 b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
542
543 k0_block_data_begin += K0PerBlock;
544 } while(k0_block_data_begin < (K0 - K0PerBlock));
545 }
546
547 // tail
548 {
550
551 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
552 }
553
554 // output: register to global memory
555 {
556 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
557 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
558
559 constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
560 constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
561 constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
562 constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
563 constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
564 constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
565 constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
566 constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
567
568 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
571
572 // calculate origin of thread output tensor on global memory
573 // blockwise GEMM c matrix starting index
574 const auto c_thread_mtx_on_block =
575 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
576
577 const index_t m_thread_data_on_grid =
578 m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
579
580 const index_t n_thread_data_on_grid =
581 n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
582
583 const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
585 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
588
589 const auto m_thread_data_on_grid_idx =
590 m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
591 make_multi_index(m_thread_data_on_grid));
592
593 const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
597
598 const auto n_thread_data_on_grid_idx =
599 n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
600 make_multi_index(n_thread_data_on_grid));
601
602 auto c_thread_copy =
604 FloatC,
605 decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
606 decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
607 CElementwiseOperation,
609 CThreadTransferSrcDstAccessOrder,
610 CThreadTransferSrcDstVectorDim,
611 CThreadTransferDstScalarPerVector,
612 CGlobalMemoryDataOperation,
613 1,
614 true>{
615
616 c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
617 make_multi_index(m_thread_data_on_grid_idx[I0],
618 n_thread_data_on_grid_idx[I0],
619 m_thread_data_on_grid_idx[I1],
620 n_thread_data_on_grid_idx[I1],
621 m_thread_data_on_grid_idx[I2],
622 m_thread_data_on_grid_idx[I3],
623 m_thread_data_on_grid_idx[I4],
624 n_thread_data_on_grid_idx[I2]),
625 c_element_op};
626
627 c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
628 make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
629 c_thread_buf,
630 c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
631 c_grid_buf);
632 }
633 }
634};
635
636} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
__global__ void kernel_gemm_xdlops_v2r4(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor)
Definition gridwise_gemm_xdlops_v2r4.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition block_to_ctile_map.hpp:541
Definition blockwise_gemm_smfmac_xdlops.hpp:44
Definition gridwise_gemm_xdlops_v2r4.hpp:119
__host__ static __device__ constexpr bool CheckValidity(const ABK0MK1GridDesc &a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc &b_b_k0_n_k1_grid_desc, const CMNGridDesc &c_m_n_grid_desc, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_xdlops_v2r4.hpp:197
static constexpr auto I3
Definition gridwise_gemm_xdlops_v2r4.hpp:123
__host__ static __device__ constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
Definition gridwise_gemm_xdlops_v2r4.hpp:233
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_gemm_xdlops_v2r4.hpp:132
__host__ static __device__ constexpr auto MakeCBlockClusterAdaptor(const CMNGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition gridwise_gemm_xdlops_v2r4.hpp:291
static constexpr auto I6
Definition gridwise_gemm_xdlops_v2r4.hpp:126
static constexpr auto K1
Definition gridwise_gemm_xdlops_v2r4.hpp:130
static constexpr auto I5
Definition gridwise_gemm_xdlops_v2r4.hpp:125
__host__ static __device__ constexpr auto MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc &c_m_n_grid_desc)
Definition gridwise_gemm_xdlops_v2r4.hpp:241
static constexpr auto I7
Definition gridwise_gemm_xdlops_v2r4.hpp:127
static __device__ bool constexpr IsValidCompilationParameter()
Definition gridwise_gemm_xdlops_v2r4.hpp:180
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, FloatAB *__restrict__ p_shared_block, const ABK0MK1GridDesc &a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc &b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc &c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const CBlockClusterAdaptor &c_block_cluster_adaptor)
Definition gridwise_gemm_xdlops_v2r4.hpp:302
static constexpr auto I4
Definition gridwise_gemm_xdlops_v2r4.hpp:124
decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)) CBlockClusterAdaptor
Definition gridwise_gemm_xdlops_v2r4.hpp:299
static constexpr auto I1
Definition gridwise_gemm_xdlops_v2r4.hpp:121
static constexpr auto I0
Definition gridwise_gemm_xdlops_v2r4.hpp:120
static constexpr auto I2
Definition gridwise_gemm_xdlops_v2r4.hpp:122
__host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_xdlops_v2r4.hpp:134
decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})) CM0N0M1N1M2M3M4N2GridDesc
Definition gridwise_gemm_xdlops_v2r4.hpp:298
Definition utility/sequence.hpp:43
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id)
Definition thread_group_tensor_slice_transfer_v4r1.hpp:143
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r1.hpp:119
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v4r1.hpp:153
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r1.hpp:131
Definition threadwise_tensor_slice_transfer.hpp:39
Definition is_known_at_compile_time.hpp:14
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340