amd_ck_fp8.hpp Source File

amd_ck_fp8.hpp Source File#

Composable Kernel: amd_ck_fp8.hpp Source File
amd_ck_fp8.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/ck.hpp"
11#include "ck/utility/type.hpp"
12
13#ifndef CK_USE_FNUZ_FP8
14#define CK_USE_FNUZ_FP8 0
15#endif
16
17#ifndef CK_USE_OCP_FP8
18#define CK_USE_OCP_FP8 0
19#endif
20
21#if(defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
22#define CK_FP8_CVT_FAST_PATH 1
23#else
24#define CK_FP8_CVT_FAST_PATH 0
25#endif
26
27#if(defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
28#define CK_OCP_FP8_CVT_FAST_PATH 1
29#else
30#define CK_OCP_FP8_CVT_FAST_PATH 0
31#endif
32
33namespace ck {
34
36{
37 using data_type = unsigned char;
39 __host__ __device__ explicit constexpr f8_fnuz_t(data_type in_data) : m_data(in_data) {}
40 __host__ __device__ explicit constexpr f8_fnuz_t() = default;
41 __host__ __device__ bool constexpr operator==(f8_fnuz_t other) const
42 {
43 return m_data == other.m_data;
44 }
45 __host__ __device__ explicit constexpr operator data_type() const { return m_data; }
46};
47
49{
50 using data_type = unsigned char;
52 __host__ __device__ explicit constexpr bf8_fnuz_t(data_type in_data) : m_data(in_data) {}
53 __host__ __device__ explicit constexpr bf8_fnuz_t() = default;
54 __host__ __device__ bool constexpr operator==(bf8_fnuz_t other) const
55 {
56 return m_data == other.m_data;
57 }
58 __host__ __device__ explicit constexpr operator data_type() const { return m_data; }
59};
60
61static_assert(1 == sizeof(f8_fnuz_t));
62static_assert(1 == sizeof(bf8_fnuz_t));
63
64typedef unsigned char fp8_storage_t;
65
70{
71 CK_E4M3_OCP = 0, // OCP E4M3
72 CK_E5M2_OCP = 1, // OCP E5M2
73 CK_E4M3_FNUZ = 2, // FP8
74 CK_E5M2_FNUZ = 3, // BF8
75};
76
81{
82 CK_NOSAT = 0, // No saturation - replace with NaN or Inf
83 CK_SATFINITE = 1, // Saturate to finite
84};
85
86namespace fp8_impl {
87
88typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2)));
89typedef _Float16 half2_t __attribute__((ext_vector_type(2)));
90typedef ushort ushortx2_t __attribute__((ext_vector_type(2)));
91typedef short shortx2_t __attribute__((ext_vector_type(2)));
92typedef float float2_t __attribute__((ext_vector_type(2)));
93
94__host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a)
95{
96 return static_cast<unsigned char>(a) == 0x80;
97}
98__host__ __device__ static inline constexpr bool fnuz_bf8_is_nan(bf8_fnuz_t a)
99{
100 return static_cast<unsigned char>(a) == 0x80;
101}
102
103__host__ __device__ static inline constexpr bool ocp_f8_is_nan(fp8_storage_t a)
104{
105 return (a & 0x7f) == 0x7f;
106}
107__host__ __device__ static inline constexpr bool ocp_bf8_is_nan(fp8_storage_t a)
108{
109 return (a & 0x7f) > 0x7c;
110}
111
112// The conversion function is from rocblas
113// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
114// This has been modified to handle double types as well
115template <typename T, int wm, int we, bool is_fnuz, bool clip = false>
116__host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
117{
118 constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
119 constexpr bool is_float = __hip_internal::is_same<T, float>::value;
120 constexpr bool is_double = __hip_internal::is_same<T, double>::value;
121 static_assert(is_half || is_float || is_double, "only half, float and double are supported");
122
123 constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
124 constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
125
126 T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
127 if constexpr(is_half)
128 {
129 const unsigned short int ihInf = 0x7C00;
130 const unsigned short int ihNegInf = 0xFC00;
131 const unsigned short int ihNaN = 0x7C01;
132 const unsigned short int ihNeg0 = 0x8000;
133 /* Max number in e5m2 57344*/
134 const unsigned short int ifmax = 0x7B00;
135 const unsigned short int ifmin = 0xFB00;
136
137 fInf = bit_cast<_Float16>(ihInf);
138 fNegInf = bit_cast<_Float16>(ihNegInf);
139 fNaN = bit_cast<_Float16>(ihNaN);
140 fNeg0 = bit_cast<_Float16>(ihNeg0);
141 fmax = bit_cast<_Float16>(ifmax);
142 fmin = bit_cast<_Float16>(ifmin);
143 }
144 else if constexpr(is_float)
145 {
146 const unsigned int ifInf = 0x7F800000;
147 const unsigned int ifNegInf = 0xFF800000;
148 const unsigned int ifNaN = 0x7F800001;
149 const unsigned int ifNeg0 = 0x80000000;
150 /* Max number in e5m2 57344*/
151 const unsigned int ifmax = 0x47600000;
152 const unsigned int ifmin = 0xC7600000;
153
154 fInf = bit_cast<float>(ifInf);
155 fNegInf = bit_cast<float>(ifNegInf);
156 fNaN = bit_cast<float>(ifNaN);
157 fNeg0 = bit_cast<float>(ifNeg0);
158 fmax = bit_cast<float>(ifmax);
159 fmin = bit_cast<float>(ifmin);
160 }
161 else if constexpr(is_double)
162 {
163 const unsigned long long ifInf = 0x7FF0000000000000ull;
164 const unsigned long long ifNegInf = 0xFFF0000000000000ull;
165 const unsigned long long ifNaN = 0x7FF0000000000001ull;
166 const unsigned long long ifNeg0 = 0x8000000000000000ull;
167 /* Max number in e5m2 57344*/
168 const unsigned long long ifmax = 0x40EC000000000000ull;
169 const unsigned long long ifmin = 0xC0EC000000000000ull;
170
171 fInf = bit_cast<double>(ifInf);
172 fNegInf = bit_cast<double>(ifNegInf);
173 fNaN = bit_cast<double>(ifNaN);
174 fNeg0 = bit_cast<double>(ifNeg0);
175 fmax = bit_cast<double>(ifmax);
176 fmin = bit_cast<double>(ifmin);
177 }
178
179 if(x == 0)
180 {
181 return 0;
182 }
183
184 unsigned long long sign = x >> 7;
185 unsigned long long mantissa = x & ((1 << wm) - 1);
186 int exponent = (x & 0x7F) >> wm;
187 if constexpr(is_fnuz)
188 {
189 if(x == 0x80)
190 {
191 return fNaN;
192 }
193 }
194 else
195 {
196 if(x == 0x80)
197 {
198 return fNeg0;
199 }
200 if constexpr(we == 4)
201 { // e4m3
202 if((x & 0x7F) == 0x7F)
203 {
204 return fNaN;
205 }
206 }
207 else if((x & 0x7C) == 0x7C)
208 { // e5m2
209 if((x & 0x3) == 0)
210 {
211 if constexpr(clip)
212 {
213 return sign ? fmin : fmax;
214 }
215 return sign ? fNegInf : fInf;
216 }
217 return fNaN;
218 }
219 }
220
221 typename ck::conditional_t<
222 sizeof(T) == 2,
223 unsigned short int,
224 typename ck::conditional_t<sizeof(T) == 4, unsigned int, unsigned long long>>
225 retval;
226
227 if constexpr(we == 5 && is_half && !is_fnuz)
228 {
229 retval = x << 8;
230 return bit_cast<T>(retval);
231 }
232
233 const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
234
235 // subnormal input
236 if(exponent == 0)
237 {
238#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
239 // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
240 int sh = 1 + __clz(mantissa) - (32 - wm);
241#else
242 int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
243#endif
244 mantissa <<= sh;
245 exponent += 1 - sh;
246 mantissa &= ((1ull << wm) - 1);
247 }
248 exponent += exp_low_cutoff - 1;
249 mantissa <<= wmo - wm;
250
251 // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
252 if(exponent <= 0)
253 {
254 mantissa |= 1 << wmo;
255 mantissa >>= 1 - exponent;
256 exponent = 0;
257 }
258
259 if constexpr(sizeof(T) == 2)
260 retval = (sign << 15) | (exponent << 10) | mantissa;
261 else if constexpr(sizeof(T) == 4)
262 retval = (sign << 31) | (exponent << 23) | mantissa;
263 else
264 retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
265
266 return bit_cast<T>(retval);
267}
268
269#if CK_FP8_CVT_FAST_PATH
270template <ck_fp8_interpretation_t interpret>
271static __host__ __device__ float cast_to_f32_from_f8(fp8_storage_t v)
272{
273 union
274 {
275 unsigned int i32val;
276 unsigned char i8val[4];
277 } val;
278 val.i8val[0] = v;
279
280 static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
284 "Only FNUZ and OCP interpretations are supported");
285
286 if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
288 {
289 return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
290 }
291 else
292 {
293 return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
294 }
295}
296
297template <ck_fp8_interpretation_t interpret>
298static __device__ float2_t cast_to_f32_from_f8(fp8x2_storage_t v)
299{
300 const auto i16val = bit_cast<uint16_t>(v);
301
302 static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
306 "Only FNUZ and OCP interpretations are supported");
307
308 if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
310 {
311 return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false);
312 }
313 else
314 {
315 return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false);
316 }
317}
318#endif
319
320} // namespace fp8_impl
321
323{
326
330
331 static constexpr unsigned int we = 4; // exponent width
332 static constexpr unsigned int wm = 3; // mantissa width
333
334 __host__ __device__ constexpr bool operator==(const f8_ocp_t& other) const
335 {
336 return (data == other.data) && (fp8_impl::ocp_f8_is_nan(data) == false); // NaN != NaN
337 }
338
339#if CK_USE_OCP_FP8
340 __host__ __device__ explicit operator float() const
341#else
342 __host__ explicit operator float() const
343#endif
344 {
345#if CK_OCP_FP8_CVT_FAST_PATH
346 return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
347#else
348 return fp8_impl::cast_from_f8<float, wm, we, false>(
349 this->data); // XXX: clip==false must be consistent with operator _Float16
350#endif
351 }
352
353#if CK_USE_OCP_FP8
354 __host__ __device__ explicit operator _Float16() const
355#else
356 __host__ explicit operator _Float16() const
357#endif
358 {
359#if CK_OCP_FP8_CVT_FAST_PATH
360 return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
361#else
362 return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
363 this->data); // XXX: clip==false must be consistent with operator float
364#endif
365 }
366};
367
369{
372
376
377 static constexpr unsigned int we = 5; // exponent width
378 static constexpr unsigned int wm = 2; // mantissa width
379
380 __host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const
381 {
382 return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN
383 }
384
385#if CK_USE_OCP_FP8
386 __host__ __device__ explicit operator float() const
387
388#else
389 __host__ explicit operator float() const
390#endif
391 {
392#if defined(__gfx950__) || defined(__gfx12__)
393 return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
394#else
395 return fp8_impl::cast_from_f8<float, wm, we, false>(
396 this->data); // XXX: clip==false must be consistent with operator _Float16
397#endif
398 }
399
400#if CK_USE_OCP_FP8
401 __host__ __device__ explicit operator _Float16() const
402#else
403 __host__ explicit operator _Float16() const
404#endif
405 {
406#if defined(__gfx950__) || defined(__gfx12__)
407 return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
408#else
409 return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
410 this->data); // XXX: clip==false must be consistent with operator float
411#endif
412 }
413};
414
415template <typename T>
416__host__ __device__ static inline constexpr bool fp8_is_nan(T);
417
418template <>
419__host__ __device__ inline constexpr bool fp8_is_nan(f8_ocp_t a)
420{
421 return fp8_impl::ocp_f8_is_nan(a.data);
422}
423template <>
424__host__ __device__ inline constexpr bool fp8_is_nan(bf8_ocp_t a)
425{
426 return fp8_impl::ocp_bf8_is_nan(a.data);
427}
428template <>
429__host__ __device__ inline constexpr bool fp8_is_nan(f8_fnuz_t a)
430{
431 return fp8_impl::fnuz_f8_is_nan(a);
432}
433template <>
434__host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a)
435{
436 return fp8_impl::fnuz_bf8_is_nan(a);
437}
438
439template <typename T,
442 bool> = true>
443__host__ __device__ static inline constexpr bool fp8_is_inf(T)
444{
445 return false;
446}
447template <>
448__host__ __device__ inline constexpr bool fp8_is_inf(bf8_ocp_t a)
449{
450 return (a.data & 0x7f) == 0x7c;
451}
452
453namespace fp8_impl {
454
455// Assertions to check for supported conversion types
456#define __fp8_impl_assert_ocp_support(interp) \
457 { \
458 if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
459 interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
460 { \
461 __hip_assert(false && "type is unsupported by current target device"); \
462 } \
463 }
464#define __fp8_impl_assert_fnuz_support(interp) \
465 { \
466 if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
467 interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
468 { \
469 __hip_assert(false && "type is unsupported by current target device"); \
470 } \
471 }
472
473__host__ __device__ static inline void
474__is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp)
475{
476#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
477#if CK_USE_OCP_FP8
479#endif
480#if CK_USE_FNUZ_FP8
482#endif
483#endif
484}
485
486#if defined(__gfx950__)
487template <ck_fp8_interpretation_t interpret,
488 bool saturate,
489 bool stochastic_rounding = false,
492static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
493{
494 union
495 {
496 unsigned int i32val;
497 half2_t half_vec;
498 fp8_storage_t i8val[4];
499 } val;
500
501 constexpr unsigned int i32val = 0;
502 val.half_vec[0] = v;
503
504 if constexpr(saturate)
505 {
506 if((val.i32val & 0x7FFF) != 0x7FFF)
507 {
508 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
509 }
510 }
511
512 val.i32val =
513 __builtin_amdgcn_cvt_scalef32_sr_fp8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0);
514
515 return val.i8val[0];
516}
517
518template <ck_fp8_interpretation_t interpret,
519 bool saturate,
520 bool stochastic_rounding = false,
523static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
524{
525 // there is no packed conversion with SR, so convert one element at a time
526 return fp8x2_storage_t{
527 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
528 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
529}
530
531template <ck_fp8_interpretation_t interpret,
532 bool saturate,
533 bool stochastic_rounding = false,
536static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
537{
538 union
539 {
540 unsigned int i32val;
541 half2_t half_vec;
542 fp8_storage_t i8val[4];
543 } val;
544
545 constexpr unsigned int i32val = 0;
546 val.half_vec[0] = v;
547
548 if constexpr(saturate)
549 {
550 if((val.i32val & 0x7FFF) != 0x7FFF)
551 {
552 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
553 }
554 }
555
556 val.i32val =
557 __builtin_amdgcn_cvt_scalef32_sr_bf8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0);
558
559 return val.i8val[0];
560}
561
562template <ck_fp8_interpretation_t interpret,
563 bool saturate,
564 bool stochastic_rounding = false,
567static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
568{
569 // there is no packed conversion with SR, so convert one element at a time
570 return fp8x2_storage_t{
571 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
572 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
573}
574
575template <ck_fp8_interpretation_t interpret,
576 bool saturate,
577 bool stochastic_rounding = false,
580static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
581{
582 ignore = rng;
583
584 union
585 {
586 unsigned int i32val;
587 half2_t half_vec;
588 shortx2_t i16_vec;
589 fp8_storage_t i8val[4];
590 } val;
591
592 constexpr shortx2_t i16x2val = {0, 0};
593 val.half_vec[0] = v;
594
595 if constexpr(saturate)
596 {
597 if((val.i32val & 0x7FFF) != 0x7FFF)
598 {
599 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
600 }
601 }
602
603 val.i16_vec =
604 __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
605
606 return val.i8val[0];
607}
608
609template <ck_fp8_interpretation_t interpret,
610 bool saturate,
611 bool stochastic_rounding = false,
614static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
615{
616#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
617 return fp8x2_storage_t{
618 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
619 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
620#else
621 ignore = rng;
622
623 union
624 {
625 half2_t half_vec;
626 shortx2_t i16_vec;
627 fp8_storage_t i8val[4];
628 } val;
629
630 constexpr shortx2_t i16x2val = {0, 0};
631 val.half_vec = v;
632
633 if constexpr(saturate)
634 {
635 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
636 {
637 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
638 }
639 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
640 {
641 val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 448.0, -448.0);
642 }
643 }
644
645 val.i16_vec =
646 __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
647
648 return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
649#endif
650}
651
652template <ck_fp8_interpretation_t interpret,
653 bool saturate,
654 bool stochastic_rounding = false,
657static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0)
658{
659 ignore = rng;
660
661 union
662 {
663 unsigned int i32val;
664 half2_t half_vec;
665 shortx2_t i16_vec;
666 fp8_storage_t i8val[4];
667 } val;
668
669 constexpr shortx2_t i16x2val = {0, 0};
670 val.half_vec[0] = v;
671
672 if constexpr(saturate)
673 {
674 if((val.i32val & 0x7FFF) != 0x7FFF)
675 {
676 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
677 }
678 }
679
680 val.half_vec =
681 __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
682
683 return val.i8val[0];
684}
685
686template <ck_fp8_interpretation_t interpret,
687 bool saturate,
688 bool stochastic_rounding = false,
691static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0)
692{
693#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION
694 return fp8x2_storage_t{
695 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
696 cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
697#else
698 ignore = rng;
699
700 union
701 {
702 half2_t half_vec;
703 shortx2_t i16_vec;
704 fp8_storage_t i8val[4];
705 } val;
706
707 constexpr shortx2_t i16x2val = {0, 0};
708 val.half_vec = v;
709
710 if constexpr(saturate)
711 {
712 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
713 {
714 val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
715 }
716 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
717 {
718 val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 57344.0, -57344.0);
719 }
720 }
721
722 val.i16_vec =
723 __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0);
724
725 return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
726#endif
727}
728
729template <ck_fp8_interpretation_t interpret,
730 bool saturate,
731 bool stochastic_rounding = false,
734static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
735{
736 union
737 {
738 unsigned int i32val;
739 ushortx2_t bhalf_vec;
740 fp8_storage_t i8val[4];
741 } val;
742
743 constexpr unsigned int i32val = 0;
744 val.bhalf_vec[0] = v;
745
746 if constexpr(saturate)
747 {
748 if((val.i32val & 0x7FFF) != 0x7FFF)
749 {
750 val.bhalf_vec[0] =
751 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
752 bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
753 16)); // convert to float and back
754 }
755 }
756
757 val.i32val = __builtin_amdgcn_cvt_scalef32_sr_fp8_bf16(
758 i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0);
759
760 return val.i8val[0];
761}
762
763template <ck_fp8_interpretation_t interpret,
764 bool saturate,
765 bool stochastic_rounding = false,
768static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
769{
770 // there is no packed conversion with SR, so convert one element at a time
771 return fp8x2_storage_t{
772 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
773 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
774}
775
776template <ck_fp8_interpretation_t interpret,
777 bool saturate,
778 bool stochastic_rounding = false,
781static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
782{
783 union
784 {
785 unsigned int i32val;
786 ushortx2_t bhalf_vec;
787 fp8_storage_t i8val[4];
788 } val;
789
790 constexpr unsigned int i32val = 0;
791 val.bhalf_vec[0] = v;
792
793 if constexpr(saturate)
794 {
795 if((val.i32val & 0x7FFF) != 0x7FFF)
796 {
797 val.bhalf_vec[0] = ushort(
798 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
799 bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
800 16)); // convert to float and back
801 }
802 }
803
804 val.i32val = __builtin_amdgcn_cvt_scalef32_sr_bf8_bf16(
805 i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0);
806
807 return val.i8val[0];
808}
809
810template <ck_fp8_interpretation_t interpret,
811 bool saturate,
812 bool stochastic_rounding = false,
815static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
816{
817 // there is no packed conversion with SR, so convert one element at a time
818 return fp8x2_storage_t{
819 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
820 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
821}
822
823template <ck_fp8_interpretation_t interpret,
824 bool saturate,
825 bool stochastic_rounding = false,
828static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
829{
830 ignore = rng;
831
832 union
833 {
834 unsigned int i32val;
835 ushortx2_t bhalf_vec;
836 shortx2_t i16_vec;
837 fp8_storage_t i8val[4];
838 } val;
839
840 constexpr shortx2_t i16x2val = {0, 0};
841 val.bhalf_vec[0] = v;
842
843 if constexpr(saturate)
844 {
845 if((val.i32val & 0x7FFF) != 0x7FFF)
846 {
847 val.bhalf_vec[0] =
848 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
849 bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
850 16)); // convert to float and back
851 }
852 }
853
854 val.i16_vec =
855 __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
856
857 return val.i8val[0];
858}
859
860template <ck_fp8_interpretation_t interpret,
861 bool saturate,
862 bool stochastic_rounding = false,
865static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
866{
867#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
868 return fp8x2_storage_t{
869 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
870 cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
871#else
872 ignore = rng;
873
874 union
875 {
876 ushortx2_t bhalf_vec;
877 shortx2_t i16_vec;
878 fp8_storage_t i8val[4];
879 } val;
880
881 constexpr shortx2_t i16x2val = {0, 0};
882 val.bhalf_vec = v;
883
884 if constexpr(saturate)
885 {
886 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
887 {
888 val.bhalf_vec[0] =
889 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
890 bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
891 16)); // convert to float and back
892 }
893 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
894 {
895 val.bhalf_vec[1] =
896 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
897 bit_cast<float>(uint32_t{val.bhalf_vec[1]} << 16), 448.0, -448.0)) >>
898 16)); // convert to float and back
899 }
900 }
901
902 val.i16_vec =
903 __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
904
905 return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
906#endif
907}
908
909template <ck_fp8_interpretation_t interpret,
910 bool saturate,
911 bool stochastic_rounding = false,
914static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0)
915{
916 ignore = rng;
917
918 union
919 {
920 unsigned int i32val;
921 ushortx2_t bhalf_vec;
922 shortx2_t i16_vec;
923 fp8_storage_t i8val[4];
924 } val;
925
926 constexpr shortx2_t i16x2val = {0, 0};
927 val.bhalf_vec[0] = v;
928
929 if constexpr(saturate)
930 {
931 if((val.i32val & 0x7FFF) != 0x7FFF)
932 {
933 val.bhalf_vec[0] = ushort(
934 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
935 bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
936 16)); // convert to float and back
937 }
938 }
939
940 val.i16_vec =
941 __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
942
943 return val.i8val[0];
944}
945
946template <ck_fp8_interpretation_t interpret,
947 bool saturate,
948 bool stochastic_rounding = false,
951static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0)
952{
953 ignore = rng;
954
955 union
956 {
957 ushortx2_t bhalf_vec;
958 shortx2_t i16_vec;
959 fp8_storage_t i8val[4];
960 } val;
961
962 constexpr shortx2_t i16x2val = {0, 0};
963 val.bhalf_vec = v;
964
965 if constexpr(saturate)
966 {
967 if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
968 {
969 val.bhalf_vec[0] = ushort(
970 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
971 bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
972 16)); // convert to float and back
973 }
974 if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
975 {
976 val.bhalf_vec[1] = ushort(
977 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
978 bit_cast<float>(uint32_t{val.bhalf_vec[1]} << 16), 57344.0, -57344.0)) >>
979 16)); // convert to float and back
980 }
981 }
982
983 val.i16_vec =
984 __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0);
985
986 return fp8x2_storage_t{val.i8val[0], val.i8val[1]};
987}
988#endif // defined(__gfx950__)
989
990#if CK_FP8_CVT_FAST_PATH
991// The conversion function is from rocblas
992// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
993template <ck_fp8_interpretation_t interpret, bool saturate, bool stochastic_rounding = false>
994static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
995{
996 fp8_storage_t i8data;
997 union
998 {
999 float fval;
1000 unsigned int i32val;
1001 unsigned char i8val[4]; // NOTE: not endian independent
1002 } val;
1003
1004 unsigned int ival = 0;
1005 val.fval = v;
1006
1007 if constexpr(saturate)
1008 {
1009 if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
1010 {
1011 if((val.i32val & 0x7F800000) != 0x7F800000)
1012 {
1013 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
1014 }
1015 }
1016 else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
1017 { // OCP type
1018 if((val.i32val & 0x7F800000) != 0x7F800000)
1019 {
1020 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
1021 }
1022 }
1023 else
1024 {
1025 if((val.i32val & 0x7F800000) != 0x7F800000)
1026 {
1027 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
1028 }
1029 }
1030 }
1031
1032 if constexpr(stochastic_rounding)
1033 {
1034 ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
1036 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
1037 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
1038 val.i32val = ival;
1039 i8data = val.i8val[0]; // little endian
1040 }
1041 else
1042 { // RNE CVT
1043 ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
1045 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
1046 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
1047 val.fval,
1048 ival,
1049 false); // false -> WORD0
1050 val.i32val = ival;
1051 i8data = val.i8val[0];
1052 }
1053 return i8data;
1054}
1055
1056template <ck_fp8_interpretation_t interpret, bool saturate, bool stochastic_rounding = false>
1057static __device__ fp8x2_storage_t cast_to_f8_from_f32(float2_t v, unsigned int rng = 0)
1058{
1059 if constexpr(stochastic_rounding)
1060 {
1061 // there is no packed conversion with SR, so convert one element at a time
1062 return fp8x2_storage_t{
1063 cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[0], rng),
1064 cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[1], rng)};
1065 }
1066 else
1067 {
1068 union
1069 {
1070 float fval;
1071 unsigned int i32val;
1072 unsigned char i8val[4];
1073 } val0, val1;
1074
1075 val0.fval = v[0];
1076 val1.fval = v[1];
1077
1078 unsigned int ival = 0;
1079
1080 if constexpr(saturate)
1081 {
1082 if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
1083 {
1084 if((val0.i32val & 0x7F800000) != 0x7F800000)
1085 {
1086 val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 240.0, -240.0);
1087 }
1088 if((val1.i32val & 0x7F800000) != 0x7F800000)
1089 {
1090 val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 240.0, -240.0);
1091 }
1092 }
1093 else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
1094 { // OCP type
1095 if((val0.i32val & 0x7F800000) != 0x7F800000)
1096 {
1097 val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 448.0, -448.0);
1098 }
1099 if((val1.i32val & 0x7F800000) != 0x7F800000)
1100 {
1101 val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 448.0, -448.0);
1102 }
1103 }
1104 else
1105 {
1106 if((val0.i32val & 0x7F800000) != 0x7F800000)
1107 {
1108 val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 57344.0, -57344.0);
1109 }
1110 if((val1.i32val & 0x7F800000) != 0x7F800000)
1111 {
1112 val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 57344.0, -57344.0);
1113 }
1114 }
1115 }
1116
1117 // RNE CVT
1118 if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
1120 {
1121 ival = __builtin_amdgcn_cvt_pk_fp8_f32(val0.fval, val1.fval, ival, false);
1122 }
1123 else
1124 {
1125 ival = __builtin_amdgcn_cvt_pk_bf8_f32(val0.fval, val1.fval, ival, false);
1126 }
1127
1128 val0.i32val = ival;
1129
1130 return fp8x2_storage_t{val0.i8val[0], val0.i8val[1]};
1131 }
1132}
1133#endif // CK_FP8_CVT_FAST_PATH
1134
1135// The conversion function is from rocblas
1136// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
1137// This has been modified to add double types conversion as well
1138template <typename T, int wm, int we, bool is_fnuz, bool clip = false, bool stoch = false>
1139__host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rng = 0)
1140{
1141 constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
1142 constexpr bool is_float = __hip_internal::is_same<T, float>::value;
1143 constexpr bool is_double = __hip_internal::is_same<T, double>::value;
1144 static_assert(is_half || is_float || is_double,
1145 "Only half, float and double can be cast to f8");
1146
1147 constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
1148
1149 using T_bitwise = typename ck::conditional_t<
1150 sizeof(T) == 2,
1151 unsigned short int,
1152 typename ck::conditional_t<sizeof(T) == 4, unsigned int, unsigned long long>>;
1153 T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
1154
1155 unsigned long long x{x_bitwise};
1156
1157 unsigned long long head, mantissa;
1158 int exponent, bias;
1159 unsigned int sign;
1160 unsigned long long fInf, mask;
1161
1162 if constexpr(sizeof(T) == 8)
1163 {
1164 head = x & 0xFFF0000000000000ull;
1165 mantissa = x & 0xFFFFFFFFFFFFFull;
1166 exponent = (head >> 52) & 0x7FF;
1167 sign = head >> 63;
1168 bias = 1023;
1169 fInf = 0x7FF0000000000000ull;
1170 mask = 0x7FFFFFFFFFFFFFFFull;
1171 }
1172 else if constexpr(sizeof(T) == 4)
1173 {
1174 head = x & 0xFF800000;
1175 mantissa = x & 0x7FFFFF;
1176 exponent = (head >> 23) & 0xFF;
1177 sign = head >> 31;
1178 bias = 127;
1179 fInf = 0x7F800000;
1180 mask = 0x7FFFFFFF;
1181 }
1182 else
1183 {
1184 head = x & 0xFC00;
1185 mantissa = x & 0x3FF;
1186 exponent = (head >> 10) & 0x1F;
1187 sign = head >> 15;
1188 bias = 15;
1189 fInf = 0x7C00;
1190 mask = 0x7FFF;
1191 }
1192 unsigned int signed_inf = 0;
1193 unsigned int nan = 0;
1194 if constexpr(is_fnuz)
1195 {
1196 signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
1197 nan = 0x80;
1198 }
1199 else
1200 {
1201 if constexpr(we == 4)
1202 { // e4m3
1203 signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
1204 }
1205 else
1206 { // e5m2
1207 signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
1208 }
1209 nan = (sign << 7) + 0x7f;
1210 }
1211 // Max values
1212 unsigned long long ifmax = 0;
1213 if constexpr(sizeof(T) == 8)
1214 {
1215 if constexpr(we == 5)
1216 { // 57344
1217 ifmax = 0x40EC000000000000ull;
1218 }
1219 else
1220 {
1221 if constexpr(is_fnuz)
1222 { // 240
1223 ifmax = 0x406E000000000000ull;
1224 }
1225 else
1226 { // 448
1227 ifmax = 0x407C000000000000ull;
1228 }
1229 }
1230 }
1231 else if(sizeof(T) == 4)
1232 {
1233 if constexpr(we == 5)
1234 {
1235 ifmax = 0x47600000;
1236 }
1237 else
1238 {
1239 if constexpr(is_fnuz)
1240 {
1241 ifmax = 0x43700000;
1242 }
1243 else
1244 {
1245 ifmax = 0x43E00000;
1246 }
1247 }
1248 }
1249 else
1250 {
1251 if constexpr(we == 5)
1252 {
1253 ifmax = 0x7B00;
1254 }
1255 else
1256 {
1257 if constexpr(is_fnuz)
1258 {
1259 ifmax = 0x5B80;
1260 }
1261 else
1262 {
1263 ifmax = 0x5F00;
1264 }
1265 }
1266 }
1267 // Deal with inf and NaNs
1268 if((x & fInf) == fInf)
1269 {
1270 if constexpr(is_fnuz)
1271 return signed_inf;
1272
1273 return mantissa != 0 ? nan : signed_inf;
1274 }
1275
1276 if((x & mask) > ifmax)
1277 {
1278 return signed_inf;
1279 }
1280
1281 if(x == 0)
1282 {
1283 return 0;
1284 }
1285
1286 // First need to check if it is normal or denorm as there is a difference of
1287 // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
1288 // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
1289 // to mantissa and truncate. And for RNE, no need to add rng. Then probably
1290 // need to check whether there is carry and adjust exponent and mantissa again
1291
1292 // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
1293 // bits
1294 const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
1295 const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
1296 // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
1297 // f8_exponent is the converted f8 exponent with bias encoding
1298 // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
1299 // the difference needs to be adjusted and mantissa shifted
1300 int act_exponent, f8_exponent, exponent_diff;
1301
1302 if(exponent == 0)
1303 { // fp32/fp16 is in denormal.
1304 /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
1305 mostly concern fp16 here. In this case, f8 is usually in denormal. But there
1306 could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
1307 exponent bias 16. It means that there are some numbers in fp16 denormal but they
1308 are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
1309 where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
1310 (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
1311 act_exponent = exponent - bias + 1;
1312 exponent_diff = f8_denormal_act_exponent -
1313 act_exponent; // actual exponent is exponent-bias+1 as it is denormal
1314 }
1315 else
1316 { // fp32/fp16 is normal with implicit 1
1317 act_exponent = exponent - bias;
1318 if(act_exponent <= f8_denormal_act_exponent)
1319 {
1320 /* This is the case where fp32/fp16 is normal but it is in f8 denormal
1321 range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
1322 actual exponent is -7, it is actually larger due to the implicit 1,
1323 Therefore it needs to be adjust to -6 and mantissa shift right by 1.
1324 So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
1325 exponent_diff = f8_denormal_act_exponent - act_exponent;
1326 }
1327 else
1328 { // both fp32/fp16 and f8 are in normal range
1329 exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
1330 // for this case, act_exponent could be larger. Just
1331 // that it does not need shift mantissa
1332 }
1333 mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa
1334 }
1335
1336 bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
1337 (1ull << (mfmt - wm + exponent_diff - 1));
1338 /* This part is a bit tricky. The judgment of whether it is a tie needs to be
1339 done before we shift right as shift right could rip off some residual part and
1340 make something not midpoint look like midpoint. For example, the fp16 number
1341 0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
1342 by 4 bits, it would look like midpoint.
1343 */
1344
1345 if(exponent_diff > 0)
1346 mantissa >>= exponent_diff;
1347 else if(exponent_diff == -1)
1348 mantissa <<= -exponent_diff;
1349 bool implicit_one = mantissa & (1ull << mfmt);
1350 // if there is no implicit 1, it means the f8 is denormal and need to adjust
1351 // to denorm exponent
1352 f8_exponent =
1353 (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
1354
1355 // Now we have the exponent and mantissa adjusted
1356 unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
1357 bool odd =
1358 mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1
1359 mantissa +=
1360 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
1361
1362 // Now we deal with overflow
1363 if(f8_exponent == 0)
1364 {
1365 if((1ull << mfmt) & mantissa)
1366 {
1367 f8_exponent = 1; // denormal overflow to become normal, promote exponent
1368 }
1369 }
1370 else
1371 {
1372 if((1ull << (mfmt + 1)) & mantissa)
1373 {
1374 mantissa >>= 1;
1375 f8_exponent++;
1376 }
1377 }
1378
1379 mantissa >>= (mfmt - wm);
1380
1381 // above range: quantize to maximum possible float of the same sign
1382 const int max_exp = (1 << we) - 1;
1383 if(f8_exponent > max_exp)
1384 {
1385 if constexpr(clip)
1386 {
1387 mantissa = (1 << wm) - 1;
1388 f8_exponent = max_exp;
1389 }
1390 else
1391 {
1392 return signed_inf;
1393 }
1394 }
1395
1396 if(f8_exponent == 0 && mantissa == 0)
1397 return is_fnuz ? 0 : (sign << 7);
1398 mantissa &= (1 << wm) - 1;
1399 return (sign << 7) | (f8_exponent << wm) | mantissa;
1400}
1401
1411template <ck_fp8_interpretation_t interp,
1413 bool stochastic_rounding = false>
1414#if CK_FP8_CVT_FAST_PATH
1415__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
1416{
1417 __is_interpret_supported(interp);
1418 uint32_t rng = 0;
1419 if constexpr(stochastic_rounding)
1420 {
1421#if defined(__gfx950__)
1422 // use HW clock for stochastic input multiply by incremented thread id
1423 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1424 (get_thread_global_1d_id() + 1));
1425#else
1426 constexpr int seed = 1254739;
1427#ifndef CK_CODE_GEN_RTC
1428 rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
1429#else
1430 rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
1431#endif // #ifndef CK_CODE_GEN_RTC
1432#endif // #if defined(__gfx950__)
1433 }
1434 return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1435 f, rng);
1436#else
1437#if CK_USE_OCP_FP8
1438__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
1439{
1440#else
1441__host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
1442{
1443#endif
1444 uint32_t rng = 0;
1445 if constexpr(stochastic_rounding)
1446 {
1447#if defined(__gfx950__)
1448 // use HW clock for stochastic input multiply by incremented thread id
1449 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1450 (get_thread_global_1d_id() + 1));
1451#else
1452 constexpr int seed = 1254739;
1453#ifndef CK_CODE_GEN_RTC
1454 rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
1455#else
1456 rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
1457#endif // #ifndef CK_CODE_GEN_RTC
1458#endif // #if defined(__gfx950__)
1459 }
1460
1461 if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
1462 {
1463 return cast_to_f8<float,
1464 3,
1465 4,
1466 true,
1468 stochastic_rounding>(f, rng);
1469 }
1470 else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_FNUZ)
1471 {
1472 return cast_to_f8<float,
1473 2,
1474 5,
1475 true,
1477 stochastic_rounding>(f, rng);
1478 }
1479 else if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
1480 {
1481 return cast_to_f8<float,
1482 3,
1483 4,
1484 false,
1486 stochastic_rounding>(f, rng);
1487 }
1488 else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
1489 {
1490 return cast_to_f8<float,
1491 2,
1492 5,
1493 false,
1495 stochastic_rounding>(f, rng);
1496 }
1497 else
1498 {
1499 __hip_assert(false && "FP8 type is not supported by current target device");
1500 return 0;
1501 }
1502#endif // CK_FP8_CVT_FAST_PATH
1503}
1504
1514template <ck_fp8_interpretation_t interp,
1516 bool stochastic_rounding = false>
1517#if CK_FP8_CVT_FAST_PATH
1518__device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
1519{
1520 __is_interpret_supported(interp);
1521 uint32_t rng = 0;
1522 if constexpr(stochastic_rounding)
1523 {
1524#if defined(__gfx950__)
1525 // use HW clock for stochastic input multiply by incremented thread id
1526 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1527 (get_thread_global_1d_id() + 1));
1528#else
1529 constexpr int seed = 1254739;
1530#ifndef CK_CODE_GEN_RTC
1531 rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
1532#else
1533 rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f[0]);
1534#endif // #ifndef CK_CODE_GEN_RTC
1535#endif // #if defined(__gfx950__)
1536 }
1537 return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1538 f, rng);
1539#else
1540#if CK_USE_OCP_FP8
1541__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
1542{
1543#else
1544__host__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
1545{
1546#endif // CK_USE_OCP_FP8
1547 return fp8x2_storage_t{cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[0]),
1548 cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[1])};
1549#endif // CK_FP8_CVT_FAST_PATH
1550}
1551
1561template <ck_fp8_interpretation_t interp,
1563 bool stochastic_rounding = false>
1564#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1565__host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
1566#else
1567__host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
1568#endif
1569{
1570 {
1571 __is_interpret_supported(interp);
1572 uint32_t rng = 0;
1573 if constexpr(stochastic_rounding)
1574 {
1575#if defined(__gfx950__)
1576 // use HW clock for stochastic input multiply by incremented thread id
1577 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1578 (get_thread_global_1d_id() + 1));
1579#else
1580 constexpr int seed = 1254739;
1581#ifndef CK_CODE_GEN_RTC
1582 rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
1583#else
1584 rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
1585#endif // #ifndef CK_CODE_GEN_RTC
1586#endif // #if defined(__gfx950__)
1587 }
1588#if defined(__gfx950__)
1589 return cast_to_f8_from_f16<interp,
1591 stochastic_rounding>(x, rng);
1592#else
1593 ignore = rng;
1594 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1595 static_cast<float>(x));
1596#endif // defined(__gfx950__)
1597 }
1598}
1599
1609template <ck_fp8_interpretation_t interp,
1611 bool stochastic_rounding = false>
1612#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1613__host__ __device__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
1614#else
1615__host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
1616#endif
1617{
1618 {
1619 __is_interpret_supported(interp);
1620 uint32_t rng = 0;
1621 if constexpr(stochastic_rounding)
1622 {
1623#if defined(__gfx950__)
1624 // use HW clock for stochastic input multiply by incremented thread id
1625 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1626 (get_thread_global_1d_id() + 1));
1627#else
1628 constexpr int seed = 1254739;
1629#ifndef CK_CODE_GEN_RTC
1630 rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
1631#else
1632 rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
1633#endif // #ifndef CK_CODE_GEN_RTC
1634#endif // #if defined(__gfx950__)
1635 }
1636#if defined(__gfx950__)
1637 return cast_to_f8_from_f16<interp,
1639 stochastic_rounding>(x, rng);
1640#else
1641 ignore = rng;
1642 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1643 float2_t{static_cast<float>(x[0]), static_cast<float>(x[1])});
1644#endif // defined(__gfx950__)
1645 }
1646}
1647
1657template <ck_fp8_interpretation_t interp,
1659 bool stochastic_rounding = false>
1660#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1661__host__ __device__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
1662#else
1663__host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
1664#endif
1665{
1666 {
1667 __is_interpret_supported(interp);
1668 uint32_t rng = 0;
1669 if constexpr(stochastic_rounding)
1670 {
1671#if defined(__gfx950__)
1672 // use HW clock for stochastic input multiply by incremented thread id
1673 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1674 (get_thread_global_1d_id() + 1));
1675#else
1676 constexpr int seed = 1254739;
1677#ifndef CK_CODE_GEN_RTC
1678 rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
1679 static_cast<float>(x));
1680#else
1681 rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), static_cast<float>(x));
1682#endif // #ifndef CK_CODE_GEN_RTC
1683#endif // #if defined(__gfx950__)
1684 }
1685#if defined(__gfx950__)
1686 return cast_to_f8_from_bf16<interp,
1688 stochastic_rounding>(x, rng);
1689#else
1690 ignore = rng;
1691 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1692 bit_cast<float>(uint32_t{x} << 16)); // convert value to float
1693#endif // defined(__gfx950__)
1694 }
1695}
1696
1706template <ck_fp8_interpretation_t interp,
1708 bool stochastic_rounding = false>
1709#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
1710__host__ __device__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
1711#else
1712__host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
1713#endif
1714{
1715#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION
1716 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1717 float2_t{bit_cast<float>(uint32_t{x[0]} << 16),
1718 bit_cast<float>(uint32_t{x[1]} << 16)}); // convert values to float
1719#else // CK_WORKAROUND_BF16_TO_FP8_CONVERSION
1720 {
1721 __is_interpret_supported(interp);
1722 uint32_t rng = 0;
1723 if constexpr(stochastic_rounding)
1724 {
1725#if defined(__gfx950__)
1726 // use HW clock for stochastic input multiply by incremented thread id
1727 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1728 (get_thread_global_1d_id() + 1));
1729#else
1730 constexpr int seed = 1254739;
1731#ifndef CK_CODE_GEN_RTC
1732 rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
1733 static_cast<float>(x[0]));
1734#else
1735 rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x),
1736 static_cast<float>(x[0]));
1737#endif // #ifndef CK_CODE_GEN_RTC
1738#endif // #if defined(__gfx950__)
1739 }
1740#if defined(__gfx950__)
1741 return cast_to_f8_from_bf16<interp,
1743 stochastic_rounding>(x, rng);
1744#else
1745 ignore = rng;
1746 return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
1747 float2_t{bit_cast<float>(uint32_t{x[0]} << 16),
1748 bit_cast<float>(uint32_t{x[1]} << 16)}); // convert values to float
1749#endif // defined(__gfx950__)
1750 }
1751#endif // CK_WORKAROUND_BF16_TO_FP8_CONVERSION
1752}
1753
1754} // namespace fp8_impl
1755
1756#if CK_USE_OCP_FP8
1757using f8_t = f8_ocp_t;
1758using bf8_t = bf8_ocp_t;
1759#define CK_FP8_TYPE_FNUZ 0
1760#define CK_FP8_TYPE_OCP 1
1761#else
1764#define CK_FP8_TYPE_FNUZ 1
1765#define CK_FP8_TYPE_OCP 0
1766#endif
1767
1768} // namespace ck
#define __fp8_impl_assert_fnuz_support(interp)
Definition amd_ck_fp8.hpp:464
#define __fp8_impl_assert_ocp_support(interp)
Definition amd_ck_fp8.hpp:456
Definition amd_ck_fp8.hpp:86
ushort ushortx2_t
Definition amd_ck_fp8.hpp:90
short shortx2_t
Definition amd_ck_fp8.hpp:91
float float2_t
Definition amd_ck_fp8.hpp:92
fp8_storage_t fp8x2_storage_t
Definition amd_ck_fp8.hpp:88
_Float16 half2_t
Definition amd_ck_fp8.hpp:89
Definition ck.hpp:268
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
ck_fp8_interpretation_t
Describes FP8 interpretation.
Definition amd_ck_fp8.hpp:70
@ CK_E4M3_OCP
Definition amd_ck_fp8.hpp:71
@ CK_E5M2_OCP
Definition amd_ck_fp8.hpp:72
@ CK_E5M2_FNUZ
Definition amd_ck_fp8.hpp:74
@ CK_E4M3_FNUZ
Definition amd_ck_fp8.hpp:73
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bf8_fnuz_t bf8_t
Definition amd_ck_fp8.hpp:1763
__device__ index_t get_thread_global_1d_id()
Definition get_id.hpp:43
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed=seed_t)
Definition random_gen.hpp:19
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
ck_saturation_t
Describes saturation behavior.
Definition amd_ck_fp8.hpp:81
@ CK_SATFINITE
Definition amd_ck_fp8.hpp:83
@ CK_NOSAT
Definition amd_ck_fp8.hpp:82
unsigned char fp8_storage_t
Definition amd_ck_fp8.hpp:64
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
_W64 unsigned int uintptr_t
Definition stdint.h:164
unsigned int uint32_t
Definition stdint.h:126
Definition amd_ck_fp8.hpp:49
data_type m_data
Definition amd_ck_fp8.hpp:51
__host__ __device__ bool constexpr operator==(bf8_fnuz_t other) const
Definition amd_ck_fp8.hpp:54
unsigned char data_type
Definition amd_ck_fp8.hpp:50
__host__ __device__ constexpr bf8_fnuz_t(data_type in_data)
Definition amd_ck_fp8.hpp:52
__host__ __device__ constexpr bf8_fnuz_t()=default
Definition amd_ck_fp8.hpp:369
static constexpr unsigned int wm
Definition amd_ck_fp8.hpp:378
static constexpr unsigned int we
Definition amd_ck_fp8.hpp:377
fp8_storage_t data_type
Definition amd_ck_fp8.hpp:370
data_type data
Definition amd_ck_fp8.hpp:371
static constexpr ck_fp8_interpretation_t default_interpret
Definition amd_ck_fp8.hpp:374
static constexpr ck_saturation_t default_saturation
Definition amd_ck_fp8.hpp:373
__host__ __device__ constexpr bool operator==(const bf8_ocp_t &other) const
Definition amd_ck_fp8.hpp:380
Definition amd_ck_fp8.hpp:36
__host__ __device__ constexpr f8_fnuz_t()=default
data_type m_data
Definition amd_ck_fp8.hpp:38
__host__ __device__ constexpr f8_fnuz_t(data_type in_data)
Definition amd_ck_fp8.hpp:39
__host__ __device__ bool constexpr operator==(f8_fnuz_t other) const
Definition amd_ck_fp8.hpp:41
unsigned char data_type
Definition amd_ck_fp8.hpp:37
Definition amd_ck_fp8.hpp:323
fp8_storage_t data_type
Definition amd_ck_fp8.hpp:324
data_type data
Definition amd_ck_fp8.hpp:325
static constexpr unsigned int we
Definition amd_ck_fp8.hpp:331
__host__ __device__ constexpr bool operator==(const f8_ocp_t &other) const
Definition amd_ck_fp8.hpp:334
static constexpr unsigned int wm
Definition amd_ck_fp8.hpp:332
static constexpr ck_fp8_interpretation_t default_interpret
Definition amd_ck_fp8.hpp:328
static constexpr ck_saturation_t default_saturation
Definition amd_ck_fp8.hpp:327