fmha_fwd_appendkv_kernel.hpp Source File#
fmha_fwd_appendkv_kernel.hpp
Go to the documentation of this file.
61 "b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" +
62 _TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
63 "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
64 + (!kApplyRoPE ? _SS_("") : (_SS_("_") + RotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name))
#define _TS_
#define _SS_
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition page_block_navigator.hpp:333
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition fmha_fwd_appendkv_kernel.hpp:81
ck_tile::index_t stride_q
Definition fmha_fwd_appendkv_kernel.hpp:101
const int32_t * seqlen_k_ptr
Definition fmha_fwd_appendkv_kernel.hpp:88
const void * knew_ptr
Definition fmha_fwd_appendkv_kernel.hpp:84
ck_tile::index_t stride_k
Definition fmha_fwd_appendkv_kernel.hpp:102
ck_tile::index_t batch_stride_knew
Definition fmha_fwd_appendkv_kernel.hpp:115
ck_tile::index_t batch_stride_v
Definition fmha_fwd_appendkv_kernel.hpp:116
ck_tile::index_t nhead_stride_knew
Definition fmha_fwd_appendkv_kernel.hpp:109
ck_tile::index_t nhead_stride_vnew
Definition fmha_fwd_appendkv_kernel.hpp:111
ck_tile::index_t batch_stride_k
Definition fmha_fwd_appendkv_kernel.hpp:114
ck_tile::index_t stride_knew
Definition fmha_fwd_appendkv_kernel.hpp:103
ck_tile::index_t nhead_stride_q
Definition fmha_fwd_appendkv_kernel.hpp:107
ck_tile::index_t hdim_q
Definition fmha_fwd_appendkv_kernel.hpp:93
ck_tile::index_t stride_v
Definition fmha_fwd_appendkv_kernel.hpp:104
ck_tile::index_t batch_stride_q
Definition fmha_fwd_appendkv_kernel.hpp:113
ck_tile::index_t nhead_stride_k
Definition fmha_fwd_appendkv_kernel.hpp:108
ck_tile::index_t hdim_v
Definition fmha_fwd_appendkv_kernel.hpp:94
ck_tile::index_t batch_stride_vnew
Definition fmha_fwd_appendkv_kernel.hpp:117
void * q_ptr
Definition fmha_fwd_appendkv_kernel.hpp:82
ck_tile::index_t stride_vnew
Definition fmha_fwd_appendkv_kernel.hpp:105
void * v_ptr
Definition fmha_fwd_appendkv_kernel.hpp:85
const void * vnew_ptr
Definition fmha_fwd_appendkv_kernel.hpp:86
ck_tile::index_t nhead_stride_v
Definition fmha_fwd_appendkv_kernel.hpp:110
void * k_ptr
Definition fmha_fwd_appendkv_kernel.hpp:83
ck_tile::index_t seqlen_k
Definition fmha_fwd_appendkv_kernel.hpp:91
ck_tile::index_t seqlen_q
Definition fmha_fwd_appendkv_kernel.hpp:90
ck_tile::index_t seqlen_knew
Definition fmha_fwd_appendkv_kernel.hpp:92
ck_tile::index_t num_head_q
Definition fmha_fwd_appendkv_kernel.hpp:96
ck_tile::index_t nhead_ratio_qk
Definition fmha_fwd_appendkv_kernel.hpp:99
Definition fmha_fwd_appendkv_kernel.hpp:136
const int32_t * cache_batch_idx
Definition fmha_fwd_appendkv_kernel.hpp:137
Definition fmha_fwd_appendkv_kernel.hpp:74
Definition fmha_fwd_appendkv_kernel.hpp:143
Definition fmha_fwd_appendkv_kernel.hpp:129
ck_tile::index_t batch_stride_block_table
Definition fmha_fwd_appendkv_kernel.hpp:131
const int32_t * block_table_ptr
Definition fmha_fwd_appendkv_kernel.hpp:130
ck_tile::index_t page_block_size
Definition fmha_fwd_appendkv_kernel.hpp:132
Definition fmha_fwd_appendkv_kernel.hpp:121
ck_tile::index_t rotary_dim
Definition fmha_fwd_appendkv_kernel.hpp:124
const void * rotary_sin_ptr
Definition fmha_fwd_appendkv_kernel.hpp:123
bool has_mask
Definition fmha_fwd_appendkv_kernel.hpp:125
const void * rotary_cos_ptr
Definition fmha_fwd_appendkv_kernel.hpp:122
static constexpr const char * name
Definition fmha_fwd_appendkv_kernel.hpp:40
static constexpr const char * name
Definition fmha_fwd_appendkv_kernel.hpp:42
static constexpr const char * name
Definition fmha_fwd_appendkv_kernel.hpp:39
static constexpr const char * name
Definition fmha_fwd_appendkv_kernel.hpp:41
static constexpr const char * name
Definition fmha_fwd_appendkv_kernel.hpp:38
Definition fmha_fwd_appendkv_kernel.hpp:37
Definition fmha_fwd_appendkv_kernel.hpp:15
static constexpr bool kPadHeadDimV
Definition fmha_fwd_appendkv_kernel.hpp:34
static constexpr ck_tile::index_t kBlockPerCuInput
Definition fmha_fwd_appendkv_kernel.hpp:21
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition fmha_fwd_appendkv_kernel.hpp:16
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition fmha_fwd_appendkv_kernel.hpp:23
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition fmha_fwd_appendkv_kernel.hpp:24
static constexpr bool kPadSeqLenQ
Definition fmha_fwd_appendkv_kernel.hpp:31
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fmha_fwd_appendkv_kernel.hpp:260
static CK_TILE_HOST dim3 BlockSize()
Definition fmha_fwd_appendkv_kernel.hpp:258
static CK_TILE_HOST constexpr Kargs MakeKargs(void *q_ptr, void *k_ptr, const void *knew_ptr, void *v_ptr, const void *vnew_ptr, ck_tile::index_t seqlen_q, const void *seqlen_k_ptr, ck_tile::index_t seqlen_knew, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *rotary_cos_ptr, const void *rotary_sin_ptr, ck_tile::index_t rotary_dim, bool has_mask, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_knew, ck_tile::index_t stride_v, ck_tile::index_t stride_vnew, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_knew, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_vnew, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_knew, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_vnew)
Definition fmha_fwd_appendkv_kernel.hpp:146
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition fmha_fwd_appendkv_kernel.hpp:27
static CK_TILE_HOST constexpr auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_knew)
Definition fmha_fwd_appendkv_kernel.hpp:237
static CK_TILE_DEVICE constexpr auto GetTileIndex(const Kargs &)
Definition fmha_fwd_appendkv_kernel.hpp:249
static constexpr ck_tile::index_t kBlockPerCu
Definition fmha_fwd_appendkv_kernel.hpp:18
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition fmha_fwd_appendkv_kernel.hpp:25
static constexpr bool kPadHeadDimQ
Definition fmha_fwd_appendkv_kernel.hpp:33
static constexpr bool kPadSeqLenK
Definition fmha_fwd_appendkv_kernel.hpp:32
static constexpr bool kIsPagedKV
Definition fmha_fwd_appendkv_kernel.hpp:29
static CK_TILE_HOST std::string GetName()
Definition fmha_fwd_appendkv_kernel.hpp:45
static constexpr bool kApplyRoPE
Definition fmha_fwd_appendkv_kernel.hpp:28
static constexpr ck_tile::index_t kBlockSize
Definition fmha_fwd_appendkv_kernel.hpp:17
Definition block_rotary_embedding.hpp:19
Definition tile/core/container/sequence.hpp:49