fused_moegemm_shape.hpp Source File

fused_moegemm_shape.hpp Source File#

Composable Kernel: fused_moegemm_shape.hpp Source File
fused_moegemm_shape.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7
8namespace ck_tile {
9
10/*
11tensors:
121. act (A): input feature map
132. gate (G): B matrix for first gemm, output will do activation(Silu)
143. up (U): B matrix for first gemm
154. down (D): B matrix for second gemm
16 N1
17 / \
18 +----------+ |
19 | Down | |
20 x----------x |
21 hidden hidden K1 | | |
22 N0 N0 x----------x |
23 | +------x-----x------+------x-----x------+ | | |
24 dim | | Gate | | | Up | | | | | |
25 contiguous | | | | | | | | | | |
26 | | | | | | | | | | |
27 v +------x-----x------+------x-----x------+ +----------+ V
28 K0 | | | | | contiguous
29 / \ v v v v |
30 +---------+ +------x-----x------+------x-----x------+ |
31M0 | A | | | | | | | | |
32 +---------+ +------x-----x------+------x-----x------+ |
33 ----------> | | |
34 contiguous | V V
35 | x-----x +----------+
36 +------------> M1 | Y | ---------> | Out(O) |
37 ACT x-----x +----------+
38 K1 = N0 dim
39
40* Note: Act could be Gelu/Silu/...
41* Note: some model does not have Up
42*/
43template <typename BlockTile_0_,
44 typename WarpPerBlock_0_,
45 typename WarpTile_0_,
46 typename BlockTile_1_,
47 typename WarpPerBlock_1_,
48 typename WarpTile_1_>
50{
57
58 static constexpr index_t NumWarps =
60
61 // TODO: we don't support half warps aound to 1 warp here
63
64 static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{});
65 static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{});
66 static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{});
67 static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{});
68 static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{});
69 static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(number<2>{});
70 static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{});
71 static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{});
72 static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{});
73
77 static_assert(Block_M0 % ThreadPerBlock_M0 == 0);
78 static_assert(Block_N0 % ThreadPerBlock_N0 == 0);
79 static_assert(Block_K0 % ThreadPerBlock_K0 == 0);
83
84 static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{});
85 static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{});
86 static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{});
87 static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{});
88 static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{});
89 static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{});
90 static constexpr index_t Warp_M1 = WarpTile_1::at(number<0>{});
91 static constexpr index_t Warp_N1 = WarpTile_1::at(number<1>{});
92 static constexpr index_t Warp_K1 = WarpTile_1::at(number<2>{});
93
97 static_assert(Block_M1 % ThreadPerBlock_M1 == 0);
98 static_assert(Block_N1 % ThreadPerBlock_N1 == 0);
99 static_assert(Block_K1 % ThreadPerBlock_K1 == 0);
103
104 static constexpr index_t BlockSize = get_warp_size() * NumWarps;
105
106 // some assert
107 static_assert(Block_M0 == Block_M1);
108 static_assert(Block_N0 == Block_K1 || (Block_N0 / 2) == Block_K1); // Gate Only or Gate+Up
109
110 // pre-shuffle tile size compute (assume only for B matrix)
111 // we flatten the each wave tile to a 1d linear tensor(at model loading time)
112 // e.g. originally we have Block_N*Block_K tile size, after pre-shuffle
113 // we can have Block_Nr*Block_Kr*Block_W, where Block_W is Warp_N*Warp_K,
114 // and Block_Nr=Block_N/Warp_N, Block_Kr=Block_K/Warp_K
115 static constexpr index_t Block_W0 = Warp_N0 * Warp_K0;
116 static constexpr index_t Block_Nr0 = Block_N0 / Warp_N0;
117 static constexpr index_t Block_Kr0 = Block_K0 / Warp_K0;
118 static constexpr index_t Block_W1 = Warp_N1 * Warp_K1;
119 static constexpr index_t Block_Nr1 = Block_N1 / Warp_N1;
120 static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1;
121
122 static_assert(Block_W0 == Block_W1);
123 // static_assert(Block_Nr0 == Block_Kr1);
124};
125} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition tile/core/container/sequence.hpp:982
Definition fused_moegemm_shape.hpp:50
static constexpr index_t ThreadPerBlock_N0
Definition fused_moegemm_shape.hpp:75
static constexpr index_t Repeat_K0
Definition fused_moegemm_shape.hpp:82
static constexpr index_t ThreadPerBlock_M1
Definition fused_moegemm_shape.hpp:94
static constexpr index_t BlockSize
Definition fused_moegemm_shape.hpp:104
static constexpr index_t Repeat_K1
Definition fused_moegemm_shape.hpp:102
remove_cvref_t< WarpTile_0_ > WarpTile_0
Definition fused_moegemm_shape.hpp:53
static constexpr index_t Block_N0
Definition fused_moegemm_shape.hpp:65
static constexpr index_t Warp_N0
Definition fused_moegemm_shape.hpp:71
static constexpr index_t Block_K0
Definition fused_moegemm_shape.hpp:66
remove_cvref_t< WarpPerBlock_1_ > WarpPerBlock_1
Definition fused_moegemm_shape.hpp:55
static constexpr index_t Warp_N1
Definition fused_moegemm_shape.hpp:91
static constexpr index_t Block_W0
Definition fused_moegemm_shape.hpp:115
static constexpr index_t WarpPerBlock_N0
Definition fused_moegemm_shape.hpp:68
static constexpr index_t WarpPerBlock_M0
Definition fused_moegemm_shape.hpp:67
static constexpr index_t Repeat_N0
Definition fused_moegemm_shape.hpp:81
static constexpr index_t Block_Nr1
Definition fused_moegemm_shape.hpp:119
static constexpr index_t Block_Kr1
Definition fused_moegemm_shape.hpp:120
static constexpr index_t Repeat_M1
Definition fused_moegemm_shape.hpp:100
static constexpr index_t NumWarps
Definition fused_moegemm_shape.hpp:58
static constexpr index_t Repeat_M0
Definition fused_moegemm_shape.hpp:80
remove_cvref_t< WarpTile_1_ > WarpTile_1
Definition fused_moegemm_shape.hpp:56
static constexpr index_t Block_M1
Definition fused_moegemm_shape.hpp:84
static constexpr index_t ThreadPerBlock_M0
Definition fused_moegemm_shape.hpp:74
static constexpr index_t Block_K1
Definition fused_moegemm_shape.hpp:86
static constexpr index_t Block_M0
Definition fused_moegemm_shape.hpp:64
remove_cvref_t< BlockTile_0_ > BlockTile_0
Definition fused_moegemm_shape.hpp:51
static constexpr index_t WarpPerBlock_K1
Definition fused_moegemm_shape.hpp:89
remove_cvref_t< WarpPerBlock_0_ > WarpPerBlock_0
Definition fused_moegemm_shape.hpp:52
static constexpr index_t Block_Nr0
Definition fused_moegemm_shape.hpp:116
static constexpr index_t Warp_M0
Definition fused_moegemm_shape.hpp:70
static constexpr index_t Block_Kr0
Definition fused_moegemm_shape.hpp:117
remove_cvref_t< BlockTile_1_ > BlockTile_1
Definition fused_moegemm_shape.hpp:54
static constexpr index_t WarpPerBlock_M1
Definition fused_moegemm_shape.hpp:87
static constexpr index_t Warp_M1
Definition fused_moegemm_shape.hpp:90
static constexpr index_t ThreadPerBlock_K1
Definition fused_moegemm_shape.hpp:96
static constexpr index_t Block_N1
Definition fused_moegemm_shape.hpp:85
static constexpr index_t Repeat_N1
Definition fused_moegemm_shape.hpp:101
static constexpr index_t WarpPerBlock_K0
Definition fused_moegemm_shape.hpp:69
static constexpr index_t Warp_K0
Definition fused_moegemm_shape.hpp:72
static constexpr index_t ThreadPerBlock_K0
Definition fused_moegemm_shape.hpp:76
static constexpr index_t WarpPerBlock_N1
Definition fused_moegemm_shape.hpp:88
static constexpr index_t ThreadPerBlock_N1
Definition fused_moegemm_shape.hpp:95
static constexpr index_t Warp_K1
Definition fused_moegemm_shape.hpp:92
static constexpr index_t Block_W1
Definition fused_moegemm_shape.hpp:118
Definition tile/core/numeric/math.hpp:98