device_grouped_gemm_softmax_gemm_permute.hpp Source File

device_grouped_gemm_softmax_gemm_permute.hpp Source File#

Composable Kernel: device_grouped_gemm_softmax_gemm_permute.hpp Source File
device_grouped_gemm_softmax_gemm_permute.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
9#include "device_base.hpp"
11
12namespace ck {
13namespace tensor_operation {
14namespace device {
15
16template <index_t NumDimG,
17 index_t NumDimM,
18 index_t NumDimN,
19 index_t NumDimK,
20 index_t NumDimO,
21 typename ADataType,
22 typename B0DataType,
23 typename B1DataType,
24 typename CDataType,
25 typename Acc0BiasDataType,
26 typename Acc1BiasDataType,
27 typename AElementwiseOperation,
28 typename B0ElementwiseOperation,
29 typename Acc0ElementwiseOperation,
30 typename B1ElementwiseOperation,
31 typename CElementwiseOperation,
32 MaskingSpecialization MaskingSpec>
34{
36 {
37 std::vector<index_t> a_gs_ms_ks_lengths;
38 std::vector<index_t> a_gs_ms_ks_strides;
39
40 std::vector<index_t> b0_gs_ns_ks_lengths;
41 std::vector<index_t> b0_gs_ns_ks_strides;
42
43 std::vector<index_t> b1_gs_os_ns_lengths;
44 std::vector<index_t> b1_gs_os_ns_strides;
45
46 std::vector<index_t> c_gs_ms_os_lengths;
47 std::vector<index_t> c_gs_ms_os_strides;
48
49 std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
50 std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
51
52 std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths;
53 std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides;
54 };
55
56 virtual std::unique_ptr<BaseArgument>
57 MakeArgumentPointer(std::vector<const void*> p_a_vec,
58 std::vector<const void*> p_b0_vec,
59 std::vector<const void*> p_b1_vec,
60 std::vector<void*> p_c_vec,
61 std::vector<std::vector<const void*>> p_acc0_biases_vec,
62 std::vector<std::vector<const void*>> p_acc1_biases_vec,
63 std::vector<ProblemDesc> problem_desc_vec,
64 AElementwiseOperation a_element_op,
65 B0ElementwiseOperation b0_element_op,
66 Acc0ElementwiseOperation acc0_element_op,
67 B1ElementwiseOperation b1_element_op,
68 CElementwiseOperation c_element_op) = 0;
69
70 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
71};
72
73} // namespace device
74} // namespace tensor_operation
75} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
MaskingSpecialization
Definition masking_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_grouped_gemm_softmax_gemm_permute.hpp:36
std::vector< index_t > b1_gs_os_ns_strides
Definition device_grouped_gemm_softmax_gemm_permute.hpp:44
std::vector< index_t > c_gs_ms_os_strides
Definition device_grouped_gemm_softmax_gemm_permute.hpp:47
std::vector< index_t > b1_gs_os_ns_lengths
Definition device_grouped_gemm_softmax_gemm_permute.hpp:43
std::vector< std::vector< index_t > > acc1_biases_gs_ms_os_strides
Definition device_grouped_gemm_softmax_gemm_permute.hpp:53
std::vector< std::vector< index_t > > acc1_biases_gs_ms_os_lengths
Definition device_grouped_gemm_softmax_gemm_permute.hpp:52
std::vector< index_t > b0_gs_ns_ks_lengths
Definition device_grouped_gemm_softmax_gemm_permute.hpp:40
std::vector< index_t > c_gs_ms_os_lengths
Definition device_grouped_gemm_softmax_gemm_permute.hpp:46
std::vector< std::vector< index_t > > acc0_biases_gs_ms_ns_strides
Definition device_grouped_gemm_softmax_gemm_permute.hpp:50
std::vector< index_t > a_gs_ms_ks_strides
Definition device_grouped_gemm_softmax_gemm_permute.hpp:38
std::vector< index_t > a_gs_ms_ks_lengths
Definition device_grouped_gemm_softmax_gemm_permute.hpp:37
std::vector< index_t > b0_gs_ns_ks_strides
Definition device_grouped_gemm_softmax_gemm_permute.hpp:41
std::vector< std::vector< index_t > > acc0_biases_gs_ms_ns_lengths
Definition device_grouped_gemm_softmax_gemm_permute.hpp:49
Definition device_grouped_gemm_softmax_gemm_permute.hpp:34
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > p_a_vec, std::vector< const void * > p_b0_vec, std::vector< const void * > p_b1_vec, std::vector< void * > p_c_vec, std::vector< std::vector< const void * > > p_acc0_biases_vec, std::vector< std::vector< const void * > > p_acc1_biases_vec, std::vector< ProblemDesc > problem_desc_vec, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, Acc0ElementwiseOperation acc0_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)=0