block_fmha_pipeline_problem.hpp Source File

block_fmha_pipeline_problem.hpp Source File#

Composable Kernel: block_fmha_pipeline_problem.hpp Source File
block_fmha_pipeline_problem.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11template <typename QDataType_,
12 typename KDataType_,
13 typename VDataType_,
14 typename SaccDataType_,
15 typename SMPLComputeDataType_,
16 typename BiasDataType_,
17 typename RandValOutputDataType_,
18 typename LSEDataType_,
19 typename PDataType_,
20 typename OaccDataType_,
21 typename ODataType_,
22 typename BlockFmhaShape_,
23 bool kIsGroupMode_,
24 typename AttentionVariant_,
25 typename FmhaMask_,
26 bool kUseTrLoad_,
27 typename Traits_>
29{
45
46 static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
47 static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
48 static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
49
50 static constexpr bool kIsGroupMode = kIsGroupMode_;
51 static constexpr bool kUseTrLoad = kUseTrLoad_;
52
53 // attributes from traits
54 static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
55 static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
56 static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
57 static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
58 static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
59 static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
60 static constexpr auto BiasEnum = Traits::BiasEnum;
61 static constexpr bool kStoreLSE = Traits::kStoreLSE;
62 static constexpr bool kHasDropout = Traits::kHasDropout;
63 static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
64 static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
65};
66
67template <typename QDataType_,
68 typename KDataType_,
69 typename VDataType_,
70 typename SaccDataType_,
71 typename SMPLComputeDataType_,
72 typename BiasDataType_,
73 typename LSEDataType_,
74 typename PDataType_,
75 typename OaccDataType_,
76 typename ODataType_,
77 typename BlockFmhaShape_,
78 bool kIsGroupMode_,
79 typename AttentionVariant_,
80 typename FmhaMask_,
81 typename Traits_>
83{
98
99 static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
100 static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
101 static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
102
103 static constexpr bool kIsGroupMode = kIsGroupMode_;
104
105 // attributes from traits
106 static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
107 static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
108 static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
109 static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
110 static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
111 static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
112 static constexpr auto BiasEnum = Traits::BiasEnum;
113 static constexpr bool kStoreLSE = Traits::kStoreLSE;
114 static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
115 static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
116 static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
117};
118
119template <typename QDataType_,
120 typename KDataType_,
121 typename VDataType_,
122 typename SaccDataType_,
123 typename SMPLComputeDataType_,
124 typename BiasDataType_,
125 typename LSEDataType_,
126 typename PDataType_,
127 typename OaccDataType_,
128 typename ODataType_,
129 typename BlockFmhaShape_,
130 bool kIsGroupMode_,
131 typename AttentionVariant_,
132 typename FmhaMask_,
133 typename Traits_>
135{
150
151 static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
152 static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
153 static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
154
155 static constexpr bool kIsGroupMode = kIsGroupMode_;
156
157 // attributes from traits
158 static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
159 static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
160 static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
161 static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
162 static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
163 static constexpr auto BiasEnum = Traits::BiasEnum;
164 static constexpr bool kStoreLSE = Traits::kStoreLSE;
165 static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
166 static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
167 static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
168 static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
169 static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
170};
171
172// extract tile size attributes to remove dependency on traits
173template <typename OaccDataType_, ck_tile::index_t kN1_>
175{
176 static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_);
177
178 static constexpr index_t kN1 = kN1_;
179 static constexpr index_t NThreads = kN1 / MaxVectorSize;
180 static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp
181};
182
183template <typename LSEDataType_,
184 typename OaccDataType_,
185 typename ODataType_,
186 index_t HeadDimV_,
187 bool kIsGroupMode_,
188 ck_tile::index_t kN1_,
189 typename Traits_>
191 : BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>
192{
194
199
200 static_assert(std::is_same_v<LSEDataType, OaccDataType>);
201
202 static constexpr index_t kHeadDimV = HeadDimV_;
203 static constexpr bool kIsGroupMode = kIsGroupMode_;
204
205 using BaseType::kM0;
206 using BaseType::kN1;
207 using BaseType::NThreads;
208
209 static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
210
211 // attributes from traits
212 static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
213 static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
214 static constexpr bool kStoreLSE = Traits::kStoreLSE;
215 static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
216 static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
217 static constexpr index_t kMaxSplits = Traits::kMaxSplits;
218 static_assert(8 <= kMaxSplits);
219
220 static constexpr index_t kNumWarps = 4;
222
223 static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
224 (kM0 * kMaxSplits) % get_warp_size() == 0);
225};
226
227template <typename QDataType_,
228 typename KDataType_,
229 typename VDataType_,
230 index_t kM0_,
231 index_t kN0_,
232 index_t kK0_,
233 index_t kN1_,
234 bool kIsVLayoutRowMajor_,
235 RotaryEmbeddingEnum RotaryEnum_,
236 bool kIsPagedKV_,
237 typename Traits_>
239{
244
245 static constexpr index_t kBlockSize = 256;
246
247 static constexpr index_t kM0 = kM0_;
248 static constexpr index_t kN0 = kN0_;
249 static constexpr index_t kK0 = kK0_;
250 static constexpr index_t kN1 = kN1_;
251
252 using VLayout = std::conditional_t<kIsVLayoutRowMajor_,
255
256 static constexpr auto RotaryEnum = RotaryEnum_;
257 static constexpr bool kIsPagedKV = kIsPagedKV_;
258
259 // attributes from traits
260 static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
261 static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
262 static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
263 static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
264 static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
265};
266
267template <typename QDataType_,
268 typename KDataType_,
269 typename VDataType_,
270 typename SaccDataType_,
271 typename SMPLComputeDataType_,
272 typename LSEDataType_,
273 typename PDataType_,
274 typename OaccDataType_,
275 typename ODataType_,
276 typename BlockFmhaShape_,
277 bool kIsGroupMode_,
278 typename FmhaMask_,
279 typename Traits_>
281{
294
295 static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
296 static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
297 static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
298
299 static constexpr bool kIsGroupMode = kIsGroupMode_;
300
301 // attributes from traits
302 static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
303 static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
304 static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
305 static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
306 static constexpr bool kStoreLSE = Traits::kStoreLSE;
307 static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
308};
309
310} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
RotaryEmbeddingEnum
Definition block_rotary_embedding.hpp:12
int32_t index_t
Definition integer.hpp:9
Definition block_fmha_pipeline_problem.hpp:239
std::conditional_t< kIsVLayoutRowMajor_, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition block_fmha_pipeline_problem.hpp:252
remove_cvref_t< QDataType_ > QDataType
Definition block_fmha_pipeline_problem.hpp:240
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_problem.hpp:261
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_problem.hpp:262
static constexpr auto RotaryEnum
Definition block_fmha_pipeline_problem.hpp:256
static constexpr index_t kK0
Definition block_fmha_pipeline_problem.hpp:249
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_problem.hpp:260
remove_cvref_t< Traits_ > Traits
Definition block_fmha_pipeline_problem.hpp:243
static constexpr bool kIsPagedKV
Definition block_fmha_pipeline_problem.hpp:257
remove_cvref_t< VDataType_ > VDataType
Definition block_fmha_pipeline_problem.hpp:242
static constexpr index_t kM0
Definition block_fmha_pipeline_problem.hpp:247
static constexpr index_t kN1
Definition block_fmha_pipeline_problem.hpp:250
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_problem.hpp:264
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_problem.hpp:245
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_problem.hpp:263
remove_cvref_t< KDataType_ > KDataType
Definition block_fmha_pipeline_problem.hpp:241
static constexpr index_t kN0
Definition block_fmha_pipeline_problem.hpp:248
Definition block_fmha_pipeline_problem.hpp:83
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition block_fmha_pipeline_problem.hpp:88
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_problem.hpp:108
static constexpr bool kDoFp8StaticQuant
Definition block_fmha_pipeline_problem.hpp:114
remove_cvref_t< Traits_ > Traits
Definition block_fmha_pipeline_problem.hpp:97
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_problem.hpp:116
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition block_fmha_pipeline_problem.hpp:94
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_problem.hpp:109
remove_cvref_t< VDataType_ > VDataType
Definition block_fmha_pipeline_problem.hpp:86
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition block_fmha_pipeline_problem.hpp:95
remove_cvref_t< FmhaMask_ > FmhaMask
Definition block_fmha_pipeline_problem.hpp:96
remove_cvref_t< PDataType_ > PDataType
Definition block_fmha_pipeline_problem.hpp:91
remove_cvref_t< SaccDataType_ > SaccDataType
Definition block_fmha_pipeline_problem.hpp:87
remove_cvref_t< KDataType_ > KDataType
Definition block_fmha_pipeline_problem.hpp:85
static constexpr auto BiasEnum
Definition block_fmha_pipeline_problem.hpp:112
remove_cvref_t< QDataType_ > QDataType
Definition block_fmha_pipeline_problem.hpp:84
remove_cvref_t< BiasDataType_ > BiasDataType
Definition block_fmha_pipeline_problem.hpp:89
remove_cvref_t< OaccDataType_ > OaccDataType
Definition block_fmha_pipeline_problem.hpp:92
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_pipeline_problem.hpp:110
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_problem.hpp:106
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_problem.hpp:101
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_problem.hpp:107
remove_cvref_t< ODataType_ > ODataType
Definition block_fmha_pipeline_problem.hpp:93
static constexpr index_t kNumGemm0Warps
Definition block_fmha_pipeline_problem.hpp:99
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_problem.hpp:103
remove_cvref_t< LSEDataType_ > LSEDataType
Definition block_fmha_pipeline_problem.hpp:90
static constexpr bool kIsPagedKV
Definition block_fmha_pipeline_problem.hpp:115
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_problem.hpp:113
static constexpr bool kSkipMinSeqlenQ
Definition block_fmha_pipeline_problem.hpp:111
static constexpr index_t kNumGemm1Warps
Definition block_fmha_pipeline_problem.hpp:100
Definition block_fmha_pipeline_problem.hpp:135
static constexpr bool kHasUnevenSplits
Definition block_fmha_pipeline_problem.hpp:167
remove_cvref_t< VDataType_ > VDataType
Definition block_fmha_pipeline_problem.hpp:138
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_pipeline_problem.hpp:162
remove_cvref_t< FmhaMask_ > FmhaMask
Definition block_fmha_pipeline_problem.hpp:148
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_problem.hpp:160
static constexpr bool kDoFp8StaticQuant
Definition block_fmha_pipeline_problem.hpp:165
static constexpr index_t kNumGemm0Warps
Definition block_fmha_pipeline_problem.hpp:151
remove_cvref_t< QDataType_ > QDataType
Definition block_fmha_pipeline_problem.hpp:136
remove_cvref_t< OaccDataType_ > OaccDataType
Definition block_fmha_pipeline_problem.hpp:144
remove_cvref_t< LSEDataType_ > LSEDataType
Definition block_fmha_pipeline_problem.hpp:142
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_problem.hpp:155
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition block_fmha_pipeline_problem.hpp:168
static constexpr index_t kNumGemm1Warps
Definition block_fmha_pipeline_problem.hpp:152
remove_cvref_t< SaccDataType_ > SaccDataType
Definition block_fmha_pipeline_problem.hpp:139
static constexpr bool kIsPagedKV
Definition block_fmha_pipeline_problem.hpp:166
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition block_fmha_pipeline_problem.hpp:140
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition block_fmha_pipeline_problem.hpp:146
remove_cvref_t< KDataType_ > KDataType
Definition block_fmha_pipeline_problem.hpp:137
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_problem.hpp:158
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_problem.hpp:153
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_problem.hpp:169
remove_cvref_t< PDataType_ > PDataType
Definition block_fmha_pipeline_problem.hpp:143
remove_cvref_t< ODataType_ > ODataType
Definition block_fmha_pipeline_problem.hpp:145
static constexpr auto BiasEnum
Definition block_fmha_pipeline_problem.hpp:163
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition block_fmha_pipeline_problem.hpp:147
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_problem.hpp:159
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_problem.hpp:164
remove_cvref_t< BiasDataType_ > BiasDataType
Definition block_fmha_pipeline_problem.hpp:141
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_problem.hpp:161
remove_cvref_t< Traits_ > Traits
Definition block_fmha_pipeline_problem.hpp:149
Definition block_fmha_pipeline_problem.hpp:281
remove_cvref_t< ODataType_ > ODataType
Definition block_fmha_pipeline_problem.hpp:290
remove_cvref_t< LSEDataType_ > LSEDataType
Definition block_fmha_pipeline_problem.hpp:287
remove_cvref_t< SaccDataType_ > SaccDataType
Definition block_fmha_pipeline_problem.hpp:285
static constexpr index_t kNumGemm0Warps
Definition block_fmha_pipeline_problem.hpp:295
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_problem.hpp:305
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_problem.hpp:299
remove_cvref_t< OaccDataType_ > OaccDataType
Definition block_fmha_pipeline_problem.hpp:289
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_problem.hpp:297
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_problem.hpp:302
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition block_fmha_pipeline_problem.hpp:291
remove_cvref_t< KDataType_ > KDataType
Definition block_fmha_pipeline_problem.hpp:283
remove_cvref_t< QDataType_ > QDataType
Definition block_fmha_pipeline_problem.hpp:282
remove_cvref_t< VDataType_ > VDataType
Definition block_fmha_pipeline_problem.hpp:284
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_problem.hpp:303
remove_cvref_t< FmhaMask_ > FmhaMask
Definition block_fmha_pipeline_problem.hpp:292
static constexpr index_t kNumGemm1Warps
Definition block_fmha_pipeline_problem.hpp:296
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_problem.hpp:307
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_problem.hpp:304
remove_cvref_t< PDataType_ > PDataType
Definition block_fmha_pipeline_problem.hpp:288
remove_cvref_t< Traits_ > Traits
Definition block_fmha_pipeline_problem.hpp:293
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_problem.hpp:306
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition block_fmha_pipeline_problem.hpp:286
Definition block_fmha_pipeline_problem.hpp:29
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_problem.hpp:55
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition block_fmha_pipeline_problem.hpp:42
static constexpr bool kHasDropout
Definition block_fmha_pipeline_problem.hpp:62
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_problem.hpp:61
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_pipeline_problem.hpp:58
static constexpr auto BiasEnum
Definition block_fmha_pipeline_problem.hpp:60
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition block_fmha_pipeline_problem.hpp:41
remove_cvref_t< OaccDataType_ > OaccDataType
Definition block_fmha_pipeline_problem.hpp:39
static constexpr bool kSkipMinSeqlenQ
Definition block_fmha_pipeline_problem.hpp:59
static constexpr index_t kNumGemm0Warps
Definition block_fmha_pipeline_problem.hpp:46
remove_cvref_t< Traits_ > Traits
Definition block_fmha_pipeline_problem.hpp:44
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_problem.hpp:56
remove_cvref_t< SaccDataType_ > SaccDataType
Definition block_fmha_pipeline_problem.hpp:33
remove_cvref_t< LSEDataType_ > LSEDataType
Definition block_fmha_pipeline_problem.hpp:37
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_problem.hpp:50
remove_cvref_t< KDataType_ > KDataType
Definition block_fmha_pipeline_problem.hpp:31
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_problem.hpp:64
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition block_fmha_pipeline_problem.hpp:36
remove_cvref_t< ODataType_ > ODataType
Definition block_fmha_pipeline_problem.hpp:40
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_problem.hpp:48
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_problem.hpp:57
remove_cvref_t< PDataType_ > PDataType
Definition block_fmha_pipeline_problem.hpp:38
remove_cvref_t< VDataType_ > VDataType
Definition block_fmha_pipeline_problem.hpp:32
static constexpr bool kUseTrLoad
Definition block_fmha_pipeline_problem.hpp:51
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition block_fmha_pipeline_problem.hpp:34
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_problem.hpp:54
remove_cvref_t< BiasDataType_ > BiasDataType
Definition block_fmha_pipeline_problem.hpp:35
remove_cvref_t< FmhaMask_ > FmhaMask
Definition block_fmha_pipeline_problem.hpp:43
static constexpr index_t kNumGemm1Warps
Definition block_fmha_pipeline_problem.hpp:47
static constexpr bool kDoFp8StaticQuant
Definition block_fmha_pipeline_problem.hpp:63
remove_cvref_t< QDataType_ > QDataType
Definition block_fmha_pipeline_problem.hpp:30
Definition block_fmha_pipeline_problem.hpp:192
BlockFmhaSplitKVCombinePipelineTileSizes< OaccDataType_, kN1_ > BaseType
Definition block_fmha_pipeline_problem.hpp:193
remove_cvref_t< ODataType_ > ODataType
Definition block_fmha_pipeline_problem.hpp:197
static constexpr index_t kNumWarps
Definition block_fmha_pipeline_problem.hpp:220
remove_cvref_t< Traits_ > Traits
Definition block_fmha_pipeline_problem.hpp:198
static constexpr index_t kHeadDimV
Definition block_fmha_pipeline_problem.hpp:202
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_problem.hpp:221
static constexpr index_t kM0
Definition block_fmha_pipeline_problem.hpp:180
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_problem.hpp:216
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_problem.hpp:203
static constexpr index_t kMaxSplits
Definition block_fmha_pipeline_problem.hpp:217
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_problem.hpp:213
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_problem.hpp:214
static constexpr bool kDoFp8StaticQuant
Definition block_fmha_pipeline_problem.hpp:215
remove_cvref_t< LSEDataType_ > LSEDataType
Definition block_fmha_pipeline_problem.hpp:195
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_problem.hpp:212
static constexpr index_t kN1
Definition block_fmha_pipeline_problem.hpp:178
remove_cvref_t< OaccDataType_ > OaccDataType
Definition block_fmha_pipeline_problem.hpp:196
Definition block_fmha_pipeline_problem.hpp:175
static constexpr index_t NThreads
Definition block_fmha_pipeline_problem.hpp:179
static constexpr index_t kM0
Definition block_fmha_pipeline_problem.hpp:180
static constexpr index_t MaxVectorSize
Definition block_fmha_pipeline_problem.hpp:176
static constexpr index_t kN1
Definition block_fmha_pipeline_problem.hpp:178
Definition tile/ops/common/tensor_layout.hpp:22
Definition tile/ops/common/tensor_layout.hpp:17