tensor_layout.hpp Source File

tensor_layout.hpp Source File#

Composable Kernel: tensor_layout.hpp Source File
tile/ops/common/tensor_layout.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6// TODO: this folder does not match the single namespace rule. need to refactor in the future
7namespace ck_tile {
8namespace tensor_layout {
9
11{
12};
13
14namespace gemm {
15
17{
18 static constexpr const char* name = "RowMajor";
19};
20
22{
23 static constexpr const char* name = "ColumnMajor";
24};
25} // namespace gemm
26
27namespace convolution {
28
29// input tensor
30// packed NCW/NCHW/NCDHW
31struct NCW : public BaseTensorLayout
32{
33 static constexpr const char* name = "NCW";
34};
35
36struct NCHW : public BaseTensorLayout
37{
38 static constexpr const char* name = "NCHW";
39};
40
41struct NCDHW : public BaseTensorLayout
42{
43 static constexpr const char* name = "NCDHW";
44};
45
46// packed GNCW/GNCHW/GNCDHW
47struct GNCW : public BaseTensorLayout
48{
49 static constexpr const char* name = "GNCW";
50};
51
52struct GNCHW : public BaseTensorLayout
53{
54 static constexpr const char* name = "GNCHW";
55};
56
57struct GNCDHW : public BaseTensorLayout
58{
59 static constexpr const char* name = "GNCDHW";
60};
61
62// input tensor
63// packed NWC/NHWC/NDHWC
64struct NWC : public BaseTensorLayout
65{
66 static constexpr const char* name = "NWC";
67};
68
69struct NHWC : public BaseTensorLayout
70{
71 static constexpr const char* name = "NHWC";
72};
73
74struct NDHWC : public BaseTensorLayout
75{
76 static constexpr const char* name = "NDHWC";
77};
78
79// input tensor
80// packed GNWC/GNHWC/GNDHWC
81struct GNWC : public BaseTensorLayout
82{
83 static constexpr const char* name = "GNWC";
84};
85
86struct GNHWC : public BaseTensorLayout
87{
88 static constexpr const char* name = "GNHWC";
89};
90
91struct GNDHWC : public BaseTensorLayout
92{
93 static constexpr const char* name = "GNDHWC";
94};
95
96// for input bias
97struct GC : public BaseTensorLayout
98{
99 static constexpr const char* name = "GC";
100};
101
102// input tensor
103// packed NWGC/NHWGC/NDHWGC
104struct NWGC : public BaseTensorLayout
105{
106 static constexpr const char* name = "NWGC";
107};
108
109struct NHWGC : public BaseTensorLayout
110{
111 static constexpr const char* name = "NHWGC";
112};
113
115{
116 static constexpr const char* name = "NDHWGC";
117};
118
119// input tensor
120// strided layout
122{
123 static constexpr const char* name = "G_NW_C";
124};
125
127{
128 static constexpr const char* name = "G_NHW_C";
129};
130
132{
133 static constexpr const char* name = "G_NDHW_C";
134};
135
136// for input bias
137struct G_C : public BaseTensorLayout
138{
139 static constexpr const char* name = "G_C";
140};
141
142// weight tensor
143// packed KCX/KCYX/KCZYX
144struct KCX : public BaseTensorLayout
145{
146 static constexpr const char* name = "KCX";
147};
148
149struct KCYX : public BaseTensorLayout
150{
151 static constexpr const char* name = "KCYX";
152};
153
154struct KCZYX : public BaseTensorLayout
155{
156 static constexpr const char* name = "KCZYX";
157};
158
159// weight tensor
160// packed KCX/KCYX/KCZYX
161struct GKCX : public BaseTensorLayout
162{
163 static constexpr const char* name = "GKCX";
164};
165
166struct GKCYX : public BaseTensorLayout
167{
168 static constexpr const char* name = "GKCYX";
169};
170
172{
173 static constexpr const char* name = "GKCZYX";
174};
175
176// weight tensor
177// packed KXC/KYXC/KZYXC
178struct KXC : public BaseTensorLayout
179{
180 static constexpr const char* name = "KXC";
181};
182
183struct KYXC : public BaseTensorLayout
184{
185 static constexpr const char* name = "KYXC";
186};
187
188struct KZYXC : public BaseTensorLayout
189{
190 static constexpr const char* name = "KZYXC";
191};
192
193// weight tensor
194// packed GKXC/GKYXC/GKZYXC
195struct GKXC : public BaseTensorLayout
196{
197 static constexpr const char* name = "GKXC";
198};
199
200struct GKYXC : public BaseTensorLayout
201{
202 static constexpr const char* name = "GKYXC";
203};
204
206{
207 static constexpr const char* name = "GKZYXC";
208};
209
210// weight tensor
211// packed KXGC/KYXGC/KZYXGC
212struct KXGC : public BaseTensorLayout
213{
214 static constexpr const char* name = "KXGC";
215};
216
217struct KYXGC : public BaseTensorLayout
218{
219 static constexpr const char* name = "KYXGC";
220};
221
223{
224 static constexpr const char* name = "KZYXGC";
225};
226
227// weight tensor
228// strided
230{
231 static constexpr const char* name = "G_K_X_C";
232};
233
235{
236 static constexpr const char* name = "G_K_YX_C";
237};
238
240{
241 static constexpr const char* name = "G_K_ZYX_C";
242};
243
244// output tensor
245// packed NKW/NKHW/NKDHW
246struct NKW : public BaseTensorLayout
247{
248 static constexpr const char* name = "NKW";
249};
250
251struct NKHW : public BaseTensorLayout
252{
253 static constexpr const char* name = "NKHW";
254};
255
256struct NKDHW : public BaseTensorLayout
257{
258 static constexpr const char* name = "NKDHW";
259};
260
261// output tensor
262// packed GNKW/GNKHW/GNKDHW
263struct GNKW : public BaseTensorLayout
264{
265 static constexpr const char* name = "GNKW";
266};
267
268struct GNKHW : public BaseTensorLayout
269{
270 static constexpr const char* name = "GNKHW";
271};
272
274{
275 static constexpr const char* name = "GNKDHW";
276};
277
278// output tensor
279// packed NWK/NHWK/NDHWK
280struct NWK : public BaseTensorLayout
281{
282 static constexpr const char* name = "NWK";
283};
284
285struct NHWK : public BaseTensorLayout
286{
287 static constexpr const char* name = "NHWK";
288};
289
290struct NDHWK : public BaseTensorLayout
291{
292 static constexpr const char* name = "NDHWK";
293};
294
295// output tensor
296// packed GNWK/GNHWK/GNDHWK
297struct GNWK : public BaseTensorLayout
298{
299 static constexpr const char* name = "GNWK";
300};
301
302struct GNHWK : public BaseTensorLayout
303{
304 static constexpr const char* name = "GNHWK";
305};
306
308{
309 static constexpr const char* name = "GNDHWK";
310};
311
312// output tensor
313// packed NWGK/NHWGK/NDHWGK
314struct NWGK : public BaseTensorLayout
315{
316 static constexpr const char* name = "NWGK";
317};
318
319struct NHWGK : public BaseTensorLayout
320{
321 static constexpr const char* name = "NHWGK";
322};
323
325{
326 static constexpr const char* name = "NDHWGK";
327};
328
329// output tensor
330// strided layout
332{
333 static constexpr const char* name = "G_NW_K";
334};
335
337{
338 static constexpr const char* name = "G_NHW_K";
339};
340
342{
343 static constexpr const char* name = "G_NDHW_K";
344};
345
346// for output bias
347struct G_K : public BaseTensorLayout
348{
349 static constexpr const char* name = "G_K";
350};
351
352// K-reduced output tensor (packed)
353struct GNW : public BaseTensorLayout
354{
355 static constexpr const char* name = "GNW";
356};
357
358struct GNHW : public BaseTensorLayout
359{
360 static constexpr const char* name = "GNHW";
361};
362
363struct GNDHW : public BaseTensorLayout
364{
365 static constexpr const char* name = "GNDHW";
366};
367
368// K-reduced output tensor (packed)
369struct NWG : public BaseTensorLayout
370{
371 static constexpr const char* name = "NWG";
372};
373
374struct NHWG : public BaseTensorLayout
375{
376 static constexpr const char* name = "NHWG";
377};
378
379struct NDHWG : public BaseTensorLayout
380{
381 static constexpr const char* name = "NDHWG";
382};
383
384// K-reduced output tensor (strided)
385struct G_NW : public BaseTensorLayout
386{
387 static constexpr const char* name = "G_NW";
388};
389
390struct G_NHW : public BaseTensorLayout
391{
392 static constexpr const char* name = "G_NHW";
393};
394
396{
397 static constexpr const char* name = "G_NDHW";
398};
399
400} // namespace convolution
401
402template <
403 typename Layout,
404 typename std::enable_if<std::is_base_of<BaseTensorLayout, Layout>::value, bool>::type = false>
405std::ostream& operator<<(std::ostream& os, const Layout&)
406{
407 os << Layout::name;
408 return os;
409}
410
411} // namespace tensor_layout
412} // namespace ck_tile
Definition tile/ops/common/tensor_layout.hpp:27
Definition tile/ops/common/tensor_layout.hpp:14
Definition tile/ops/common/tensor_layout.hpp:8
std::ostream & operator<<(std::ostream &os, const Layout &)
Definition tile/ops/common/tensor_layout.hpp:405
Definition tile/core/algorithm/cluster_descriptor.hpp:13
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24
Definition tile/ops/common/tensor_layout.hpp:11
Definition tile/ops/common/tensor_layout.hpp:138
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:139
Definition tile/ops/common/tensor_layout.hpp:230
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:231
Definition tile/ops/common/tensor_layout.hpp:235
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:236
Definition tile/ops/common/tensor_layout.hpp:240
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:241
Definition tile/ops/common/tensor_layout.hpp:348
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:349
Definition tile/ops/common/tensor_layout.hpp:132
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:133
Definition tile/ops/common/tensor_layout.hpp:342
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:343
Definition tile/ops/common/tensor_layout.hpp:396
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:397
Definition tile/ops/common/tensor_layout.hpp:127
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:128
Definition tile/ops/common/tensor_layout.hpp:337
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:338
Definition tile/ops/common/tensor_layout.hpp:391
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:392
Definition tile/ops/common/tensor_layout.hpp:122
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:123
Definition tile/ops/common/tensor_layout.hpp:332
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:333
Definition tile/ops/common/tensor_layout.hpp:386
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:387
Definition tile/ops/common/tensor_layout.hpp:98
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:99
Definition tile/ops/common/tensor_layout.hpp:162
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:163
Definition tile/ops/common/tensor_layout.hpp:167
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:168
Definition tile/ops/common/tensor_layout.hpp:172
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:173
Definition tile/ops/common/tensor_layout.hpp:196
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:197
Definition tile/ops/common/tensor_layout.hpp:201
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:202
Definition tile/ops/common/tensor_layout.hpp:206
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:207
Definition tile/ops/common/tensor_layout.hpp:58
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:59
Definition tile/ops/common/tensor_layout.hpp:53
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:54
Definition tile/ops/common/tensor_layout.hpp:48
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:49
Definition tile/ops/common/tensor_layout.hpp:92
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:93
Definition tile/ops/common/tensor_layout.hpp:364
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:365
Definition tile/ops/common/tensor_layout.hpp:308
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:309
Definition tile/ops/common/tensor_layout.hpp:87
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:88
Definition tile/ops/common/tensor_layout.hpp:359
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:360
Definition tile/ops/common/tensor_layout.hpp:303
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:304
Definition tile/ops/common/tensor_layout.hpp:274
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:275
Definition tile/ops/common/tensor_layout.hpp:269
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:270
Definition tile/ops/common/tensor_layout.hpp:264
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:265
Definition tile/ops/common/tensor_layout.hpp:82
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:83
Definition tile/ops/common/tensor_layout.hpp:354
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:355
Definition tile/ops/common/tensor_layout.hpp:298
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:299
Definition tile/ops/common/tensor_layout.hpp:145
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:146
Definition tile/ops/common/tensor_layout.hpp:150
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:151
Definition tile/ops/common/tensor_layout.hpp:155
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:156
Definition tile/ops/common/tensor_layout.hpp:179
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:180
Definition tile/ops/common/tensor_layout.hpp:213
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:214
Definition tile/ops/common/tensor_layout.hpp:184
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:185
Definition tile/ops/common/tensor_layout.hpp:218
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:219
Definition tile/ops/common/tensor_layout.hpp:189
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:190
Definition tile/ops/common/tensor_layout.hpp:223
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:224
Definition tile/ops/common/tensor_layout.hpp:42
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:43
Definition tile/ops/common/tensor_layout.hpp:37
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:38
Definition tile/ops/common/tensor_layout.hpp:32
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:33
Definition tile/ops/common/tensor_layout.hpp:75
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:76
Definition tile/ops/common/tensor_layout.hpp:115
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:116
Definition tile/ops/common/tensor_layout.hpp:380
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:381
Definition tile/ops/common/tensor_layout.hpp:325
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:326
Definition tile/ops/common/tensor_layout.hpp:291
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:292
Definition tile/ops/common/tensor_layout.hpp:70
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:71
Definition tile/ops/common/tensor_layout.hpp:110
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:111
Definition tile/ops/common/tensor_layout.hpp:375
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:376
Definition tile/ops/common/tensor_layout.hpp:320
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:321
Definition tile/ops/common/tensor_layout.hpp:286
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:287
Definition tile/ops/common/tensor_layout.hpp:257
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:258
Definition tile/ops/common/tensor_layout.hpp:252
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:253
Definition tile/ops/common/tensor_layout.hpp:247
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:248
Definition tile/ops/common/tensor_layout.hpp:65
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:66
Definition tile/ops/common/tensor_layout.hpp:105
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:106
Definition tile/ops/common/tensor_layout.hpp:370
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:371
Definition tile/ops/common/tensor_layout.hpp:315
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:316
Definition tile/ops/common/tensor_layout.hpp:281
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:282
Definition tile/ops/common/tensor_layout.hpp:22
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:23
Definition tile/ops/common/tensor_layout.hpp:17
static constexpr const char * name
Definition tile/ops/common/tensor_layout.hpp:18