reference_im2col.hpp Source File

reference_im2col.hpp Source File#

Composable Kernel: reference_im2col.hpp Source File
reference_im2col.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <thread>
9
10namespace ck_tile {
11
12template <typename InDataType, typename OutDataType, index_t NDimSpatial>
15 const ck_tile::conv::ConvParam& conv_params)
16{
17 const long_index_t G = in_host.get_lengths()[0];
18 const long_index_t N = in_host.get_lengths()[1];
19 const long_index_t C = in_host.get_lengths()[2];
20
21 if constexpr(NDimSpatial == 1)
22 {
23 const long_index_t Wo = conv_params.output_spatial_lengths_[0];
24 auto func = [&](auto g, auto n, auto wo) {
25 long_index_t row = n * Wo + wo;
26 long_index_t column = 0;
27
28 for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[0]; ++x)
29 {
30 auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[0]) +
31 static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[0]) -
32 static_cast<long_index_t>(conv_params.input_left_pads_[0]);
33
34 for(long_index_t c = 0; c < C; ++c)
35 {
36 if(wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[3])
37 {
38 InDataType v_in = in_host(g, n, c, wi);
39 out_host(g, row, column) = type_convert<OutDataType>(v_in);
40 }
41 column++;
42 }
43 }
44 };
45
46 make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
47 }
48 else if constexpr(NDimSpatial == 2)
49 {
50 const long_index_t Ho = conv_params.output_spatial_lengths_[0];
51 const long_index_t Wo = conv_params.output_spatial_lengths_[1];
52
53 auto func = [&](auto g, auto n, auto ho, auto wo) {
54 long_index_t row = n * Ho * Wo + ho * Wo + wo;
55 long_index_t column = 0;
56
57 for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[0]; ++y)
58 {
59 auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[0]) +
60 static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[0]) -
61 static_cast<long_index_t>(conv_params.input_left_pads_[0]);
62
63 for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[1]; ++x)
64 {
65 auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[1]) +
66 static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[1]) -
67 static_cast<long_index_t>(conv_params.input_left_pads_[1]);
68
69 for(long_index_t c = 0; c < C; ++c)
70 {
71
72 if(hi >= 0 && type_convert<std::size_t>(hi) < in_host.get_lengths()[3] &&
73 wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[4])
74 {
75 InDataType v_in = in_host(g, n, c, hi, wi);
76 out_host(g, row, column) = type_convert<OutDataType>(v_in);
77 }
78 column++;
79 }
80 }
81 }
82 };
83
84 make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
85 }
86 else if constexpr(NDimSpatial == 3)
87 {
88 const long_index_t Do = conv_params.output_spatial_lengths_[0];
89 const long_index_t Ho = conv_params.output_spatial_lengths_[1];
90 const long_index_t Wo = conv_params.output_spatial_lengths_[2];
91
92 auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
93 long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
94 long_index_t column = 0;
95
96 for(long_index_t z = 0; z < conv_params.filter_spatial_lengths_[0]; ++z)
97 {
98 auto di = static_cast<long_index_t>(d_o * conv_params.conv_filter_strides_[0]) +
99 static_cast<long_index_t>(z * conv_params.conv_filter_dilations_[0]) -
100 static_cast<long_index_t>(conv_params.input_left_pads_[0]);
101 for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[1]; ++y)
102 {
103 auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[1]) +
104 static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[1]) -
105 static_cast<long_index_t>(conv_params.input_left_pads_[1]);
106 for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[2]; ++x)
107 {
108 auto wi =
109 static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[2]) +
110 static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[2]) -
111 static_cast<long_index_t>(conv_params.input_left_pads_[2]);
112 for(long_index_t c = 0; c < C; ++c)
113 {
114 if(di >= 0 &&
115 type_convert<std::size_t>(di) < in_host.get_lengths()[3] &&
116 hi >= 0 &&
117 type_convert<std::size_t>(hi) < in_host.get_lengths()[4] &&
118 wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[5])
119 {
120 InDataType v_in = in_host(g, n, c, di, hi, wi);
121 out_host(g, row, column) = type_convert<OutDataType>(v_in);
122 }
123 column++;
124 }
125 }
126 }
127 }
128 };
129
130 make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(std::thread::hardware_concurrency());
131 }
132}
133} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
CK_TILE_HOST void reference_im2col(const HostTensor< InDataType > &in_host, HostTensor< OutDataType > &out_host, const ck_tile::conv::ConvParam &conv_params)
Definition reference_im2col.hpp:13
int64_t long_index_t
Definition integer.hpp:11
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
Definition tile/host/host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition tile/host/host_tensor.hpp:390
Definition tile/host/convolution_parameter.hpp:15
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition tile/host/convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:129
std::vector< ck_tile::long_index_t > input_left_pads_
Definition tile/host/convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition tile/host/convolution_parameter.hpp:134