data_type.hpp Source File

data_type.hpp Source File#

Composable Kernel: data_type.hpp Source File
data_type.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5#include <stdint.h>
7#include "ck/utility/e8m0.hpp"
9
12
13#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)
14#define CHAR_BIT 8
15using int8_t = signed char;
16using uint8_t = unsigned char;
17using int16_t = signed short;
18using uint16_t = unsigned short;
19using float_t = float;
20#endif // __HIPCC_RTC__
21
22namespace ck {
23#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)
24using byte = unsigned char;
25#else
26using std::byte;
27#endif
28
29using tf32_t = _BitInt(19); // 1 sign bit, 8 exponent bits, 10 mantissa bits
30using bhalf_t = ushort;
31using half_t = _Float16;
32using int4_t = _BitInt(4);
33using f4_t = unsigned _BitInt(4);
34using f6_t = _BitInt(6); // e2m3 format
35using bf6_t = unsigned _BitInt(6); // e3m2 format
36
37// scalar_type
38template <typename TV>
40
42{
43 static constexpr int packed_size = 2;
44
45 using type = uint8_t;
47 __host__ __device__ constexpr f4x2_pk_t() : data{type{}} {}
48 __host__ __device__ constexpr f4x2_pk_t(const type init) : data{init} {}
49
50 template <index_t I>
51 __host__ __device__ inline type unpack(Number<I>) const
52 {
53 static_assert(I < 2, "Index is out of range.");
54 if constexpr(I == 1)
55 return (data >> 4);
56 else
57 return data & 0b00001111;
58 }
59
60 __host__ __device__ inline type pack(const type x0, const type x1)
61 {
62 return (x1 << 4) | (x0 & 0b00001111);
63 }
64
65 // Compare operator
66 __host__ __device__ friend bool operator==(const f4x2_pk_t& lhs, const f4x2_pk_t& rhs)
67 {
68 return lhs.data == rhs.data;
69 }
70
71 __host__ __device__ friend bool operator!=(const f4x2_pk_t& lhs, const f4x2_pk_t& rhs)
72 {
73 return !(lhs == rhs);
74 }
75};
76
77template <typename BitType, index_t pk_size>
78struct f6_pk_t
79{
80 using element_type = uint32_t; // element storage fundamental type
81
82 static constexpr index_t packed_size = pk_size; // 16 or 32 for now
83 static constexpr index_t num_bits_elem = 6; // specialized for 6-bit data
84 // XXX: CHAR_BIT is not defined in HIPRTC, so we must use 8
85 static constexpr index_t num_bits_vec_elem =
86 sizeof(element_type) * 8; // 32-bit uint for storage
87 static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0,
88 "Packed elements must fit exactly into the element storage.");
89 static constexpr index_t vector_size =
90 (packed_size * num_bits_elem) / num_bits_vec_elem; // 3 or 6 element_type units
91
92 using storage_type = element_type __attribute__((ext_vector_type(vector_size)));
93 storage_type data_{storage_type(0)}; // packed data
94
96
97 __host__ __device__ constexpr f6_pk_t() {}
98 __host__ __device__ constexpr f6_pk_t(const storage_type& init) : data_{init}
99 {
100 // TODO: consider removing initialization similar to vector_type<T, 256>
101 }
102
103 // Initialize from a vector type with the same size as packed_size
104 template <typename T, typename = enable_if_t<scalar_type<T>::vector_size == packed_size>>
105 __host__ __device__ f6_pk_t(const T& v)
106 {
108 [&](auto i) { pack(v[static_cast<index_t>(i)], static_cast<index_t>(i)); });
109 }
110
111 // Broadcast single initialization value to all packed elements
112 __host__ __device__ f6_pk_t(const int8_t v)
113 : f6_pk_t(static_cast<int8_t __attribute__((ext_vector_type(packed_size)))>(v))
114 {
115 // TODO: consider removing initialization similar to vector_type<T, 256>
116 }
117
118 template <typename T>
119 __host__ __device__ void pack(const T x, const index_t i)
120 {
122 "T must be an integral type.");
123
124 uint32_t bits = static_cast<uint32_t>(x) & 0x3F;
125 const int bit_pos = i * num_bits_elem;
126 const int arr_index = bit_pos / num_bits_vec_elem;
127 const int bit_offset = bit_pos % num_bits_vec_elem;
128 const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
129 uint32_t old_value = data_[arr_index];
130
131 // insert bits into the current 32-bit block
132 old_value |= (bits << bit_offset);
133 data_[arr_index] = old_value;
134
135 // if it crosses into the next block, shift the remainder
136 if(overhang > 0 && (arr_index + 1) < vector_size)
137 {
138 uint32_t next_value = data_[arr_index + 1];
139 next_value |= (bits >> (num_bits_elem - overhang));
140 data_[arr_index + 1] = next_value;
141 }
142 }
143
144 __host__ __device__ static inline BitType unpack(const type& pk, const index_t i)
145 {
146 const int bit_pos = i * num_bits_elem;
147 const int arr_idx = bit_pos / num_bits_vec_elem;
148 const int bit_offset = bit_pos % num_bits_vec_elem;
149 const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
150
151 uint32_t bits = pk.data_[arr_idx] >> bit_offset;
152 if(overhang > 0 && (arr_idx + 1) < vector_size)
153 {
154 bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang);
155 }
156
157 return static_cast<BitType>(bits & 0x3F);
158 }
159
160 __host__ __device__ inline BitType unpack(const index_t i) const { return unpack(*this, i); }
161
162 // Compare operator
163 __host__ __device__ friend bool operator==(const f6_pk_t& lhs, const f6_pk_t& rhs)
164 {
165#pragma unroll
166 for(index_t i = 0; i < vector_size; ++i)
167 {
168 if(lhs.data_[i] != rhs.data_[i])
169 return false;
170 }
171 return true;
172 }
173
174 __host__ __device__ friend bool operator!=(const f6_pk_t& lhs, const f6_pk_t& rhs)
175 {
176 return !(lhs == rhs);
177 }
178};
179
184
185// custom data type - pack int4 data
187{
188 using type = int8_t;
190 __host__ __device__ constexpr pk_i4_t() : data{type{}} {}
191 __host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
192};
193
194inline constexpr auto next_pow2(uint32_t x)
195{
196 // Precondition: x > 1.
197 return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
198}
199
200// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
201// native types: bool
202template <typename T>
210
211// scalar_type
212template <typename TV>
213struct scalar_type;
214
215// is_scalar_type
216template <typename TV>
218{
219 static constexpr bool value = (scalar_type<remove_cvref_t<TV>>::vector_size == 1);
220};
221
222// has_same_scalar_type
223template <typename X, typename Y>
225 typename scalar_type<remove_cvref_t<Y>>::type>;
226
227template <typename T, index_t N>
228struct scalar_type<T __attribute__((ext_vector_type(N)))>
229{
230 using type = T;
231 static constexpr index_t vector_size = N;
232};
233
234//
235template <>
236struct scalar_type<double>
237{
238 using type = double;
239 static constexpr index_t vector_size = 1;
240};
241
242template <>
243struct scalar_type<float>
244{
245 using type = float;
246 static constexpr index_t vector_size = 1;
247};
248
249template <>
251{
252 using type = half_t;
253 static constexpr index_t vector_size = 1;
254};
255
256template <>
258{
259 using type = bhalf_t;
260 static constexpr index_t vector_size = 1;
261};
262
263template <>
265{
266 using type = int32_t;
267 static constexpr index_t vector_size = 1;
268};
269
270template <>
272{
273 using type = int8_t;
274 static constexpr index_t vector_size = 1;
275};
276
277template <>
279{
280 using type = uint8_t;
281 static constexpr index_t vector_size = 1;
282};
283
284#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
285template <>
286struct scalar_type<int4_t>
287{
288 using type = int4_t;
289 static constexpr index_t vector_size = 1;
290};
291#endif
292
293template <>
295{
296 using type = pk_i4_t;
297 static constexpr index_t vector_size = 1;
298};
299
300template <>
302{
304 static constexpr index_t vector_size = 1;
305};
306
307template <>
309{
311 static constexpr index_t vector_size = 1;
312};
313
314template <>
316{
318 static constexpr index_t vector_size = 1;
319};
320
321template <>
323{
325 static constexpr index_t vector_size = 1;
326};
327
328#ifndef CK_CODE_GEN_RTC
329template <>
331{
333 static constexpr index_t vector_size = 1;
334};
335#endif
336
337template <>
339{
341 static constexpr index_t vector_size = 1;
342};
343
344template <>
346{
348 static constexpr index_t vector_size = 1;
349};
350
351template <>
353{
355 static constexpr index_t vector_size = 1;
356};
357
358template <>
360{
362 static constexpr index_t vector_size = 1;
363};
364
365template <>
367{
369 static constexpr index_t vector_size = 1;
370};
371
372template <>
373struct scalar_type<bool>
374{
375 using type = bool;
376 static constexpr index_t vector_size = 1;
377};
378
379template <typename T>
381{
382 private:
383 static constexpr auto get_packed_type_info()
384 {
385 using U = remove_cvref_t<T>;
386 if constexpr(is_same_v<U, pk_i4_t>)
387 return ck::Tuple<ck::Number<2>, int4_t>{};
388 else if constexpr(is_same_v<U, f4x2_pk_t>)
389 return ck::Tuple<ck::Number<2>, f4_t>{};
390 else if constexpr(is_same_v<U, f6x16_pk_t>)
391 return ck::Tuple<ck::Number<16>, f6_t>{};
392 else if constexpr(is_same_v<U, bf6x16_pk_t>)
393 return ck::Tuple<ck::Number<16>, bf6_t>{};
394 else if constexpr(is_same_v<U, f6x32_pk_t>)
395 return ck::Tuple<ck::Number<32>, f6_t>{};
396 else if constexpr(is_same_v<U, bf6x32_pk_t>)
397 return ck::Tuple<ck::Number<32>, bf6_t>{};
398 else
399 return ck::Tuple<ck::Number<1>, T>{};
400 }
401
402 public:
403 using element_type = remove_cvref_t<decltype(get_packed_type_info().At(ck::Number<1>{}))>;
404 static constexpr auto packed_size =
405 static_cast<index_t>(get_packed_type_info().At(ck::Number<0>{}));
406};
407template <typename T>
409
410template <typename T>
412
413template <typename T>
414inline constexpr bool is_packed_type_v = packed_size_v<T> > 1;
415
416template <typename T, index_t N = 0>
418{
419 private:
420 static constexpr auto get_packed_type()
421 {
422 using U = remove_cvref_t<T>;
423 if constexpr(is_same_v<U, int4_t>)
424 {
425 static_assert(N == 0 || N == 2, "Packed size N for int4_t must be 2.");
426 return pk_i4_t{};
427 }
428 else if constexpr(is_same_v<U, f4_t>)
429 {
430 static_assert(N == 0 || N == 2, "Packed size N for f4_t must be 2.");
431 return f4x2_pk_t{};
432 }
433 else if constexpr(is_same_v<U, f6_t>)
434 {
435 static_assert(N == 0 || N == 16 || N == 32, "Packed size N for f6_t must be 16 or 32.");
436 if constexpr(N == 16)
437 return f6x16_pk_t{};
438 else if constexpr(N == 0 || N == 32)
439 return f6x32_pk_t{};
440 }
441 else if constexpr(is_same_v<U, bf6_t>)
442 {
443 static_assert(N == 0 || N == 16 || N == 32,
444 "Packed size N for bf6_t must be 16 or 32.");
445 if constexpr(N == 16)
446 return bf6x16_pk_t{};
447 else if constexpr(N == 0 || N == 32)
448 return bf6x32_pk_t{};
449 }
450 else
451 return T{};
452 }
453
454 public:
455 using packed_type = remove_cvref_t<decltype(get_packed_type())>;
456};
457
458template <typename T, index_t N = 0>
460
461#if defined(_WIN32)
462using int64_t = long long;
463#else
464using int64_t = long;
465#endif
466
467template <typename T>
468inline const char* get_type_name()
469{
470 if constexpr(is_same_v<T, half_t>)
471 return "fp16";
472 else if constexpr(is_same_v<T, bhalf_t>)
473 return "bf16";
474 else if constexpr(is_same_v<T, tf32_t>)
475 return "tf32";
476 else if constexpr(is_same_v<T, int4_t>)
477 return "int4";
478 else if constexpr(is_same_v<T, f4_t>)
479 return "f4";
480 else if constexpr(is_same_v<T, f6_t>)
481 return "f6";
482 else if constexpr(is_same_v<T, bf6_t>)
483 return "bf6";
484 else if constexpr(is_same_v<T, f8_t>)
485 return "f8";
486 else if constexpr(is_same_v<T, bf8_t>)
487 return "bf8";
488#ifndef CK_CODE_GEN_RTC
489 else if constexpr(is_same_v<T, e8m0_bexp_t>)
490 return "e8m0";
491#endif
492 else if constexpr(is_same_v<T, float>)
493 return "fp32";
494#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)
495 else
496 return "unknown";
497#else
498 else
499 return typeid(T).name();
500#endif
501}
502
503} // namespace ck
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
f6_pk_t< f6_t, 16 > f6x16_pk_t
Definition data_type.hpp:180
int32_t index_t
Definition ck.hpp:299
constexpr bool is_native_type()
Definition data_type.hpp:203
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
_Float16 half_t
Definition data_type.hpp:31
integral_constant< index_t, N > Number
Definition number.hpp:12
typename packed_type_maker< T, N >::packed_type packed_type_t
Definition data_type.hpp:459
long int64_t
Definition data_type.hpp:464
const char * get_type_name()
Definition data_type.hpp:468
typename packed_type_info< T >::element_type element_type_t
Definition data_type.hpp:408
f6_pk_t< bf6_t, 32 > bf6x32_pk_t
Definition data_type.hpp:183
is_same< typename scalar_type< remove_cvref_t< X > >::type, typename scalar_type< remove_cvref_t< Y > >::type > has_same_scalar_type
Definition data_type.hpp:224
unsigned _BitInt(4) f4_t
Definition data_type.hpp:33
_BitInt(6) f6_t
Definition data_type.hpp:34
f6_pk_t< f6_t, 32 > f6x32_pk_t
Definition data_type.hpp:181
unsigned _BitInt(6) bf6_t
Definition data_type.hpp:35
constexpr auto next_pow2(uint32_t x)
Definition data_type.hpp:194
constexpr bool is_same_v
Definition type.hpp:283
constexpr bool is_packed_type_v
Definition data_type.hpp:414
_BitInt(4) int4_t
Definition data_type.hpp:32
_BitInt(19) tf32_t
Definition data_type.hpp:29
constexpr index_t packed_size_v
Definition data_type.hpp:411
f6_pk_t< bf6_t, 16 > bf6x16_pk_t
Definition data_type.hpp:182
signed short int16_t
Definition stdint.h:122
unsigned short uint16_t
Definition stdint.h:125
unsigned int uint32_t
Definition stdint.h:126
signed int int32_t
Definition stdint.h:123
unsigned char uint8_t
Definition stdint.h:124
signed char int8_t
Definition stdint.h:121
Definition utility/tuple.hpp:117
Definition amd_ck_fp8.hpp:49
unsigned char data_type
Definition amd_ck_fp8.hpp:50
Definition amd_ck_fp8.hpp:369
fp8_storage_t data_type
Definition amd_ck_fp8.hpp:370
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
uint8_t type
Definition utility/e8m0.hpp:27
Definition data_type.hpp:42
static constexpr int packed_size
Definition data_type.hpp:43
__host__ __device__ type unpack(Number< I >) const
Definition data_type.hpp:51
__host__ __device__ friend bool operator!=(const f4x2_pk_t &lhs, const f4x2_pk_t &rhs)
Definition data_type.hpp:71
type data
Definition data_type.hpp:46
__host__ __device__ friend bool operator==(const f4x2_pk_t &lhs, const f4x2_pk_t &rhs)
Definition data_type.hpp:66
__host__ __device__ type pack(const type x0, const type x1)
Definition data_type.hpp:60
__host__ __device__ constexpr f4x2_pk_t(const type init)
Definition data_type.hpp:48
uint8_t type
Definition data_type.hpp:45
__host__ __device__ constexpr f4x2_pk_t()
Definition data_type.hpp:47
Definition data_type.hpp:79
__host__ __device__ friend bool operator==(const f6_pk_t &lhs, const f6_pk_t &rhs)
Definition data_type.hpp:163
static constexpr index_t vector_size
Definition data_type.hpp:89
__host__ __device__ friend bool operator!=(const f6_pk_t &lhs, const f6_pk_t &rhs)
Definition data_type.hpp:174
static constexpr index_t packed_size
Definition data_type.hpp:82
f6_pk_t< BitType, packed_size > type
Definition data_type.hpp:95
__host__ __device__ void pack(const T x, const index_t i)
Definition data_type.hpp:119
uint32_t element_type
Definition data_type.hpp:80
__host__ static __device__ BitType unpack(const type &pk, const index_t i)
Definition data_type.hpp:144
element_type storage_type
Definition data_type.hpp:92
static constexpr index_t num_bits_elem
Definition data_type.hpp:83
__host__ __device__ constexpr f6_pk_t()
Definition data_type.hpp:97
__host__ __device__ f6_pk_t(const int8_t v)
Definition data_type.hpp:112
static constexpr index_t num_bits_vec_elem
Definition data_type.hpp:85
__host__ __device__ constexpr f6_pk_t(const storage_type &init)
Definition data_type.hpp:98
storage_type data_
Definition data_type.hpp:93
__host__ __device__ f6_pk_t(const T &v)
Definition data_type.hpp:105
__host__ __device__ BitType unpack(const index_t i) const
Definition data_type.hpp:160
Definition amd_ck_fp8.hpp:36
unsigned char data_type
Definition amd_ck_fp8.hpp:37
Definition amd_ck_fp8.hpp:323
fp8_storage_t data_type
Definition amd_ck_fp8.hpp:324
Definition type.hpp:177
Definition data_type.hpp:218
static constexpr bool value
Definition data_type.hpp:219
Definition data_type.hpp:381
static constexpr auto packed_size
Definition data_type.hpp:404
remove_cvref_t< decltype(get_packed_type_info().At(ck::Number< 1 >{}))> element_type
Definition data_type.hpp:403
Definition data_type.hpp:418
remove_cvref_t< decltype(get_packed_type())> packed_type
Definition data_type.hpp:455
Definition data_type.hpp:187
type data
Definition data_type.hpp:189
__host__ __device__ constexpr pk_i4_t()
Definition data_type.hpp:190
int8_t type
Definition data_type.hpp:188
__host__ __device__ constexpr pk_i4_t(type init)
Definition data_type.hpp:191
static constexpr index_t vector_size
Definition data_type.hpp:231
T type
Definition data_type.hpp:230
bf6x16_pk_t::storage_type type
Definition data_type.hpp:368
static constexpr index_t vector_size
Definition data_type.hpp:369
bf6x32_pk_t::storage_type type
Definition data_type.hpp:354
static constexpr index_t vector_size
Definition data_type.hpp:355
bf8_fnuz_t::data_type type
Definition data_type.hpp:310
static constexpr index_t vector_size
Definition data_type.hpp:311
static constexpr index_t vector_size
Definition data_type.hpp:325
bf8_ocp_t::data_type type
Definition data_type.hpp:324
bhalf_t type
Definition data_type.hpp:259
static constexpr index_t vector_size
Definition data_type.hpp:260
static constexpr index_t vector_size
Definition data_type.hpp:376
bool type
Definition data_type.hpp:375
double type
Definition data_type.hpp:238
static constexpr index_t vector_size
Definition data_type.hpp:239
static constexpr index_t vector_size
Definition data_type.hpp:333
e8m0_bexp_t::type type
Definition data_type.hpp:332
static constexpr index_t vector_size
Definition data_type.hpp:341
f4x2_pk_t::type type
Definition data_type.hpp:340
static constexpr index_t vector_size
Definition data_type.hpp:362
f6x16_pk_t::storage_type type
Definition data_type.hpp:361
f6x32_pk_t::storage_type type
Definition data_type.hpp:347
static constexpr index_t vector_size
Definition data_type.hpp:348
static constexpr index_t vector_size
Definition data_type.hpp:304
f8_fnuz_t::data_type type
Definition data_type.hpp:303
static constexpr index_t vector_size
Definition data_type.hpp:318
f8_ocp_t::data_type type
Definition data_type.hpp:317
float type
Definition data_type.hpp:245
static constexpr index_t vector_size
Definition data_type.hpp:246
half_t type
Definition data_type.hpp:252
static constexpr index_t vector_size
Definition data_type.hpp:253
int32_t type
Definition data_type.hpp:266
static constexpr index_t vector_size
Definition data_type.hpp:267
int8_t type
Definition data_type.hpp:273
static constexpr index_t vector_size
Definition data_type.hpp:274
pk_i4_t type
Definition data_type.hpp:296
static constexpr index_t vector_size
Definition data_type.hpp:297
static constexpr index_t vector_size
Definition data_type.hpp:281
uint8_t type
Definition data_type.hpp:280
Definition data_type.hpp:39
Definition functional2.hpp:33