functional_with_tuple.hpp Source File

functional_with_tuple.hpp Source File#

Composable Kernel: functional_with_tuple.hpp Source File
functional_with_tuple.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6// This file should not be included inside tuple.hpp!
7
15#include <stdint.h>
16#include <utility>
17
18namespace ck_tile {
19
20namespace detail {
21
22// RemainLengths: sequence<...>
23// Orders: sequence<...>
24template <class RemainLengths, class RamainUnpacks, class Orders>
26{
28 {
29 static_assert(RemainLengths::size() > 0, "wrong! should not get here");
30 static_assert(RamainUnpacks::size() > 0, "wrong! should not get here");
31 }
32
33 template <class F, class CurrentUnpackIds>
34 CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds) const
35 {
36 constexpr index_t pack_len = RamainUnpacks::front();
37 static_for<0, RemainLengths::front(), pack_len>{}([=](auto I) {
38 constexpr auto new_pack = generate_tuple(
39 [&](auto idx_) {
40 constexpr auto i_new_pack = number<I + idx_ % pack_len>{};
41 constexpr auto i_pre_pack = number<idx_ / pack_len>{};
42 return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
43 },
44 number<CurrentUnpackIds::size() * pack_len>{});
45
46 static_uford_impl<decltype(RemainLengths::pop_front()),
47 decltype(RamainUnpacks::pop_front()),
48 Orders>{}(f, new_pack);
49 });
50 }
51};
52
53template <class Orders>
55{
56 template <class F, class PackedId>
57 CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId) const
58 {
59 constexpr auto origin_packs = transform_tuples(
60 [](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
61 unpack(f, origin_packs);
62 }
63};
64
65template <class RemainLengths, class RamainUnpacks, class Orders>
67{
68 template <class F, class CurrentUnpackIds, index_t current_acc>
69 CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number<current_acc>) const
70 {
71 constexpr auto r_lens_stride =
73 constexpr auto r_upks_stride =
75
76 constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front();
77 constexpr index_t pack_len = RamainUnpacks::front();
78 constexpr index_t current_idx = (current_acc / current_stride) * pack_len;
79
80 constexpr auto new_pack = generate_tuple(
81 [&](auto idx_) {
82 constexpr auto i_new_pack = number<current_idx + idx_ % pack_len>{};
83 constexpr auto i_pre_pack = number<idx_ / pack_len>{};
84 return CurrentUnpackIds{}.at(i_pre_pack).push_back(i_new_pack);
85 },
86 number<CurrentUnpackIds::size() * pack_len>{});
87
88 static_uford_one_shot_impl<decltype(RemainLengths::pop_front()),
89 decltype(RamainUnpacks::pop_front()),
90 Orders>{}(f, new_pack, number<current_acc % current_stride>{});
91 }
92};
93
94template <class Orders>
96{
97 template <class F, class PackedId, index_t current_acc>
98 CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId, number<current_acc>) const
99 {
100 constexpr auto origin_packs = transform_tuples(
101 [](auto pack_) { return decltype(pack_)::reorder_old_to_new(Orders{}); }, PackedId{});
102 unpack(f, origin_packs);
103 }
104};
105
106} // namespace detail
107
108// TODO: we may unify static_ford/static_uford in the future
109//
110// loop over nd space(sequence) with packs
111// you must make sure the function passed in has same number of argument
112//
113// e.g.
114// Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2>
115// static_uford<Lengths, Unpacks>{}([&](auto i_0, auto i_1){}); // require 2 args(packs)
116//
117// loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1>
118// loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3>
119// loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1>
120// loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3>
121// loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1>
122// loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3>
123// loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1>
124// ...
125template <class Lengths,
126 class Unpacks = typename uniform_sequence_gen<Lengths::size(), 1>::type,
127 class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
129{
130 static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies{}, number<1>{});
131
133 {
134 static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
135 static_assert(Lengths::size() == Unpacks::size(), "wrong! inconsistent size");
136 static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
137 static_for<0, Lengths::size(), 1>{}(
138 [&](auto i) { static_assert(Lengths{}.at(i) % Unpacks{}.at(i) == 0); });
139 }
140
142 {
143 using L_ = decltype(Lengths{} / Unpacks{});
144
145 return reduce_on_sequence(L_{}, multiplies{}, number<1>{});
146 }
147
148 // F signature: F(sequence<...> multi_id...)
149 // multi_id is the unordered multi-index
150 template <class F>
151 CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
152 {
153 constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
154 constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
155 detail::static_uford_impl<decltype(ordered_lengths), decltype(ordered_unpacks), Orders>{}(
156 f, make_tuple(sequence<>{}));
157 }
158
159 // this version is friendly for issue function one by one
160 template <class F, index_t i_access>
162 {
163 static_assert(i_access < get_num_of_access());
164 constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
165 constexpr auto ordered_unpacks = Unpacks::reorder_new_to_old(Orders{});
166 detail::static_uford_one_shot_impl<decltype(ordered_lengths),
167 decltype(ordered_unpacks),
168 Orders>{}(
170 }
171};
172
173} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition arch.hpp:385
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, number< Init >)
Definition tile/core/container/sequence.hpp:863
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X &x)
Definition tile/core/container/tuple.hpp:505
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto unpack(F &&f, X &&x)
Definition tile/core/utility/functional.hpp:200
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition tile/core/container/sequence.hpp:982
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId) const
Definition functional_with_tuple.hpp:57
Definition functional_with_tuple.hpp:26
CK_TILE_HOST_DEVICE constexpr static_uford_impl()
Definition functional_with_tuple.hpp:27
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds) const
Definition functional_with_tuple.hpp:34
CK_TILE_HOST_DEVICE constexpr void operator()(F f, PackedId, number< current_acc >) const
Definition functional_with_tuple.hpp:98
Definition functional_with_tuple.hpp:67
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number< current_acc >) const
Definition functional_with_tuple.hpp:69
Definition tile/core/numeric/math.hpp:98
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
CK_TILE_HOST_DEVICE constexpr static_uford()
Definition functional_with_tuple.hpp:132
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
Definition functional_with_tuple.hpp:151
static constexpr index_t num_packs
Definition functional_with_tuple.hpp:130
CK_TILE_HOST_DEVICE constexpr void operator()(F f, number< i_access >) const
Definition functional_with_tuple.hpp:161
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_access()
Definition functional_with_tuple.hpp:141