variants.hpp Source File

variants.hpp Source File#

Composable Kernel: variants.hpp Source File
variants.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <type_traits>
7
10
11#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0
12#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1
13
14#ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT
15#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
16#endif
17
18#ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM
19#define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0
20#endif
21
22namespace ck_tile {
23namespace internal {
24__device__ inline float
25exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
26{
27#if(defined(__gfx90a__) || defined(__gfx94__)) && \
28 (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
29 CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
31 float result, numerator, denominator;
32 asm volatile(
33 "v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n"
34 "v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n"
35 "v_rcp_f32_e32 %[denominator], %[denominator]\n"
36 "v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n"
37 "v_mul_f32_e32 %[result], %[numerator], %[denominator]"
38 : [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result)
39 : [softmax_scale] "s"(softmax_scale),
40 [logits] "v"(logits),
41 [logits_soft_cap_rcp] "v"(logits_soft_cap_rcp));
42 return result;
43#else
44 return softmax_scale * logits * rcp<float>(1.f + abs(logits * logits_soft_cap_rcp));
45#endif
46}
47} // namespace internal
48
49template <typename ImplMask>
51{
52 __device__ __host__ StandardAttentionParams(const ImplMask& impl_mask_, float sm_scale_)
53 : impl_mask(impl_mask_), sm_scale(sm_scale_)
54 {
55 }
56
57 const ImplMask& impl_mask;
58 float sm_scale;
59};
60
61template <typename ImplMask, bool UseExp2 = false>
63{
64 __device__
65 LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_)
66 : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_)
67 {
68 if(0.f < logits_soft_cap)
69 {
70 logits_soft_cap_rcp = __builtin_amdgcn_rcpf(logits_soft_cap);
71 }
72 else
73 {
75 }
76
77 // move computation here to prevent compiler from generating inefficient instruction
78 // sequence
79 if constexpr(UseExp2)
80 {
83 }
84 }
85
86 __host__
87 LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_)
88 : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_)
89 {
90 if(0.f < logits_soft_cap)
91 {
93 }
94 else
95 {
97 }
98
99 // move computation here to prevent compiler from generating inefficient instruction
100 // sequence
101 if constexpr(UseExp2)
102 {
105 }
106 }
107
108 __device__ __host__ LogitsSoftCapParams(const ImplMask& impl_mask_,
109 float sm_scale_,
110 float logits_soft_cap_,
111 float logits_soft_cap_rcp_)
112 : impl_mask(impl_mask_),
113 sm_scale(sm_scale_),
114 logits_soft_cap(logits_soft_cap_),
115 logits_soft_cap_rcp(logits_soft_cap_rcp_)
116 {
117 // move computation here to prevent compiler from generating inefficient instruction
118 // sequence
119 if constexpr(UseExp2)
120 {
123 }
124 }
125
126 const ImplMask& impl_mask;
127 float sm_scale;
130};
131
133{
134 __device__ __host__ StandardAttention() = default;
135
136 template <typename Params, typename T>
137 __device__ __forceinline__ T QueryTransform(const Params& params, T q) const
138 {
139 return type_convert<float>(q) * params.sm_scale;
140 }
141
144 template <typename Params, typename T>
145 __device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params& params,
146 T logits,
147 [[maybe_unused]] uint32_t batch_idx,
148 /*uint32_t qo_idx, uint32_t kv_idx,*/
149 [[maybe_unused]] uint32_t qo_head_idx,
150 [[maybe_unused]] uint32_t kv_head_idx) const
151 {
152 return logits;
153 }
154
155 template <typename Params>
156 __device__ __forceinline__ bool LogitsMask(const Params& params,
157 [[maybe_unused]] uint32_t batch_idx,
158 uint32_t qo_idx,
159 uint32_t kv_idx,
160 [[maybe_unused]] uint32_t qo_head_idx,
161 [[maybe_unused]] uint32_t kv_head_idx) const
162 {
163 return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
164 }
165};
166
167template <bool UseExp2 = false>
169{
170 __device__ __host__ LogitsSoftCap() = default;
171
172 template <typename Params, typename T>
173 __device__ __forceinline__ T QueryTransform(const Params& params, T q) const
174 {
175 if constexpr(UseExp2)
176 {
177 return q;
178 }
179 else
180 {
181 return type_convert<float>(q) * params.sm_scale;
182 }
183 }
184
187 template <typename Params, typename T>
188 __device__ __forceinline__ T LogitsTransform(const Params& params,
189 T logits,
190 [[maybe_unused]] uint32_t batch_idx,
191 /*uint32_t qo_idx, uint32_t kv_idx,*/
192 [[maybe_unused]] uint32_t qo_head_idx,
193 [[maybe_unused]] uint32_t kv_head_idx) const
194 {
195 if constexpr(UseExp2)
196 {
197#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
198 return params.logits_soft_cap *
199 tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
200#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
202 params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
203#endif
204 }
205 else
206 {
207#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
208 return params.logits_soft_cap *
209 tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
210#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
211 return type_convert<float>(logits) *
212 rcp<float>(1.f + abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
213#endif
214 }
215 }
216
217 template <typename Params>
218 __device__ __forceinline__ bool LogitsMask(const Params& params,
219 [[maybe_unused]] uint32_t batch_idx,
220 uint32_t qo_idx,
221 uint32_t kv_idx,
222 [[maybe_unused]] uint32_t qo_head_idx,
223 [[maybe_unused]] uint32_t kv_head_idx) const
224 {
225 return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
226 }
227};
228
229constexpr uint32_t CUSTOM_MASK = 1U;
232constexpr uint32_t ALIBI = 8U;
233
234template <uint32_t VARIANT_CODE, bool UseExp2 = false>
236{
237 static constexpr bool use_exp2 = UseExp2;
238
239 static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0;
240
241 __device__ __host__ ComposedAttention() = default;
242
243 template <typename Params, typename T>
244 __device__ __forceinline__ T QueryTransform(const Params& params, T q) const
245 {
246 if constexpr(use_logits_soft_cap && UseExp2)
247 {
248 return q;
249 }
250 return type_convert<float>(q) * params.sm_scale;
251 }
252
255 template <typename Params, typename T>
256 __device__ __forceinline__ T LogitsTransform(const Params& params,
257 T logits,
258 [[maybe_unused]] uint32_t batch_idx,
259 /*uint32_t qo_idx, uint32_t kv_idx,*/
260 [[maybe_unused]] uint32_t qo_head_idx,
261 [[maybe_unused]] uint32_t kv_head_idx) const
262 {
263 if constexpr(use_logits_soft_cap)
264 {
265 if constexpr(UseExp2)
266 {
267#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
268 return params.logits_soft_cap *
269 tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
270#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
272 params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
273#endif
274 }
275 else
276 {
277#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
278 return params.logits_soft_cap *
279 tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
280#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
281 return type_convert<float>(logits) *
282 rcp<float>(1.f +
283 abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
284#endif
285 }
286 }
287 return logits;
288 }
289
290 template <typename Params>
291 __device__ __forceinline__ bool LogitsMask(const Params& params,
292 [[maybe_unused]] uint32_t batch_idx,
293 uint32_t qo_idx,
294 uint32_t kv_idx,
295 [[maybe_unused]] uint32_t qo_head_idx,
296 [[maybe_unused]] uint32_t kv_head_idx) const
297 {
298 return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
299 }
300};
301
302} // namespace ck_tile
__device__ float exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
Definition variants.hpp:25
Definition tile/core/algorithm/cluster_descriptor.hpp:13
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
constexpr uint32_t ALIBI
Definition variants.hpp:232
CK_TILE_DEVICE float tanh_fast< float >(float x)
Definition tile/core/numeric/math.hpp:1394
constexpr uint32_t LOGITS_SOFT_CAP
Definition variants.hpp:231
constexpr uint32_t CUSTOM_MASK
Definition variants.hpp:229
constexpr T log2e_rcp_v
Definition tile/core/numeric/math.hpp:491
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition bfloat16.hpp:400
constexpr uint32_t SLIDING_WINDOW
Definition variants.hpp:230
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST T rcp(T x)
Definition tile/core/numeric/math.hpp:896
Definition allocators.h:459
unsigned int uint32_t
Definition stdint.h:126
__device__ __host__ ComposedAttention()=default
__device__ __forceinline__ bool LogitsMask(const Params &params, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
Definition variants.hpp:291
static constexpr bool use_exp2
Definition variants.hpp:237
__device__ __forceinline__ T LogitsTransform(const Params &params, T logits, uint32_t batch_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
Definition variants.hpp:256
__device__ __forceinline__ T QueryTransform(const Params &params, T q) const
Definition variants.hpp:244
static constexpr bool use_logits_soft_cap
Definition variants.hpp:239
__device__ __host__ LogitsSoftCap()=default
__device__ __forceinline__ bool LogitsMask(const Params &params, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
Definition variants.hpp:218
__device__ __forceinline__ T LogitsTransform(const Params &params, T logits, uint32_t batch_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
Definition variants.hpp:188
__device__ __forceinline__ T QueryTransform(const Params &params, T q) const
Definition variants.hpp:173
float logits_soft_cap_rcp
Definition variants.hpp:129
__host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition variants.hpp:87
const ImplMask & impl_mask
Definition variants.hpp:126
__device__ __host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_, float logits_soft_cap_rcp_)
Definition variants.hpp:108
__device__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition variants.hpp:65
float sm_scale
Definition variants.hpp:127
float logits_soft_cap
Definition variants.hpp:128
__device__ __forceinline__ T QueryTransform(const Params &params, T q) const
Definition variants.hpp:137
__device__ __host__ StandardAttention()=default
__device__ __forceinline__ T LogitsTransform(const Params &params, T logits, uint32_t batch_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
Definition variants.hpp:145
__device__ __forceinline__ bool LogitsMask(const Params &params, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
Definition variants.hpp:156
const ImplMask & impl_mask
Definition variants.hpp:57
__device__ __host__ StandardAttentionParams(const ImplMask &impl_mask_, float sm_scale_)
Definition variants.hpp:52
float sm_scale
Definition variants.hpp:58