binary_elementwise_operation.hpp Source File

binary_elementwise_operation.hpp Source File#

Composable Kernel: binary_elementwise_operation.hpp Source File
binary_elementwise_operation.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7
8namespace ck_tile {
9namespace element_wise {
10
11struct Add
12{
13 template <typename Y, typename X0, typename X1>
14 __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
15
16 template <>
17 __host__ __device__ constexpr void
18 operator()<float>(float& y, const float& x0, const float& x1) const
19 {
20 y = x0 + x1;
21 };
22
23 template <>
24 __host__ __device__ constexpr void
25 operator()<double>(double& y, const double& x0, const double& x1) const
26 {
27 y = x0 + x1;
28 };
29
30 template <>
31 __host__ __device__ constexpr void
32 operator()<float>(float& y, const float& x0, const half_t& x1) const
33 {
34 y = x0 + type_convert<half_t>(x1);
35 };
36
37 template <>
38 __host__ __device__ constexpr void
39 operator()<half_t>(half_t& y, const float& x0, const float& x1) const
40 {
41 y = type_convert<half_t>(x0 + x1);
42 };
43
44 template <>
45 __host__ __device__ constexpr void
46 operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
47 {
48 y = type_convert<half_t>(x0) + x1;
49 };
50
51 template <>
52 __host__ __device__ constexpr void
53 operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
54 {
55 y = x0 + x1;
56 };
57
58 template <>
59 __host__ __device__ constexpr void
60 operator()<float>(float& y, const float& x0, const bf16_t& x1) const
61 {
62 const float x1_tmp = type_convert<float>(x1);
63 y = x0 + x1_tmp;
64 }
65
66 template <>
67 __host__ __device__ constexpr void
68 operator()<bf16_t>(bf16_t& y, const bf16_t& x0, const bf16_t& x1) const
69 {
70 const float x1_tmp = type_convert<float>(x0);
71 const float x2_tmp = type_convert<float>(x1);
72 const float y_tmp = x1_tmp + x2_tmp;
73 y = type_convert<bf16_t>(y_tmp);
74 }
75
76 template <>
77 __host__ __device__ constexpr void
78 operator()<bf16_t>(bf16_t& y, const float& x0, const bf16_t& x1) const
79 {
80 const float x2_tmp = type_convert<float>(x1);
81 const float y_tmp = x0 + x2_tmp;
82 y = type_convert<bf16_t>(y_tmp);
83 }
84
85 template <>
86 __host__ __device__ constexpr void
87 operator()<bf16_t>(bf16_t& y, const float& x0, const float& x1) const
88 {
89 const float y_tmp = x0 + x1;
90 y = type_convert<bf16_t>(y_tmp);
91 }
92
93 template <>
94 __host__ __device__ constexpr void
95 operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
96 {
97 y = x0 + x1;
98 };
99};
100
101} // namespace element_wise
102} // namespace ck_tile
Definition binary_elementwise_operation.hpp:9
Definition tile/core/algorithm/cluster_descriptor.hpp:13
_Float16 half_t
Definition half.hpp:111
int8_t int8_t
Definition int8.hpp:20
bfloat16_t bf16_t
Definition bfloat16.hpp:113
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
Definition binary_elementwise_operation.hpp:12
__host__ __device__ constexpr void operator()(Y &y, const X0 &x0, const X1 &x1) const