17#if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
18#define CK_TILE_FP8_CVT_DEVICE 1
20#define CK_TILE_FP8_CVT_DEVICE 0
70#if CK_TILE_USE_CUSTOM_DATA_TYPE
71struct alignas(1) float8_e4m3_t
73 static constexpr int exponent = 4;
74 static constexpr int mantissa = 3;
75#if CK_TILE_USE_OCP_FP8
76 static constexpr int bias = 7;
78 static constexpr int bias = 8;
84 static constexpr float8_e4m3_t
bit_cast(raw_type x)
92 constexpr float8_e4m3_t() : data() {}
96 explicit constexpr float8_e4m3_t(
const float& x) : data(
float_to_fp8_raw(x)) {}
100 explicit constexpr float8_e4m3_t(
const int& x) : data(
float_to_fp8_raw(static_cast<float>(x)))
106 explicit constexpr float8_e4m3_t(
const unsigned int& x)
113 explicit constexpr operator float()
const {
return fp8_to_float_raw(data); }
117 explicit constexpr operator int()
const {
return static_cast<int>(
fp8_to_float_raw(data)); }
121 constexpr raw_type& get() {
return data; }
124 constexpr raw_type get()
const {
return data; }
126using fp8_t = float8_e4m3_t;
127using fp8_raw_t =
typename fp8_t::raw_type;
129struct alignas(1) float8_e5m2_t
131 static constexpr int exponent = 5;
132 static constexpr int mantissa = 2;
133#if CK_TILE_USE_OCP_FP8
134 static constexpr int bias = 15;
136 static constexpr int bias = 16;
142 static constexpr float8_e5m2_t
bit_cast(raw_type x)
150 constexpr float8_e5m2_t() : data() {}
154 explicit constexpr float8_e5m2_t(
const float& x) : data(
float_to_bf8_raw(x)) {}
158 explicit constexpr float8_e5m2_t(
const int& x) : data(
float_to_bf8_raw(static_cast<float>(x)))
164 explicit constexpr float8_e5m2_t(
const unsigned int& x)
171 explicit constexpr operator float()
const {
return bf8_to_float_raw(data); }
175 explicit constexpr operator int()
const {
return static_cast<int>(
bf8_to_float_raw(data)); }
179 constexpr raw_type& get() {
return data; }
182 constexpr raw_type get()
const {
return data; }
184using bf8_t = float8_e5m2_t;
185using bf8_raw_t =
typename bf8_t::raw_type;
193 using type = _BitInt(8);
199 using type =
unsigned _BitInt(8);
215 static constexpr int exp = 4;
217#if CK_TILE_USE_OCP_FP8
218 static constexpr int bias = 7;
233 static constexpr int exp = 5;
235#if CK_TILE_USE_OCP_FP8
236 static constexpr int bias = 15;
239 static constexpr int bias = 16;
249template <
typename SrcT,
typename DstT,
bool clip = true,
bool stoch = false>
252 static_assert(std::is_same<DstT, fp8_t>::value || std::is_same<DstT, bf8_t>::value,
253 "DstT type must be fp8 or bf8.");
255 constexpr bool is_half = std::is_same<SrcT, half_t>::value;
256 constexpr bool is_float = std::is_same<SrcT, float>::value;
257 static_assert(is_half || is_float,
"Only half and float can be cast to f8");
263 constexpr bool is_fnuz =
276 unsigned int head, mantissa;
283 sign = head >> (SrcT_exp + SrcT_mant);
285 unsigned int signed_inf = 0;
286 unsigned int nan = 0;
287 if constexpr(is_fnuz)
289 signed_inf = clip ? ((sign << (DstT_exp + DstT_mant)) + 0x7f) : 0x80;
294 if constexpr(DstT_exp == 4)
296 signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7e : 0x7f);
300 signed_inf = (sign << (DstT_exp + DstT_mant)) + (clip ? 0x7b : 0x7c);
302 nan = (sign << (DstT_exp + DstT_mant)) + 0x7f;
305 unsigned int ifmax = 0;
306 if constexpr(is_float)
308 if constexpr(DstT_exp == 5)
314 if constexpr(is_fnuz)
324 else if constexpr(is_half)
326 if constexpr(DstT_exp == 5)
332 if constexpr(is_fnuz)
344 if((src_bitwise & fInf) == fInf)
346 return mantissa != 0 ? nan : signed_inf;
349 if((src_bitwise & abs_mask) > ifmax)
362 constexpr int f8_denormal_act_exponent = 1 - DstT_bias;
367 int act_exponent, f8_exponent, exponent_diff;
378 act_exponent = exponent - bias + 1;
379 exponent_diff = f8_denormal_act_exponent -
384 act_exponent = exponent - bias;
385 if(act_exponent <= f8_denormal_act_exponent)
392 exponent_diff = f8_denormal_act_exponent - act_exponent;
400 mantissa += (1u << SrcT_mant);
404 if(exponent_diff > DstT_mant + 1)
406 return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
408 bool midpoint = (mantissa & ((1u << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
409 (1u << (SrcT_mant - DstT_mant + exponent_diff - 1));
417 if(exponent_diff > 0)
418 mantissa >>= exponent_diff;
419 else if(exponent_diff == -1)
420 mantissa <<= -exponent_diff;
421 bool implicit_one = mantissa & (1u << SrcT_mant);
425 (act_exponent + exponent_diff) + DstT_bias - (implicit_one ? 0 : 1);
428 unsigned int drop_mask = (1u << (SrcT_mant - DstT_mant)) - 1;
431 (1u << (SrcT_mant - DstT_mant));
433 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1u) : mantissa)) & drop_mask;
438 if((1u << SrcT_mant) & mantissa)
445 if((1u << (SrcT_mant + 1)) & mantissa)
452 mantissa >>= (SrcT_mant - DstT_mant);
455 const int max_exp = (1 << DstT_exp) - 1;
456 if(f8_exponent > max_exp)
460 mantissa = (1 << DstT_mant) - 1;
461 f8_exponent = max_exp;
469 if(f8_exponent == 0 && mantissa == 0)
470 return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
471 mantissa &= (1 << DstT_mant) - 1;
472 return (sign << (DstT_exp + DstT_mant)) | (f8_exponent << DstT_mant) | mantissa;
475template <
typename SrcT,
typename DstT,
bool clip = true>
478 static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
479 "SrcT type must be fp8 or bf8.");
483 constexpr bool is_fnuz =
487 constexpr bool is_half = std::is_same<DstT, half_t>::value;
488 constexpr bool is_float = std::is_same<DstT, float>::value;
489 static_assert(is_half || is_float,
"DstT type must be half_t or float.");
500 DstT fmax{0}, fmin{0};
502 if constexpr(is_half)
507 else if constexpr(is_float)
518 unsigned int sign = x >> (SrcT_exp + SrcT_mant);
519 unsigned int mantissa = x & ((1 << SrcT_mant) - 1);
520 int exponent = (x & SrcT_abs_mask) >> SrcT_mant;
521 if constexpr(is_fnuz)
523 if((x & 0xff) == 0x80)
534 if constexpr(SrcT_exp == 4)
536 if((x & 0x7F) == 0x7F)
541 else if((x & 0x7C) == 0x7C)
547 return sign ? fmin : fmax;
549 return sign ? fNegInf : fInf;
557 if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
563 const int exp_low_cutoff =
564 (1 << (DstT_exp - 1)) - (1 << (SrcT_exp - 1)) + 1 - (is_fnuz ? 1 : 0);
569 int sh = 1 +
clz(mantissa) - (32 - SrcT_mant);
572 mantissa &= ((1ull << SrcT_mant) - 1);
574 exponent += exp_low_cutoff - 1;
575 mantissa <<= DstT_mant - SrcT_mant;
580 mantissa |= 1 << DstT_mant;
581 mantissa >>= 1 - exponent;
585 retval = (sign << (DstT_exp + DstT_mant)) | (exponent << DstT_mant) | mantissa;
590template <
typename X,
typename Y,
bool clip,
bool stoch>
596#if CK_TILE_FP8_CVT_DEVICE
600template <fp8_
interpretation
interpret,
bool saturate,
bool stochastic_rounding = false>
608 unsigned char i8val[4];
611 unsigned int ival = 0;
614 if constexpr(saturate)
618 if((val.i32val & 0x7F800000) != 0x7F800000)
620 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
625 if((val.i32val & 0x7F800000) != 0x7F800000)
627 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
632 if((val.i32val & 0x7F800000) != 0x7F800000)
634 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
639 if constexpr(stochastic_rounding)
643 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
644 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0);
646 i8data = val.i8val[0];
652 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
false)
653 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
658 i8data = val.i8val[0];
679template <
typename SrcT,
typename DstT>
682 constexpr bool clip =
true;
683 constexpr int seed = 42;
685#if CK_TILE_FP8_CVT_DEVICE
686 return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip,
true>(x, rng);
705template <
typename SrcT,
typename DstT>
708 constexpr bool clip =
true;
709#if CK_TILE_FP8_CVT_DEVICE
710 return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip,
false>(x, 0);
717template <fp8_rounding_mode rounding>
734template <fp8_rounding_mode rounding>
753#if CK_TILE_FP8_CVT_DEVICE
756 fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
766#if CK_TILE_FP8_CVT_DEVICE
769 fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
796#if CK_TILE_USE_OCP_FP8
1055#if CK_TILE_USE_CUSTOM_DATA_TYPE
1061template <
typename T>
1064 static_assert(std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>,
1065 "Only fp8_t and bf8_t are supported");
1074#if CK_TILE_USE_OCP_FP8
1075 return (xx & 0x7f) == 0x7f;
1080#if CK_TILE_USE_CUSTOM_DATA_TYPE
1082fp8_t sqrt(
fp8_t x) {
return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(
static_cast<float>(x))); };
1085fp8_t exp(
fp8_t x) {
return static_cast<fp8_t>(__ocml_exp_f32(
static_cast<float>(x))); };
1099#if CK_TILE_USE_OCP_FP8
1100 return (xx & 0x7f) > 0x7c;
1106#if CK_TILE_USE_CUSTOM_DATA_TYPE
1108bf8_t sqrt(
bf8_t x) {
return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(
static_cast<float>(x))); };
1111bf8_t exp(
bf8_t x) {
return static_cast<bf8_t>(__ocml_exp_f32(
static_cast<float>(x))); };
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_FLOAT_TO_FP8_DEFAULT
Definition config.hpp:79
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/arch/amd_buffer_addressing.hpp:110
CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng=0)
Definition float8.hpp:250
CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
Definition float8.hpp:476
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
Definition float8.hpp:591
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_rtn_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with rounding to nearest ev...
Definition float8.hpp:706
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
fp8_interpretation
FP8 interpretation used in conversion algorithms.
Definition float8.hpp:38
@ E4M3_OCP
Definition float8.hpp:39
@ E5M2_OCP
Definition float8.hpp:40
@ E5M2_FNUZ
Definition float8.hpp:42
@ E4M3_FNUZ
Definition float8.hpp:41
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant< rounding >={})
Definition float8.hpp:778
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition float8.hpp:751
_BitInt(8) fp8_t
Definition float8.hpp:204
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition float8.hpp:764
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
fp8_rounding_mode
Definition float8.hpp:29
@ stochastic
Definition float8.hpp:31
@ standard
Definition float8.hpp:30
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition bfloat16.hpp:413
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant< rounding >={})
Definition float8.hpp:718
uint8_t fp8_raw_t
Definition float8.hpp:205
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x)
Definition float8.hpp:791
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_HOST int clz(uint32_t x)
Definition tile/core/numeric/math.hpp:264
@ standard
Definition bfloat16.hpp:20
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition bfloat16.hpp:400
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
uint8_t bf8_raw_t
Definition float8.hpp:207
@ constant
Definition arch.hpp:51
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant< rounding >={})
Definition float8.hpp:784
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition bfloat16.hpp:406
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x)
Definition float8.hpp:789
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant< rounding >={})
Definition float8.hpp:735
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_sr_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with stochastic rounding.
Definition float8.hpp:680
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
_W64 unsigned int uintptr_t
Definition stdint.h:164
unsigned int uint32_t
Definition stdint.h:126
unsigned char uint8_t
Definition stdint.h:124
Definition tile/core/numeric/integral_constant.hpp:13
Definition vector_type.hpp:26
remove_cvref_t< T > type
Definition vector_type.hpp:27
static CK_TILE_HOST_DEVICE constexpr bf8_t infinity()
Definition float8.hpp:1025
static CK_TILE_HOST_DEVICE constexpr bf8_t zero()
Definition float8.hpp:1048
static CK_TILE_HOST_DEVICE constexpr bf8_t quiet_NaN()
Definition float8.hpp:1031
static CK_TILE_HOST_DEVICE constexpr bf8_t max()
Definition float8.hpp:1003
static CK_TILE_HOST_DEVICE constexpr bf8_t round_error()
Definition float8.hpp:1019
static CK_TILE_HOST_DEVICE constexpr bf8_t min()
Definition float8.hpp:991
static CK_TILE_HOST_DEVICE constexpr bf8_t lowest()
Definition float8.hpp:997
static CK_TILE_HOST_DEVICE constexpr bf8_t signaling_NaN()
Definition float8.hpp:1037
static CK_TILE_HOST_DEVICE constexpr bf8_t epsilon()
Definition float8.hpp:1009
static CK_TILE_HOST_DEVICE constexpr bf8_t denorm_min()
Definition float8.hpp:1043
static CK_TILE_HOST_DEVICE constexpr fp8_t round_error()
Definition float8.hpp:952
static CK_TILE_HOST_DEVICE constexpr fp8_t signaling_NaN()
Definition float8.hpp:970
static CK_TILE_HOST_DEVICE constexpr fp8_t epsilon()
Definition float8.hpp:942
static CK_TILE_HOST_DEVICE constexpr fp8_t max()
Definition float8.hpp:936
static CK_TILE_HOST_DEVICE constexpr fp8_t infinity()
Definition float8.hpp:958
static CK_TILE_HOST_DEVICE constexpr fp8_t zero()
Definition float8.hpp:981
static CK_TILE_HOST_DEVICE constexpr fp8_t quiet_NaN()
Definition float8.hpp:964
static CK_TILE_HOST_DEVICE constexpr fp8_t min()
Definition float8.hpp:924
static CK_TILE_HOST_DEVICE constexpr fp8_t lowest()
Definition float8.hpp:930
static CK_TILE_HOST_DEVICE constexpr fp8_t denorm_min()
Definition float8.hpp:976
static constexpr uint8_t abs_mask
Definition float8.hpp:242
static constexpr int PackedSize
Definition float8.hpp:243
static constexpr fp8_interpretation f8_interpret
Definition float8.hpp:240
static constexpr int exp
Definition float8.hpp:233
static constexpr int mant
Definition float8.hpp:234
static constexpr int bias
Definition float8.hpp:239
bf8_raw_t bitwise_type
Definition float8.hpp:231
static constexpr int bias
Definition float8.hpp:221
fp8_raw_t bitwise_type
Definition float8.hpp:213
static constexpr uint8_t abs_mask
Definition float8.hpp:224
static constexpr fp8_interpretation f8_interpret
Definition float8.hpp:222
static constexpr int mant
Definition float8.hpp:216
static constexpr int exp
Definition float8.hpp:215
static constexpr int PackedSize
Definition float8.hpp:225
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/numeric/numeric.hpp:18
static CK_TILE_HOST_DEVICE constexpr T quiet_NaN()
Definition tile/core/numeric/numeric.hpp:41
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
static CK_TILE_HOST_DEVICE constexpr T lowest()
Definition tile/core/numeric/numeric.hpp:23
static CK_TILE_HOST_DEVICE constexpr T zero()
Definition tile/core/numeric/numeric.hpp:58
static CK_TILE_HOST_DEVICE constexpr T denorm_min()
Definition tile/core/numeric/numeric.hpp:53
static CK_TILE_HOST_DEVICE constexpr T round_error()
Definition tile/core/numeric/numeric.hpp:32
static CK_TILE_HOST_DEVICE constexpr T signaling_NaN()
Definition tile/core/numeric/numeric.hpp:47
static CK_TILE_HOST_DEVICE constexpr T max()
Definition tile/core/numeric/numeric.hpp:26
static CK_TILE_HOST_DEVICE constexpr T min()
Definition tile/core/numeric/numeric.hpp:20
static CK_TILE_HOST_DEVICE constexpr T epsilon()
Definition tile/core/numeric/numeric.hpp:29
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition tile/core/numeric/numeric.hpp:106