block_fmha_pipeline_qr_ks_vs_fp8.hpp Source File#
block_fmha_pipeline_qr_ks_vs_fp8.hpp
Go to the documentation of this file.
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
Definition block_dropout.hpp:53
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:16
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:102
static constexpr bool kHasDropout
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:53
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:29
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:73
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:24
static constexpr index_t kM0
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:39
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:37
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:23
static constexpr index_t kAlignmentO
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:68
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:51
static constexpr index_t kN0
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:40
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:21
static constexpr index_t kAlignmentK
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:59
static constexpr const char * name
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:100
static constexpr index_t kK0
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:41
remove_cvref_t< Problem_ > Problem
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:17
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:26
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:44
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:32
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:28
static constexpr index_t kAlignmentQ
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:57
static constexpr index_t kK1
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:43
remove_cvref_t< Policy_ > Policy
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:18
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:49
static constexpr index_t kAlignmentV
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:61
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:19
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &, LSEDramBlockWindowTmp &, FmhaMask mask, PositionEncoding, float scale_s, float descale_qk, float descale_sv, void *smem_ptr, BlockDropout &) const
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:115
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:46
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:27
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:22
static constexpr index_t kAlignmentBias
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:70
static constexpr bool kQLoadOnce
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:34
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:33
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:48
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:52
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:50
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:47
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:20
static constexpr index_t kN1
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:42
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:25
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:30
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49