transform_contraction_to_gemm_arraybase.hpp Source File

transform_contraction_to_gemm_arraybase.hpp Source File#

Composable Kernel: transform_contraction_to_gemm_arraybase.hpp Source File
transform_contraction_to_gemm_arraybase.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12namespace tensor_operation {
13
14// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
15template <index_t NumDimG,
16 index_t NumDimM,
17 index_t NumDimN,
19__host__ __device__ static auto
20MakeGridDescriptorPair(const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_lengths_vec,
21 const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_strides_vec)
22{
23 // if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
24 // gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
25 // {
26 // throw std::runtime_error("wrong! dimension must match input lengths");
27 // }
28
29 const auto to_tuple = [&](auto& vec, auto start, auto end) {
30 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
31 };
32
33 const auto gs_ms_ns_lengths =
34 to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
35 const auto gs_ms_ns_strides =
36 to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
37
38 // dimension Ids for G0, G1, ...
39 constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
40
41 // dimension Ids for M0, M1, ...
42 constexpr auto mDimIds =
44
45 // dimension Ids for N0, N1, ...
46 constexpr auto nDimIds =
48
49 // lengths for G0, G1, ...
50 const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds);
51
52 // lengths for M0, M1, ...
53 const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds);
54
55 // lengths for N0, N1, ...
56 const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds);
57
58 if constexpr(TensorSpec == device::TensorSpecialization::Packed)
59 {
60 auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
61 auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
62 auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
63 const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
64 make_tuple(G, M, N),
65 make_tuple(gs_ms_ns_strides[Number<NumDimG - 1>{}],
66 gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
67 gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
68
69 const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor(
70 make_tuple(M, N),
71 make_tuple(gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
72 gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
73
74 return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
75 }
76 else
77 {
78 // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
79 const auto grid_desc_gs_ms_ns =
80 make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides);
81
82 // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
83 // N2 * ...]
84 // Note: This does not require padding as it only provides G offset calculation. Technically
85 // descriptor for only G is needed. Here we opt for backward compatibility purpose to return
86 // G_M_N
87 const auto grid_desc_g_mraw_nraw =
88 transform_tensor_descriptor(grid_desc_gs_ms_ns,
90 make_merge_transform(mLengths),
91 make_merge_transform(nLengths)),
92 make_tuple(gDimIds, mDimIds, nDimIds),
93 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
94
95 const auto c_ms_ns_lengths = to_tuple(
96 gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
97 const auto c_ms_ns_strides = to_tuple(
98 gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
99
100 // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
101 // N2 * ...]
102 const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
103
104 const auto grid_desc_mraw_nraw = transform_tensor_descriptor(
105 grid_desc_ms_ns,
107 make_tuple(mDimIds - Number<NumDimG>{}, nDimIds - Number<NumDimG>{}),
108 make_tuple(Sequence<0>{}, Sequence<1>{}));
109
110 return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
111 }
112}
113
114template <typename NumDims_G_M_N_K_O, // Sequence<>
115 typename PerBlock_M_N_K_O, // Sequence<>
122{
123 static constexpr auto I0 = Number<0>{};
124 static constexpr auto I1 = Number<1>{};
125 static constexpr auto I2 = Number<2>{};
126 static constexpr auto I3 = Number<3>{};
127 static constexpr auto I4 = Number<4>{};
128
129 static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0);
130 static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1);
131 static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2);
132 static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3);
133 static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4);
134
135 static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0);
136 static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1);
137 static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2);
138 static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3);
139
143
144 //
145 // A
146 //
147 __host__ __device__ static auto MakeAGridDescriptorPair(
148 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
149 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
150 {
151 return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimK, ASpec>(a_gs_ms_ks_lengths_vec,
152 a_gs_ms_ks_strides_vec);
153 }
154
155 // TODO: rename to G_MRaw_KRaw
156 __host__ __device__ static auto MakeAGridDescriptor_G_M_K(
157 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
158 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
159 {
160 return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first;
161 }
162 __host__ __device__ static auto MakeAGridDescriptor_M_K(
163 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
164 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
165 {
166 return matrix_padder.PadADescriptor_M_K(
167 MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second);
168 }
169
170 template <typename AGridDesc_M_K, typename Number>
171 __host__ __device__ static constexpr auto
172 MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1)
173 {
174 const auto M = a_grid_desc_m_k.GetLength(I0);
175 const auto K = a_grid_desc_m_k.GetLength(I1);
176
177 const auto AK0 = K / AK1;
178
179 return transform_tensor_descriptor(a_grid_desc_m_k,
184 }
185
186 template <typename AGridDesc_M_K,
187 typename WmmaK,
188 typename MRepeat,
189 typename MWaves,
190 typename MPerWmma,
191 typename AK1>
192 __host__ __device__ static constexpr auto
194 const AGridDesc_M_K& a_grid_desc_m_k,
195 const WmmaK&,
196 const MRepeat&,
197 const MWaves&,
198 const MPerWmma&,
199 const AK1&)
200 {
201 const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlock;
202 const auto K = a_grid_desc_m_k.GetLength(I1);
203 const auto AKWmma = K / WmmaK{};
204 constexpr auto AKRow = 2;
205 constexpr auto AK0PerWmma = WmmaK{} / AKRow / AK1{};
206
208 a_grid_desc_m_k,
210 make_tuple(AKWmma, Number<AK0PerWmma>{}, Number<AKRow>{}, AK1{})),
211 make_unmerge_transform(make_tuple(M0 * MRepeat{}, MWaves{}, MPerWmma{}))),
214 }
215
216 //
217 // B (alias of B0)
218 //
219 __host__ __device__ static auto MakeB0GridDescriptorPair(
220 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
221 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
222 {
223 return MakeGridDescriptorPair<NumDimG, NumDimN, NumDimK, B0Spec>(b0_gs_ns_ks_lengths_vec,
224 b0_gs_ns_ks_strides_vec);
225 }
226
227 // TODO: rename to G_MRaw_NRaw
228 __host__ __device__ static auto MakeB0GridDescriptor_G_N_K(
229 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
230 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
231 {
232 return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first;
233 }
234 __host__ __device__ static auto MakeB0GridDescriptor_N_K(
235 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
236 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
237 {
238 // alias of matrix_padder.PadB0Descriptor_N_K
239 return matrix_padder.PadBDescriptor_N_K(
240 MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second);
241 }
242
243 template <typename BGridDesc_N_K, typename Number>
244 __host__ __device__ static constexpr auto
245 MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1)
246 {
247 const auto N = b_grid_desc_n_k.GetLength(I0);
248 const auto K = b_grid_desc_n_k.GetLength(I1);
249
250 const auto BK0 = K / BK1;
251
252 return transform_tensor_descriptor(b_grid_desc_n_k,
257 }
258
259 template <typename BGridDesc_L_K,
260 typename WmmaK,
261 typename LRepeat,
262 typename LWaves,
263 typename LPerWmma,
264 typename BK1>
265 __host__ __device__ static constexpr auto
267 const BGridDesc_L_K& b_grid_desc_l_k,
268 const WmmaK&,
269 const LRepeat&,
270 const LWaves&,
271 const LPerWmma&,
272 const BK1&)
273 {
274 const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock;
275 const auto K = b_grid_desc_l_k.GetLength(I1);
276 const auto BKWmma = K / WmmaK{};
277 constexpr auto BKRow = 2;
278 constexpr auto BK0PerWmma = WmmaK{} / BKRow / BK1{};
279
281 b_grid_desc_l_k,
283 make_tuple(BKWmma, Number<BK0PerWmma>{}, Number<BKRow>{}, BK1{})),
284 make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))),
287 }
288
289 //
290 // B1
291 //
292 __host__ __device__ static auto MakeB1GridDescriptorPair(
293 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
294 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
295 {
296 return MakeGridDescriptorPair<NumDimG, NumDimO, NumDimN, B1Spec>(b1_gs_os_ns_lengths_vec,
297 b1_gs_os_ns_strides_vec);
298 }
299
300 // TODO: rename to G_NRaw_KRaw
301 __host__ __device__ static auto MakeB1GridDescriptor_G_N_K(
302 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
303 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
304 {
305 return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first;
306 }
307 __host__ __device__ static auto MakeB1GridDescriptor_N_K(
308 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
309 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
310 {
311 // alias of matrix_padder.PadB1Descriptor_O_N
312 return matrix_padder.PadB1Descriptor_N_K(
313 MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second);
314 }
315
316 template <typename B1GridDesc_N_K, typename Number>
317 __host__ __device__ static constexpr auto
318 MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1)
319 {
320 const auto N = b1_grid_desc_n_k.GetLength(I0);
321 const auto K = b1_grid_desc_n_k.GetLength(I1);
322
323 const auto B1K0 = K / B1K1;
324
326 b1_grid_desc_n_k,
331 }
332
333 template <typename BGridDesc_N_L,
334 typename WmmaL,
335 typename NRepeat,
336 typename NWaves,
337 typename NPerWmma,
338 typename BL1>
339 __host__ __device__ static constexpr auto
341 const BGridDesc_N_L& b_grid_desc_n_l,
342 const WmmaL&,
343 const NRepeat&,
344 const NWaves&,
345 const NPerWmma&,
346 const BL1&)
347 {
348 const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock;
349 const auto L = b_grid_desc_n_l.GetLength(I1);
350 const auto BLWmma = L / WmmaL{};
351 constexpr auto BLRow = 2;
352 constexpr auto BL0PerWmma = WmmaL{} / BLRow / BL1{};
353
355 b_grid_desc_n_l,
357 make_tuple(BLWmma, Number<BL0PerWmma>{}, Number<BLRow>{}, BL1{})),
358 make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))),
361 }
362
363 //
364 // C
365 //
366 __host__ __device__ static auto MakeCGridDescriptorPair(
367 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
368 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
369 {
370 return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimO, CSpec>(c_gs_ms_os_lengths_vec,
371 c_gs_ms_os_strides_vec);
372 }
373
374 // TODO: rename to G_MRaw_NRaw
375 __host__ __device__ static auto MakeCGridDescriptor_G_M_N(
376 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
377 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
378 {
379 return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first;
380 }
381 __host__ __device__ static auto MakeCGridDescriptor_M_N(
382 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
383 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
384 {
385 return matrix_padder.PadCDescriptor_M_N(
386 MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
387 }
388};
389
390} // namespace tensor_operation
391} // namespace ck
TensorSpecialization
Definition tensor_specialization.hpp:11
@ Packed
Definition tensor_specialization.hpp:13
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition utility/container_helper.hpp:346
__host__ __device__ constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition utility/container_helper.hpp:111
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition transform_contraction_to_gemm_arraybase.hpp:122
__host__ static __device__ auto MakeCGridDescriptorPair(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:366
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k, const Number &BK1)
Definition transform_contraction_to_gemm_arraybase.hpp:245
__host__ static __device__ auto MakeCGridDescriptor_G_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:375
__host__ static __device__ constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const Number &AK1)
Definition transform_contraction_to_gemm_arraybase.hpp:172
__host__ static __device__ auto MakeB1GridDescriptorPair(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:292
__host__ static __device__ auto MakeB1GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:307
__host__ static __device__ auto MakeB1GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:301
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1(const BGridDesc_L_K &b_grid_desc_l_k, const WmmaK &, const LRepeat &, const LWaves &, const LPerWmma &, const BK1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:266
__host__ static __device__ auto MakeB0GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:228
__host__ static __device__ auto MakeB0GridDescriptorPair(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:219
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition transform_contraction_to_gemm_arraybase.hpp:318
__host__ static __device__ auto MakeAGridDescriptor_G_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:156
__host__ static __device__ constexpr auto MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const WmmaK &, const MRepeat &, const MWaves &, const MPerWmma &, const AK1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:193
__host__ static __device__ auto MakeCGridDescriptor_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:381
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1(const BGridDesc_N_L &b_grid_desc_n_l, const WmmaL &, const NRepeat &, const NWaves &, const NPerWmma &, const BL1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:340
__host__ static __device__ auto MakeAGridDescriptor_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:162
__host__ static __device__ auto MakeAGridDescriptorPair(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:147
__host__ static __device__ auto MakeB0GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:234
Definition matrix_padder.hpp:63