gridwise_gemm_dl_v1r3.hpp Source File

gridwise_gemm_dl_v1r3.hpp Source File#

Composable Kernel: gridwise_gemm_dl_v1r3.hpp Source File
gridwise_gemm_dl_v1r3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
17
18namespace ck {
19
20template <typename GridwiseGemm,
21 typename FloatAB,
22 typename FloatC,
23 typename AGridDesc_K0_M0_M1_K1,
24 typename BGridDesc_K0_N0_N1_K1,
25 typename CGridDesc_M0_M10_M11_N0_N10_N11,
26 typename Block2CTileMap,
27 bool HasMainKBlockLoop,
28 bool HasDoubleTailKBlockLoop>
29__global__ void
30#if CK_USE_LAUNCH_BOUNDS
32#endif
33 kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid,
34 const FloatAB* __restrict__ p_b_grid,
35 FloatC* __restrict__ p_c_grid,
36 const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
37 const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
38 const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
39 const Block2CTileMap block_2_ctile_map)
40{
41 constexpr index_t shared_block_size =
42 GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
43
44 __shared__ FloatAB p_shared_block[shared_block_size];
45
46 GridwiseGemm::Run(p_a_grid,
47 p_b_grid,
48 p_c_grid,
49 p_shared_block,
50 a_grid_desc_k0_m0_m1_k1,
51 b_grid_desc_k0_n0_n1_k1,
52 c_grid_desc_m0_m10_m11_n0_n10_n11,
53 block_2_ctile_map,
56}
57
58template <index_t BlockSize,
59 typename FloatAB,
60 typename FloatAcc,
61 typename FloatC,
62 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
63 typename AGridDesc_K0_M_K1,
64 typename BGridDesc_K0_N_K1,
65 typename CGridDesc_M_N,
66 index_t MPerBlock,
67 index_t NPerBlock,
68 index_t K0PerBlock,
69 index_t K1Value,
70 index_t M1PerThreadM111,
71 index_t N1PerThreadN111,
72 index_t KPerThread,
73 typename M11N11ThreadClusterM110Xs,
74 typename M11N11ThreadClusterN110Xs,
75 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
76 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
77 typename ABlockTransferThreadClusterArrangeOrder,
78 typename ABlockTransferSrcAccessOrder,
79 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
80 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
81 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
82 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
83 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
84 typename BBlockTransferThreadClusterArrangeOrder,
85 typename BBlockTransferSrcAccessOrder,
86 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
87 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
88 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
89 typename CThreadTransferSrcDstAccessOrder,
90 index_t CThreadTransferSrcDstVectorDim,
91 index_t CThreadTransferDstScalarPerVector>
93{
94 static constexpr auto I0 = Number<0>{};
95 static constexpr auto I1 = Number<1>{};
96 static constexpr auto I2 = Number<2>{};
97 static constexpr auto I3 = Number<3>{};
98
99 // K1 should be Number<...>
100 static constexpr auto K1 = Number<K1Value>{};
101
102 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
103 {
104 // TODO: change this. I think it needs multi-dimensional alignment
105 constexpr auto max_lds_align = K1;
106
107 // TODO: check alignment
108 // A matrix in LDS memory, dst of blockwise copy
109 constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
110 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
111
112 // TODO: check alignment
113 // B matrix in LDS memory, dst of blockwise copy
114 constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
115 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
116
117 // TODO: check alignment
118 // LDS allocation for A and B: be careful of alignment
119 constexpr auto a_block_aligned_space_size =
120 math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
121
122 constexpr auto b_block_aligned_space_size =
123 math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
124
125 return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
126 }
127
128 __host__ __device__ static constexpr bool
129 CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
130 const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
131 const CGridDesc_M_N& c_grid_desc_m_n)
132 {
133 const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
134 const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
135 const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
136
137 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
138
139 return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
140 K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
141 K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
142 K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
143 (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
144 }
145
146 __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
147 {
148 const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
149
150 return grid_size;
151 }
152
153 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
154 {
155 const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
156
157 return has_main_k_block_loop;
158 }
159
160 __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
161 {
162 const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
163
164 return has_double_tail_k_block_loop;
165 }
166
167 __host__ __device__ static constexpr auto
168 MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
169 {
170 const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
171 const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
172
173 const auto M1 = Number<MPerBlock>{};
174 const auto M0 = M / M1;
175
176 const auto a_grid_desc_k0_m0_m1_k1 =
177 transform_tensor_descriptor(a_grid_desc_k0_m_k1,
183
184 return a_grid_desc_k0_m0_m1_k1;
185 }
186
187 __host__ __device__ static constexpr auto
188 MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
189 {
190 const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
191 const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
192
193 const auto N1 = Number<NPerBlock>{};
194 const auto N0 = N / N1;
195
196 const auto b_grid_desc_k0_n0_n1_k1 =
197 transform_tensor_descriptor(b_grid_desc_k0_n_k1,
203
204 return b_grid_desc_k0_n0_n1_k1;
205 }
206
207 __host__ __device__ static constexpr auto
208 MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
209 {
210 const auto M = c_grid_desc_m_n.GetLength(I0);
211 const auto N = c_grid_desc_m_n.GetLength(I1);
212
213 constexpr auto M1 = Number<MPerBlock>{};
214 constexpr auto N1 = Number<NPerBlock>{};
215
216 const auto M0 = M / M1;
217 const auto N0 = N / N1;
218
219 constexpr auto M11 =
220 Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
221 M1PerThreadM111>{};
222 constexpr auto N11 =
223 Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
224 N1PerThreadN111>{};
225
226 constexpr auto M10 = M1 / M11;
227 constexpr auto N10 = N1 / N11;
228
229 const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
230 c_grid_desc_m_n,
232 make_unmerge_transform(make_tuple(N0, N10, N11))),
235
236 return c_grid_desc_m0_m10_m11_n0_n10_n11;
237 }
238
239 // return block_id to C matrix tile idx (m0, n0) mapping
240 __host__ __device__ static constexpr auto
241 MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
242 {
244 c_grid_desc_m_n);
245 }
246
247 using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
248 using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
250 decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
251 using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
252
253 template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
254 __device__ static void
255 Run(const FloatAB* __restrict__ p_a_grid,
256 const FloatAB* __restrict__ p_b_grid,
257 FloatC* __restrict__ p_c_grid,
258 FloatAB* __restrict__ p_shared_block,
259 const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
260 const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
261 const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
262 const Block2CTileMap& block_2_ctile_map,
265 {
266 const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
267 p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
268 const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
269 p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
271 p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
272
273 // divide block work by [M, N]
274 const auto c_m0_n0_block_cluster_idx =
275 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
276
277 // HACK: this forces index data into SGPR
278 const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
279 const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
280
281 if(!block_2_ctile_map.ValidCTileIndex(
282 make_tuple(im0, in0),
283 make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
284 c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
285 {
286 return;
287 }
288
289 // TODO: change this. I think it needs multi-dimensional alignment
290 constexpr auto max_lds_align = K1;
291
292 // TODO: check alignment
293 // A matrix in LDS memory, dst of blockwise copy
294 // be careful of LDS alignment
295 constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
296 make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
297
298 // TODO: check alignment
299 // B matrix in LDS memory, dst of blockwise copy
300 // be careful of LDS alignment
301 constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
302 make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
303
304 // TODO: check alignment
305 // A matrix in LDS memory, for blockwise GEMM
306 constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
307 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
308
309 // TODO: check alignment
310 // B matrix in LDS memory, for blockwise GEMM
311 constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
312 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
313
314 static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
315 a_k0_m_k1_block_desc.GetElementSpaceSize() &&
316 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
317 b_k0_n_k1_block_desc.GetElementSpaceSize() &&
318 "wrong!");
319
320 // A matrix blockwise copy
321 auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
322 BlockSize,
324 Sequence<K0PerBlock, 1, MPerBlock, K1.value>,
325 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
326 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
327 ABlockTransferThreadClusterArrangeOrder,
328 FloatAB,
329 FloatAB,
330 remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
331 decltype(a_block_desc_k0_m0_m1_k1),
332 ABlockTransferSrcAccessOrder,
334 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
335 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
336 ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
337 Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
338 false,
339 true>(a_grid_desc_k0_m0_m1_k1,
340 make_multi_index(0, im0, 0, 0),
341 a_block_desc_k0_m0_m1_k1,
342 make_multi_index(0, 0, 0, 0));
343
344 // B matrix blockwise copy
345 auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
346 BlockSize,
348 Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
349 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
350 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
351 BBlockTransferThreadClusterArrangeOrder,
352 FloatAB,
353 FloatAB,
354 remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
355 decltype(b_block_desc_k0_n0_n1_k1),
356 BBlockTransferSrcAccessOrder,
358 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
359 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
360 BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
361 Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
362 false,
363 true>(b_grid_desc_k0_n0_n1_k1,
364 make_multi_index(0, in0, 0, 0),
365 b_block_desc_k0_n0_n1_k1,
366 make_multi_index(0, 0, 0, 0));
367
368 // GEMM definition
369 // c_mtx += transpose(a_mtx) * b_mtx
370 // a_mtx[K0PerBlock, MPerBlock] is in LDS
371 // b_mtx[KPerBlocl, NPerBlock] is in LDS
372 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
373 // register
374 const auto blockwise_gemm =
376 BlockSize,
377 FloatAB,
378 FloatAB,
379 FloatAcc,
380 decltype(a_k0_m_k1_block_desc),
381 decltype(b_k0_n_k1_block_desc),
382 M1PerThreadM111,
383 N1PerThreadN111,
384 KPerThread,
385 M11N11ThreadClusterM110Xs,
386 M11N11ThreadClusterN110Xs,
387 M1PerThreadM111,
388 N1PerThreadN111>{};
389
390 constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
391 decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
392
393 constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
394 sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
395
396 // LDS allocation for A and B: be careful of alignment
397 constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
398 a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
399
400 constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
401 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
402
403 FloatAB* p_a_block_double = p_shared_block;
404 FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
405
406 // register allocation for output
408 c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
409
410 // Initialize C
411 c_thread_buf.Clear();
412
413 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
414 constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
415
416 auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
417 p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
418 auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
419 p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
420
421 auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
422 p_a_block_double + a_block_aligned_space_size,
423 a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
424 auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
425 p_b_block_double + b_block_aligned_space_size,
426 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
427
428 // LDS double buffer: preload data into LDS
429 {
430 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
431 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
432
433 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
434 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
435 }
436
437 if constexpr(HasMainKBlockLoop)
438 {
439 const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
440
441 index_t k_block_data_begin = 0;
442
443 // LDS double buffer: main body
444 // use Do-While loop instead of For loop to simplify control flow
445 do
446 {
447 // even iteration
448 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
449 a_block_slice_copy_step);
450 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
451 b_block_slice_copy_step);
452
453 // LDS doubel buffer: load next data from device mem
454 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
455 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
456
458
459 // LDS double buffer: GEMM on current data
460 blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
461 a_block_even_buf,
462 b_block_even_buf,
463 c_thread_buf);
464
465 // LDS double buffer: store next data to LDS
466 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
467 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
468
469 // odd iteration
470 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
471 a_block_slice_copy_step);
472 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
473 b_block_slice_copy_step);
474
475 // LDS double buffer: load next data from device mem
476 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
477 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
478
480
481 // LDS double buffer: GEMM on current data
482 blockwise_gemm.Run(
483 c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
484
485 // LDS double buffer: store next data to LDS
486 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
487 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
488
489 k_block_data_begin += 2 * K0PerBlock;
490 } while(k_block_data_begin < K0 - 2 * K0PerBlock);
491 }
492
493 // LDS double buffer: tail
494 if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
495 {
496 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
497 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
498
500
501 // LDS double buffer: load last data from device mem
502 a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
503 b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
504
505 // LDS double buffer: GEMM on 2nd-last data
506 blockwise_gemm.Run(
507 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
508
509 // LDS double buffer: store last data to LDS
510 a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
511 b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
512
514
515 // LDS double buffer: GEMM on last data
516 blockwise_gemm.Run(
517 c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
518 }
519 else // if has 1 iteration left
520 {
521 __syncthreads();
522
523 // LDS double buffer: GEMM on last data
524 blockwise_gemm.Run(
525 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
526 }
527
528 // output: register to global memory
529 {
530 constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
533 Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
535 I1,
538
539 const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
540 blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
542
544 FloatAcc,
545 FloatC,
546 decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
547 decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
549 Sequence<1,
550 c_m10_m11_n10_n11_thread_tensor_lengths[I0],
551 c_m10_m11_n10_n11_thread_tensor_lengths[I1],
552 1,
553 c_m10_m11_n10_n11_thread_tensor_lengths[I2],
554 c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
555 CThreadTransferSrcDstAccessOrder,
556 CThreadTransferSrcDstVectorDim,
557 CThreadTransferDstScalarPerVector,
558 CGlobalMemoryDataOperation,
559 1,
560 true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
562 c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
563 c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
564 in0,
565 c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
566 c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
568 .Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
569 make_tuple(I0, I0, I0, I0, I0, I0),
570 c_thread_buf,
571 c_grid_desc_m0_m10_m11_n0_n10_n11,
572 c_grid_buf);
573 }
574 }
575};
576
577template <index_t BlockSize,
578 typename FloatAB,
579 typename FloatAcc,
580 typename FloatC,
581 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
582 typename AGridDesc_B_K0_M_K1,
583 typename BGridDesc_B_K0_N_K1,
584 typename CGridDesc_M_N,
585 index_t MPerBlock,
586 index_t NPerBlock,
587 index_t K0PerBlock,
588 index_t K1Value,
589 index_t M1PerThreadM111,
590 index_t N1PerThreadN111,
591 index_t KPerThread,
592 typename M11N11ThreadClusterM110Xs,
593 typename M11N11ThreadClusterN110Xs,
594 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
595 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
596 typename ABlockTransferThreadClusterArrangeOrder,
597 typename ABlockTransferSrcAccessOrder,
598 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
599 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
600 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
601 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
602 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
603 typename BBlockTransferThreadClusterArrangeOrder,
604 typename BBlockTransferSrcAccessOrder,
605 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
606 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
607 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
608 typename CThreadTransferSrcDstAccessOrder,
609 index_t CThreadTransferSrcDstVectorDim,
610 index_t CThreadTransferDstScalarPerVector>
612{
613 static constexpr auto I0 = Number<0>{};
614 static constexpr auto I1 = Number<1>{};
615 static constexpr auto I2 = Number<2>{};
616 static constexpr auto I3 = Number<3>{};
617
618 // K1 should be Number<...>
619 static constexpr auto K1 = Number<K1Value>{};
620
621 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
622 {
623 // TODO: change this. I think it needs multi-dimensional alignment
624 constexpr auto max_lds_align = K1;
625
626 // TODO: check alignment
627 // A matrix in LDS memory, dst of blockwise copy
628 constexpr auto a_block_desc_b_k0_m_k1 = make_naive_tensor_descriptor_aligned(
630
631 // TODO: check alignment
632 // B matrix in LDS memory, dst of blockwise copy
633 constexpr auto b_block_desc_b_k0_n_k1 = make_naive_tensor_descriptor_aligned(
635
636 // TODO: check alignment
637 // LDS allocation for A and B: be careful of alignment
638 constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
639 a_block_desc_b_k0_m_k1.GetElementSpaceSize(), max_lds_align);
640
641 constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
642 b_block_desc_b_k0_n_k1.GetElementSpaceSize(), max_lds_align);
643
644 return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
645 }
646
647 __host__ __device__ static constexpr bool
648 CheckValidity(const AGridDesc_B_K0_M_K1& a_grid_desc_b_k0_m_k1,
649 const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1,
650 const CGridDesc_M_N& c_grid_desc_m_n)
651 {
652 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
653
654 if(!(a_grid_desc_b_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
655 b_grid_desc_b_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
656 c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
657 {
658 return false;
659 }
660
661 const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
662 const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
663 const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
664 const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(I0);
665
666 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
667
668 return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
669 K0 == b_grid_desc_b_k0_n_k1.GetLength(I1) &&
670 K1 == a_grid_desc_b_k0_m_k1.GetLength(I3) &&
671 K1 == b_grid_desc_b_k0_n_k1.GetLength(I3)) &&
672 KBatch == b_grid_desc_b_k0_n_k1.GetLength(I0) &&
673 (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
674 }
675
676 __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
677 {
678 const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
679
680 return grid_size;
681 }
682
683 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
684 {
685 const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
686
687 return has_main_k_block_loop;
688 }
689
690 __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
691 {
692 const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
693
694 return has_double_tail_k_block_loop;
695 }
696
697 __host__ __device__ static constexpr auto
698 MakeAGridDescriptor_B_K0_M0_M1_K1(const AGridDesc_B_K0_M_K1& a_grid_desc_b_k0_m_k1)
699 {
700 const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(I0);
701 const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
702 const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
703
704 const auto M1 = Number<MPerBlock>{};
705 const auto M0 = M / M1;
706
707 const auto a_grid_desc_b_k0_m0_m1_k1 = transform_tensor_descriptor(
708 a_grid_desc_b_k0_m_k1,
715
716 return a_grid_desc_b_k0_m0_m1_k1;
717 }
718
719 __host__ __device__ static constexpr auto
720 MakeBGridDescriptor_B_K0_N0_N1_K1(const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1)
721 {
722 const auto KBatch = b_grid_desc_b_k0_n_k1.GetLength(I0);
723 const auto K0 = b_grid_desc_b_k0_n_k1.GetLength(I1);
724 const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
725
726 const auto N1 = Number<NPerBlock>{};
727 const auto N0 = N / N1;
728
729 const auto b_grid_desc_b_k0_n0_n1_k1 = transform_tensor_descriptor(
730 b_grid_desc_b_k0_n_k1,
737
738 return b_grid_desc_b_k0_n0_n1_k1;
739 }
740
741 __host__ __device__ static constexpr auto
742 MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
743 {
744 const auto M = c_grid_desc_m_n.GetLength(I0);
745 const auto N = c_grid_desc_m_n.GetLength(I1);
746
747 constexpr auto M1 = Number<MPerBlock>{};
748 constexpr auto N1 = Number<NPerBlock>{};
749
750 const auto M0 = M / M1;
751 const auto N0 = N / N1;
752
753 constexpr auto M11 =
754 Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
755 M1PerThreadM111>{};
756 constexpr auto N11 =
757 Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
758 N1PerThreadN111>{};
759
760 constexpr auto M10 = M1 / M11;
761 constexpr auto N10 = N1 / N11;
762
763 const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
764 c_grid_desc_m_n,
766 make_unmerge_transform(make_tuple(N0, N10, N11))),
769
770 return c_grid_desc_m0_m10_m11_n0_n10_n11;
771 }
772
773 // return block_id to C matrix tile idx (m0, n0) mapping
774 __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
775 const CGridDesc_M_N& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
776 {
778 c_m_n_grid_desc, M01, N01, KBatch);
779 }
780
782 decltype(MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{}));
784 decltype(MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{}));
786 decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
787 using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
788
789 template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
790 __device__ static void
791 Run(const FloatAB* __restrict__ p_a_grid,
792 const FloatAB* __restrict__ p_b_grid,
793 FloatC* __restrict__ p_c_grid,
794 FloatAB* __restrict__ p_shared_block,
795 const AGridDesc_B_K0_M0_M1_K1& a_grid_desc_b_k0_m0_m1_k1,
796 const BGridDesc_B_K0_N0_N1_K1& b_grid_desc_b_k0_n0_n1_k1,
797 const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
798 const CBlockClusterAdaptor& c_block_cluster_adaptor,
801 {
802 const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
803 p_a_grid, a_grid_desc_b_k0_m0_m1_k1.GetElementSpaceSize());
804 const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
805 p_b_grid, b_grid_desc_b_k0_n0_n1_k1.GetElementSpaceSize());
807 p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
808
809 // divide block work by [M, N]
810 const auto block_work_idx =
811 c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
812
813 const index_t k_batch_id = block_work_idx[I0];
814
815 if(!c_block_cluster_adaptor.ValidCTileIndex(
816 make_tuple(block_work_idx[I1], block_work_idx[I2]),
817 make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
818 c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
819 {
820 return;
821 }
822
823 // HACK: this force m/n_block_data_idx_on_grid into SGPR
824 const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
825
826 const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
827
828 // TODO: change this. I think it needs multi-dimensional alignment
829 constexpr auto max_lds_align = K1;
830
831 // TODO: check alignment
832 // A matrix in LDS memory, dst of blockwise copy
833 // be careful of LDS alignment
834 constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
835 make_tuple(I1, Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
836
837 // TODO: check alignment
838 // B matrix in LDS memory, dst of blockwise copy
839 // be careful of LDS alignment
840 constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
841 make_tuple(I1, Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
842
843 // TODO: check alignment
844 // A matrix in LDS memory, dst of blockwise copy
845 // be careful of LDS alignment
846 constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
847 make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
848
849 // TODO: check alignment
850 // B matrix in LDS memory, dst of blockwise copy
851 // be careful of LDS alignment
852 constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
853 make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
854
855 // TODO: check alignment
856 // A matrix in LDS memory, for blockwise GEMM
857 constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
858 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
859
860 // TODO: check alignment
861 // B matrix in LDS memory, for blockwise GEMM
862 constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
863 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
864
865 static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
866 a_k0_m_k1_block_desc.GetElementSpaceSize() &&
867 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
868 b_k0_n_k1_block_desc.GetElementSpaceSize() &&
869 "wrong!");
870
871 // A matrix blockwise copy
872 auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
873 BlockSize,
875 Sequence<1, K0PerBlock, 1, MPerBlock, K1.value>,
876 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
877 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
878 ABlockTransferThreadClusterArrangeOrder,
879 FloatAB,
880 FloatAB,
881 remove_reference_t<decltype(a_grid_desc_b_k0_m0_m1_k1)>,
882 decltype(a_block_desc_b_k0_m0_m1_k1),
883 ABlockTransferSrcAccessOrder,
885 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
886 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
887 ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
888 Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
889 false,
890 true>(a_grid_desc_b_k0_m0_m1_k1,
891 make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0, 0),
892 a_block_desc_b_k0_m0_m1_k1,
893 make_multi_index(0, 0, 0, 0, 0));
894
895 // B matrix blockwise copy
896 auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
897 BlockSize,
899 Sequence<1, K0PerBlock, 1, NPerBlock, K1.value>,
900 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
901 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
902 BBlockTransferThreadClusterArrangeOrder,
903 FloatAB,
904 FloatAB,
905 remove_reference_t<decltype(b_grid_desc_b_k0_n0_n1_k1)>,
906 decltype(b_block_desc_b_k0_n0_n1_k1),
907 BBlockTransferSrcAccessOrder,
909 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
910 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
911 BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
912 Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
913 false,
914 true>(b_grid_desc_b_k0_n0_n1_k1,
915 make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0, 0),
916 b_block_desc_b_k0_n0_n1_k1,
917 make_multi_index(0, 0, 0, 0, 0));
918
919 // GEMM definition
920 // c_mtx += transpose(a_mtx) * b_mtx
921 // a_mtx[K0PerBlock, MPerBlock] is in LDS
922 // b_mtx[KPerBlocl, NPerBlock] is in LDS
923 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
924 // register
925 const auto blockwise_gemm =
927 BlockSize,
928 FloatAB,
929 FloatAB,
930 FloatAcc,
931 decltype(a_k0_m_k1_block_desc),
932 decltype(b_k0_n_k1_block_desc),
933 M1PerThreadM111,
934 N1PerThreadN111,
935 KPerThread,
936 M11N11ThreadClusterM110Xs,
937 M11N11ThreadClusterN110Xs,
938 M1PerThreadM111,
939 N1PerThreadN111>{};
940
941 constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
942 decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
943
944 constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
945 sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
946
947 // LDS allocation for A and B: be careful of alignment
948 constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
949 a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
950
951 constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
952 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
953
954 FloatAB* p_a_block_double = p_shared_block;
955 FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
956
957 // register allocation for output
959 c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
960
961 // Initialize C
962 c_thread_buf.Clear();
963
964 constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
965 constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
966
967 auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
968 p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
969 auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
970 p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
971
972 auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
973 p_a_block_double + a_block_aligned_space_size,
974 a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
975 auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
976 p_b_block_double + b_block_aligned_space_size,
977 b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
978
979 // LDS double buffer: preload data into LDS
980 {
981 a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
982 b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
983
984 a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf);
985 b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf);
986 }
987
988 if constexpr(HasMainKBlockLoop)
989 {
990 const auto K0 = a_grid_desc_b_k0_m0_m1_k1.GetLength(I1);
991
992 index_t k_block_data_begin = 0;
993
994 // LDS double buffer: main body
995 // use Do-While loop instead of For loop to simplify control flow
996 do
997 {
998 // even iteration
999 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1,
1000 a_block_slice_copy_step);
1001 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
1002 b_block_slice_copy_step);
1003
1004 // LDS double buffer: load next data from device mem
1005 a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
1006 b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
1007
1009
1010 // LDS double buffer: GEMM on current data
1011 blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
1012 a_block_even_buf,
1013 b_block_even_buf,
1014 c_thread_buf);
1015
1016 // LDS double buffer: store next data to LDS
1017 a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf);
1018 b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf);
1019
1020 // odd iteration
1021 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1,
1022 a_block_slice_copy_step);
1023 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
1024 b_block_slice_copy_step);
1025
1026 // LDS doubel buffer: load next data from device mem
1027 a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
1028 b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
1029
1031
1032 // LDS double buffer: GEMM on current data
1033 blockwise_gemm.Run(
1034 c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
1035
1036 // LDS double buffer: store next data to LDS
1037 a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf);
1038 b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf);
1039
1040 k_block_data_begin += 2 * K0PerBlock;
1041 } while(k_block_data_begin < K0 - 2 * K0PerBlock);
1042 }
1043
1044 // LDS double buffer: tail
1045 if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
1046 {
1047 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1, a_block_slice_copy_step);
1048 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, b_block_slice_copy_step);
1049
1051
1052 // LDS double buffer: load last data from device mem
1053 a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
1054 b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
1055
1056 // LDS double buffer: GEMM on 2nd-last data
1057 blockwise_gemm.Run(
1058 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
1059
1060 // LDS double buffer: store last data to LDS
1061 a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf);
1062 b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf);
1063
1065
1066 // LDS double buffer: GEMM on last data
1067 blockwise_gemm.Run(
1068 c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
1069 }
1070 else // if has 1 iteration left
1071 {
1072 __syncthreads();
1073
1074 // LDS double buffer: GEMM on last data
1075 blockwise_gemm.Run(
1076 c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
1077 }
1078
1079 // output: register to global memory
1080 {
1081 constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
1083 make_tuple(I1,
1084 Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
1086 I1,
1089
1090 const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
1091 blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
1093
1095 FloatAcc,
1096 FloatC,
1097 decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
1098 decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
1100 Sequence<1,
1101 c_m10_m11_n10_n11_thread_tensor_lengths[I0],
1102 c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1103 1,
1104 c_m10_m11_n10_n11_thread_tensor_lengths[I2],
1105 c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
1106 CThreadTransferSrcDstAccessOrder,
1107 CThreadTransferSrcDstVectorDim,
1108 CThreadTransferDstScalarPerVector,
1109 CGlobalMemoryDataOperation,
1110 1,
1111 true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
1112 make_multi_index(m_block_data_idx_on_grid,
1113 c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
1114 c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
1115 n_block_data_idx_on_grid,
1116 c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
1117 c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
1119 .Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
1120 make_tuple(I0, I0, I0, I0, I0, I0),
1121 c_thread_buf,
1122 c_grid_desc_m0_m10_m11_n0_n10_n11,
1123 c_grid_buf);
1124 }
1125 }
1126};
1127
1128} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__global__ void kernel_gemm_dl_v1r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_dl_v1r3.hpp:33
__host__ __device__ constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition utility/container_helper.hpp:111
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition block_to_ctile_map.hpp:720
Definition block_to_ctile_map.hpp:617
Definition blockwise_tensor_slice_transfer_v5r1.hpp:37
Definition gridwise_gemm_dl_v1r3.hpp:612
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition gridwise_gemm_dl_v1r3.hpp:676
__host__ static __device__ constexpr auto MakeCBlockClusterAdaptor(const CGridDesc_M_N &c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
Definition gridwise_gemm_dl_v1r3.hpp:774
__host__ static __device__ constexpr auto MakeAGridDescriptor_B_K0_M0_M1_K1(const AGridDesc_B_K0_M_K1 &a_grid_desc_b_k0_m_k1)
Definition gridwise_gemm_dl_v1r3.hpp:698
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_v1r3.hpp:683
__host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_dl_v1r3.hpp:621
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:742
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_B_K0_M_K1 &a_grid_desc_b_k0_m_k1, const BGridDesc_B_K0_N_K1 &b_grid_desc_b_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:648
__host__ static __device__ constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_v1r3.hpp:690
__host__ static __device__ constexpr auto MakeBGridDescriptor_B_K0_N0_N1_K1(const BGridDesc_B_K0_N_K1 &b_grid_desc_b_k0_n_k1)
Definition gridwise_gemm_dl_v1r3.hpp:720
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, FloatAB *__restrict__ p_shared_block, const AGridDesc_B_K0_M0_M1_K1 &a_grid_desc_b_k0_m0_m1_k1, const BGridDesc_B_K0_N0_N1_K1 &b_grid_desc_b_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 &c_grid_desc_m0_m10_m11_n0_n10_n11, const CBlockClusterAdaptor &c_block_cluster_adaptor, integral_constant< bool, HasMainKBlockLoop >, integral_constant< bool, HasDoubleTailKBlockLoop >)
Definition gridwise_gemm_dl_v1r3.hpp:791
Definition gridwise_gemm_dl_v1r3.hpp:93
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, FloatAB *__restrict__ p_shared_block, const AGridDesc_K0_M0_M1_K1 &a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 &b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 &c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap &block_2_ctile_map, integral_constant< bool, HasMainKBlockLoop >, integral_constant< bool, HasDoubleTailKBlockLoop >)
Definition gridwise_gemm_dl_v1r3.hpp:255
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:208
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_v1r3.hpp:153
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition gridwise_gemm_dl_v1r3.hpp:146
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:129
__host__ static __device__ constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_v1r3.hpp:160
__host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_dl_v1r3.hpp:102
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:241
__host__ static __device__ constexpr auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition gridwise_gemm_dl_v1r3.hpp:168
__host__ static __device__ constexpr auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition gridwise_gemm_dl_v1r3.hpp:188
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
Definition utility/integral_constant.hpp:20
Definition utility/math.hpp:34
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340