24template <
typename ALayout,
39 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
42 const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
46 const index_t N = g_n_c_wis_lengths[
I1] / split_n_size;
48 const index_t& Wi = g_n_c_wis_lengths[
I3];
50 const index_t& GStride = g_n_c_wis_strides[
I0];
51 const index_t& NStride = g_n_c_wis_strides[
I1];
52 const index_t& CStride = g_n_c_wis_strides[
I2];
53 const index_t& WiStride = g_n_c_wis_strides[
I3];
57 const auto merged_desc =
67 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
70 const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
74 const index_t N = g_n_c_wis_lengths[
I1] / split_n_size;
76 const index_t& Wi = g_n_c_wis_lengths[
I3];
78 const index_t& NStride = g_n_c_wis_strides[
I1];
85 const auto merged_desc =
95 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
98 const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
101 const index_t& G = g_n_c_wis_lengths[
I0];
102 const index_t N = g_n_c_wis_lengths[
I1] / split_n_size;
103 const index_t& C = g_n_c_wis_lengths[
I2];
104 const index_t& Hi = g_n_c_wis_lengths[
I3];
105 const index_t& Wi = g_n_c_wis_lengths[
I4];
107 const index_t& GStride = g_n_c_wis_strides[
I0];
108 const index_t& NStride = g_n_c_wis_strides[
I1];
109 const index_t& CStride = g_n_c_wis_strides[
I2];
110 const index_t& HiStride = g_n_c_wis_strides[
I3];
111 const index_t& WiStride = g_n_c_wis_strides[
I4];
115 const auto merged_desc =
125 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
128 const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
129 const index_t split_n_size = 1)
131 const index_t& G = g_n_c_wis_lengths[
I0];
132 const index_t N = g_n_c_wis_lengths[
I1] / split_n_size;
133 const index_t& C = g_n_c_wis_lengths[
I2];
134 const index_t& Hi = g_n_c_wis_lengths[
I3];
135 const index_t& Wi = g_n_c_wis_lengths[
I4];
137 const index_t& NStride = g_n_c_wis_strides[
I1];
138 const index_t HiStride = Wi * G * C;
139 const index_t WiStride = G * C;
145 const auto merged_desc =
155 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
158 const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
159 const index_t split_n_size = 1)
161 const index_t& G = g_n_c_wis_lengths[
I0];
162 const index_t N = g_n_c_wis_lengths[
I1] / split_n_size;
163 const index_t& C = g_n_c_wis_lengths[
I2];
164 const index_t& Di = g_n_c_wis_lengths[
I3];
165 const index_t& Hi = g_n_c_wis_lengths[
I4];
166 const index_t& Wi = g_n_c_wis_lengths[
I5];
168 const index_t& GStride = g_n_c_wis_strides[
I0];
169 const index_t& NStride = g_n_c_wis_strides[
I1];
170 const index_t& CStride = g_n_c_wis_strides[
I2];
171 const index_t& DiStride = g_n_c_wis_strides[
I3];
172 const index_t& HiStride = g_n_c_wis_strides[
I4];
173 const index_t& WiStride = g_n_c_wis_strides[
I5];
177 make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
178 const auto merged_desc =
188 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
191 const std::array<ck::index_t, NDimSpatial + 3>& g_n_c_wis_strides,
192 const index_t split_n_size = 1)
194 const index_t& G = g_n_c_wis_lengths[
I0];
195 const index_t N = g_n_c_wis_lengths[
I1] / split_n_size;
196 const index_t& C = g_n_c_wis_lengths[
I2];
197 const index_t& Di = g_n_c_wis_lengths[
I3];
198 const index_t& Hi = g_n_c_wis_lengths[
I4];
199 const index_t& Wi = g_n_c_wis_lengths[
I5];
201 const index_t& NStride = g_n_c_wis_strides[
I1];
202 const index_t DiStride = Hi * Wi * G * C;
203 const index_t HiStride = Wi * G * C;
204 const index_t WiStride = G * C;
210 make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
211 const auto merged_desc =
221 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
224 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
226 const index_t& G = g_k_c_wis_lengths[
I0];
227 const index_t& K = g_k_c_wis_lengths[
I1];
228 const index_t& C = g_k_c_wis_lengths[
I2];
229 const index_t& X = g_k_c_wis_lengths[
I3];
231 const index_t& GStride = g_k_c_wis_strides[
I0];
232 const index_t& KStride = g_k_c_wis_strides[
I1];
233 const index_t& CStride = g_k_c_wis_strides[
I2];
234 const index_t& XStride = g_k_c_wis_strides[
I3];
247 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
250 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
252 const index_t& G = g_k_c_wis_lengths[
I0];
253 const index_t& K = g_k_c_wis_lengths[
I1];
254 const index_t& C = g_k_c_wis_lengths[
I2];
255 const index_t& X = g_k_c_wis_lengths[
I3];
257 const index_t& GStride = g_k_c_wis_strides[
I0];
258 const index_t KStride = g_k_c_wis_strides[
I1];
273 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
276 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
278 const index_t& G = g_k_c_wis_lengths[
I0];
279 const index_t& K = g_k_c_wis_lengths[
I1];
280 const index_t& C = g_k_c_wis_lengths[
I2];
281 const index_t& Y = g_k_c_wis_lengths[
I3];
282 const index_t& X = g_k_c_wis_lengths[
I4];
284 const index_t& GStride = g_k_c_wis_strides[
I0];
285 const index_t& KStride = g_k_c_wis_strides[
I1];
286 const index_t& CStride = g_k_c_wis_strides[
I2];
287 const index_t& YStride = g_k_c_wis_strides[
I3];
288 const index_t& XStride = g_k_c_wis_strides[
I4];
292 const auto merged_desc =
302 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
305 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
307 const index_t& G = g_k_c_wis_lengths[
I0];
308 const index_t& K = g_k_c_wis_lengths[
I1];
309 const index_t& C = g_k_c_wis_lengths[
I2];
310 const index_t& Y = g_k_c_wis_lengths[
I3];
311 const index_t& X = g_k_c_wis_lengths[
I4];
313 const index_t& GStride = g_k_c_wis_strides[
I0];
314 const index_t KStride = g_k_c_wis_strides[
I1];
321 const auto merged_desc =
331 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
334 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
336 const index_t& G = g_k_c_wis_lengths[
I0];
337 const index_t& K = g_k_c_wis_lengths[
I1];
338 const index_t& C = g_k_c_wis_lengths[
I2];
339 const index_t& Z = g_k_c_wis_lengths[
I3];
340 const index_t& Y = g_k_c_wis_lengths[
I4];
341 const index_t& X = g_k_c_wis_lengths[
I5];
343 const index_t& GStride = g_k_c_wis_strides[
I0];
344 const index_t& KStride = g_k_c_wis_strides[
I1];
345 const index_t& CStride = g_k_c_wis_strides[
I2];
346 const index_t& ZStride = g_k_c_wis_strides[
I3];
347 const index_t& YStride = g_k_c_wis_strides[
I4];
348 const index_t& XStride = g_k_c_wis_strides[
I5];
352 make_tuple(GStride, KStride, CStride, ZStride, YStride, XStride));
353 const auto merged_desc =
363 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
366 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
368 const index_t& G = g_k_c_wis_lengths[
I0];
369 const index_t& K = g_k_c_wis_lengths[
I1];
370 const index_t& C = g_k_c_wis_lengths[
I2];
371 const index_t& Z = g_k_c_wis_lengths[
I3];
372 const index_t& Y = g_k_c_wis_lengths[
I4];
373 const index_t& X = g_k_c_wis_lengths[
I5];
375 const index_t& GStride = g_k_c_wis_strides[
I0];
376 const index_t KStride = g_k_c_wis_strides[
I1];
378 const index_t ZStride = Y * X * C;
384 make_tuple(GStride, KStride, CStride, ZStride, YStride, XStride));
385 const auto merged_desc =
396 const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_strides)
401 std::array<index_t, NDimSpatial + 3> g_n_c_wis_strides_transposed;
402 const auto G = g_n_c_wis_lengths[
I0];
403 const auto C = g_n_c_wis_lengths[
I2];
405 g_n_c_wis_strides_transposed[
I0] = C;
406 g_n_c_wis_strides_transposed[
I1] = g_n_c_wis_strides[
I1];
407 g_n_c_wis_strides_transposed[
I2] =
I1;
408 if constexpr(NDimSpatial == 2)
410 g_n_c_wis_strides_transposed[
I3] = g_n_c_wis_lengths[
I4] * G * C;
411 g_n_c_wis_strides_transposed[
I4] = G * C;
413 else if constexpr(NDimSpatial == 3)
415 g_n_c_wis_strides_transposed[
I3] =
416 g_n_c_wis_lengths[
I4] * g_n_c_wis_lengths[
I5] * G * C;
417 g_n_c_wis_strides_transposed[
I4] = g_n_c_wis_lengths[
I5] * G * C;
418 g_n_c_wis_strides_transposed[
I5] = G * C;
420 return g_n_c_wis_strides_transposed;
425 return g_n_c_wis_strides;
431 const std::array<ck::index_t, NDimSpatial + 3>& g_k_c_wis_strides)
436 std::array<index_t, NDimSpatial + 3> g_k_c_wis_strides_transposed = g_k_c_wis_strides;
439 if constexpr(NDimSpatial == 2)
442 g_k_c_wis_strides_transposed[
I2] = 1;
443 g_k_c_wis_strides_transposed[
I3] = X * C;
444 g_k_c_wis_strides_transposed[
I4] = C;
446 else if constexpr(NDimSpatial == 3)
450 g_k_c_wis_strides_transposed[
I2] = 1;
451 g_k_c_wis_strides_transposed[
I3] = Y * X * C;
452 g_k_c_wis_strides_transposed[
I4] = X * C;
453 g_k_c_wis_strides_transposed[
I5] = C;
455 return g_k_c_wis_strides_transposed;
460 return g_k_c_wis_strides;
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
constexpr bool is_NGCDHW_NGKDHW()
Definition device_grouped_conv_utils.hpp:112
constexpr bool is_NGCHW_GKCYX_NGKHW()
Definition device_grouped_conv_utils.hpp:64
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
Definition device_grouped_conv_utils.hpp:104
constexpr bool is_NGCHW_NGKHW()
Definition device_grouped_conv_utils.hpp:72
Definition convolution_backward_data_specialization.hpp:7
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Definition utility/sequence.hpp:43