block_fmha_fwd_splitkv_combine_pipeline.hpp Source File#
block_fmha_fwd_splitkv_combine_pipeline.hpp
Go to the documentation of this file.
Definition arch.hpp:385
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:48
remove_cvref_t< Problem_ > Problem
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:49
CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindowTmp &lse_acc_dram_block_window_tmp, const OaccDramBlockWindowTmp &o_acc_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const OaccElementFunction &o_acc_element_func, index_t num_splits, void *smem_ptr) const
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:115
static constexpr index_t kAlignmentOacc
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:73
static constexpr index_t kNumWarps
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:56
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:53
remove_cvref_t< Policy_ > Policy
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:50
static constexpr const char * name
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:102
CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow &lse_acc_dram_block_window, const OaccDramBlockWindow &o_acc_dram_block_window, LSEDramBlockWindow &lse_dram_block_window, index_t num_splits, void *smem_ptr) const
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:368
static constexpr index_t kBlockPerCu
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:79
static constexpr index_t kMaxSplits
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:67
static constexpr bool kStoreLSE
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:66
static constexpr index_t kM0
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:60
static constexpr bool kIsGroupMode
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:63
static constexpr index_t kBlockSize
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:57
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:54
static constexpr index_t kAlignmentO
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:76
static constexpr index_t kN1
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:61
static constexpr index_t kHeadDimV
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:59
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:104
static constexpr index_t kAlignmentLSEacc
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:71
static constexpr bool kPadHeadDimV
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:65
static constexpr bool kPadSeqLenQ
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:64
static constexpr index_t kAlignmentLSE
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:69
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:52
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:13
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43