warp_gemm_attribute_wmma.hpp Source File

warp_gemm_attribute_wmma.hpp Source File#

Composable Kernel: warp_gemm_attribute_wmma.hpp Source File
warp_gemm_attribute_wmma.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12// TODO: currently only support 16 bit input, which means only support tr16_b128; will use ADataType
13// to determine the layout in the future
14template <typename Impl>
26
27template <typename Impl>
39
40template <typename Impl>
52
53template <typename Impl>
65
66template <typename WarpGemmAttributeWmmaImpl_, bool kTransC = false>
68{
70
71 using ADataType = typename Impl::ADataType;
72 using BDataType = typename Impl::BDataType;
73 using CDataType = typename Impl::CDataType;
74
75 using AVecType = typename Impl::AVecType;
76 using BVecType = typename Impl::BVecType;
77 using CVecType = typename Impl::CVecType;
78
79 static constexpr index_t kM = Impl::kM;
80 static constexpr index_t kN = Impl::kN;
81 static constexpr index_t kK = Impl::kK;
82 static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane;
83
84 CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
85
86 // 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2
87 // 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4
90
91 // kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16
93 std::conditional_t<kTransC,
96
97 // c_vec += a_vec * b_vec
98 template <bool post_nop_ = false>
100 const AVecType& a_vec,
101 const BVecType& b_vec,
102 bool_constant<post_nop_> = {}) const
103 {
104 if constexpr(kTransC)
105 {
106 Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
107 }
108 else
109 {
110 Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
111 }
112 }
113
114 // c_vec = a_vec * b_vec
115 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
116 {
117 if constexpr(kTransC)
118 {
119 return Impl{}(b_vec, a_vec);
120 }
121 else
122 {
123 return Impl{}(a_vec, b_vec);
124 }
125 }
126};
127
128template <typename ADataType,
129 typename BDataType,
130 typename AccDataType,
131 index_t M_Warp_Tile,
132 index_t N_Warp_Tile,
133 index_t K_Warp_Tile>
135{
137 {
139 ADataType,
140 BDataType,
141 AccDataType,
142 M_Warp_Tile,
143 N_Warp_Tile,
144 K_Warp_Tile>;
145 }
146 else if(is_gfx11_supported())
147 {
149 ADataType,
150 BDataType,
151 AccDataType,
152 M_Warp_Tile,
153 N_Warp_Tile,
154 K_Warp_Tile>;
155 }
156 else
157 {
158 return false;
159 }
160}
161
162} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
constexpr bool has_wmma_traits_v
Definition warp_gemm_attribute_wmma_impl.hpp:138
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST bool check_wmma_supported()
Definition warp_gemm_attribute_wmma.hpp:134
bool is_gfx12_supported()
Definition tile/host/device_prop.hpp:63
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
bool is_gfx11_supported()
Definition tile/host/device_prop.hpp:55
int32_t index_t
Definition integer.hpp:9
Definition warp_gemm_attribute_wmma.hpp:16
tile_distribution_encoding< sequence< Impl::kRepeat >, tuple< sequence< Impl::kAMLane >, sequence< Impl::kABK0PerLane, Impl::kABKLane, Impl::kABK1PerLane > >, tuple< typename Impl::kABPs2RHssMajor >, tuple< typename Impl::kABPs2RHssMinor >, typename Impl::kABYs2RHsMajor, typename Impl::kABYs2RHsMinor > type
Definition warp_gemm_attribute_wmma.hpp:17
Definition warp_gemm_attribute_wmma.hpp:29
tile_distribution_encoding< sequence< Impl::kRepeat >, tuple< sequence< Impl::kBNLane >, sequence< Impl::kABK0PerLane, Impl::kABKLane, Impl::kABK1PerLane > >, tuple< typename Impl::kABPs2RHssMajor >, tuple< typename Impl::kABPs2RHssMinor >, typename Impl::kABYs2RHsMajor, typename Impl::kABYs2RHsMinor > type
Definition warp_gemm_attribute_wmma.hpp:30
Definition warp_gemm_attribute_wmma.hpp:55
tile_distribution_encoding< sequence<>, tuple< sequence< Impl::kCNLane >, sequence< Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane > >, tuple< typename Impl::kCTPs2RHssMajor >, tuple< typename Impl::kCTPs2RHssMinor >, typename Impl::kCTYs2RHsMajor, typename Impl::kCTYs2RHsMinor > type
Definition warp_gemm_attribute_wmma.hpp:56
Definition warp_gemm_attribute_wmma.hpp:42
tile_distribution_encoding< sequence<>, tuple< sequence< Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane >, sequence< Impl::kCNLane > >, tuple< typename Impl::kCPs2RHssMajor >, tuple< typename Impl::kCPs2RHssMinor >, typename Impl::kCYs2RHsMajor, typename Impl::kCYs2RHsMinor > type
Definition warp_gemm_attribute_wmma.hpp:43
Definition warp_gemm_attribute_wmma.hpp:68
static CK_TILE_HOST_DEVICE constexpr auto get_num_of_access()
Definition warp_gemm_attribute_wmma.hpp:84
remove_cvref_t< WarpGemmAttributeWmmaImpl_ > Impl
Definition warp_gemm_attribute_wmma.hpp:69
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_wmma.hpp:99
static constexpr index_t kN
Definition warp_gemm_attribute_wmma.hpp:80
static constexpr index_t kKPerThread
Definition warp_gemm_attribute_wmma.hpp:82
typename Impl::ADataType ADataType
Definition warp_gemm_attribute_wmma.hpp:71
typename Impl::CDataType CDataType
Definition warp_gemm_attribute_wmma.hpp:73
typename Impl::BVecType BVecType
Definition warp_gemm_attribute_wmma.hpp:76
typename Impl::BDataType BDataType
Definition warp_gemm_attribute_wmma.hpp:72
typename Impl::CVecType CVecType
Definition warp_gemm_attribute_wmma.hpp:77
std::conditional_t< kTransC, typename CTransposedWarpDstrEncodingTrait< Impl >::type, typename CWarpDstrEncodingTrait< Impl >::type > CWarpDstrEncoding
Definition warp_gemm_attribute_wmma.hpp:92
typename Impl::AVecType AVecType
Definition warp_gemm_attribute_wmma.hpp:75
static constexpr index_t kK
Definition warp_gemm_attribute_wmma.hpp:81
static constexpr index_t kM
Definition warp_gemm_attribute_wmma.hpp:79
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_wmma.hpp:115
typename AWarpDstrEncodingTrait< Impl >::type AWarpDstrEncoding
Definition warp_gemm_attribute_wmma.hpp:88
typename BWarpDstrEncodingTrait< Impl >::type BWarpDstrEncoding
Definition warp_gemm_attribute_wmma.hpp:89
Definition arch.hpp:363
Definition arch.hpp:366
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192