gemm_bquant_pipeline_ag_bg_cr_base.hpp Source File

gemm_bquant_pipeline_ag_bg_cr_base.hpp Source File#

Composable Kernel: gemm_bquant_pipeline_ag_bg_cr_base.hpp Source File
gemm_bquant_pipeline_ag_bg_cr_base.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12template <typename Problem, typename Policy>
14{
16 using ADataType = typename Base::ADataType;
17 using ALayout = typename Base::ALayout;
18 using BDataType = typename Base::BDataType;
19 using BLayout = typename Base::BLayout;
22
24
25 static constexpr index_t MPerBlock = BlockGemmShape::kM;
26 static constexpr index_t NPerBlock = BlockGemmShape::kN;
27 static constexpr index_t KPerBlock = BlockGemmShape::kK;
28
29 static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
30 static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
31
32 static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
33 static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
34
35 static_assert(NPerBlock % QuantGroupSize::kN == 0,
36 "NPerBlock must be a multiple of QuantGroupSize::kN");
37 static_assert(KPerBlock % QuantGroupSize::kK == 0,
38 "KPerBlock must be a multiple of QuantGroupSize::kK");
39
40 // Create DRAM tile window for BQ
41 template <typename BQDramBlockWindowTmp>
42 CK_TILE_DEVICE constexpr auto
43 GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const
44 {
45 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
46
47 using YPerTile = number<NPerBlockBQ>;
48 using XPerTile = number<KPerBlockBQ>;
49
50 auto bq_copy_dram_window =
51 make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
52 make_tuple(YPerTile(), XPerTile()),
53 bq_dram_block_window_tmp.get_window_origin(),
54 Policy::template MakeBQDramTileDistribution<Problem>());
55 return bq_copy_dram_window;
56 }
57};
58
59} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:14
static constexpr index_t KPerBlock
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:27
GemmPipelineAgBgCrImplBase< Problem, Policy > Base
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:15
static constexpr index_t KPerBlockBQ
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:30
typename Base::BlockGemmShape BlockGemmShape
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:20
static constexpr index_t NPerBlockBQ
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:29
remove_cvref_t< typename Problem::QuantGroupSize > QuantGroupSize
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:21
static constexpr index_t NPerBlock
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:26
CK_TILE_DEVICE constexpr auto GetBQDramLoadWindow(const BQDramBlockWindowTmp &bq_dram_block_window_tmp) const
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:43
typename Base::BLayout BLayout
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:19
typename Base::BDataType BDataType
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:18
typename Base::ADataType ADataType
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:16
static constexpr index_t MPerBlock
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:25
typename Base::ALayout ALayout
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:17
remove_cvref_t< typename Problem::BQLayout > BQLayout
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:23
Definition gemm_pipeline_ag_bg_cr_base.hpp:13
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsLayout > > BLayout
Definition gemm_pipeline_ag_bg_cr_base.hpp:23
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:22
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsLayout > > ALayout
Definition gemm_pipeline_ag_bg_cr_base.hpp:21
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_pipeline_ag_bg_cr_base.hpp:18
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataType > > ADataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:20