static_distributed_tensor.hpp Source File

static_distributed_tensor.hpp Source File#

Composable Kernel: static_distributed_tensor.hpp Source File
static_distributed_tensor.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
17namespace ck_tile {
18
19template <typename DataType_, typename StaticTileDistribution_>
21{
24
25 static_assert(StaticTileDistribution::is_static(),
26 "wrong! StaticTileDistribution should be known at compile tile");
27
29 remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
30 static constexpr index_t PackedSize =
32
33 static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
34 static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid");
35
37 {
38 return StaticTileDistribution::get_num_of_dimension_x();
39 }
40
41 CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
42 {
43 return StaticTileDistribution::get_lengths();
44 }
45
47 {
49 }
50
52 {
53 return StaticTileDistribution::get_distributed_spans();
54 }
55
56 CK_TILE_HOST_DEVICE void initialize(const DataType& x) { thread_buf_.initialize(x); }
57
58 CK_TILE_HOST_DEVICE constexpr const auto& get_thread_buffer() const { return thread_buf_; }
59
61
66
67 template <index_t... YSliceOrigins, index_t... YSliceLengths>
70 {
71 static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
72 sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
73 "wrong!");
74
75 constexpr auto sliced_thread_tensor_desc =
77
78 thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
79 sliced_thread_data;
80
81 static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
82 constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
83
84 sliced_thread_data(
85 number<sliced_thread_tensor_desc.calculate_offset(idx) / PackedSize>{}) =
86 thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}];
87 });
88
89 return sliced_thread_data;
90 }
91
92 template <index_t... YSliceOrigins, index_t... YSliceLengths, typename SlicedThreadData>
95 const SlicedThreadData& sliced_thread_data)
96 {
97 static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
98 sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
99 "wrong!");
100
101 constexpr auto sliced_thread_tensor_desc =
103
104 static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
105 constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
106
107 thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}) =
108 sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx) /
109 PackedSize>{}];
110 });
111 }
112
113 template <typename TileDistributedIndices>
114 CK_TILE_HOST_DEVICE constexpr const DataType& operator[](TileDistributedIndices) const
115 {
117 "wrong! Tile Distributed Indices should be static");
118
119 constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
120 TileDistributedIndices{});
121
122 return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{}];
123 }
124
125 template <typename TileDistributedIndices>
126 CK_TILE_HOST_DEVICE constexpr DataType& operator()(TileDistributedIndices)
127 {
129 "wrong! Tile Distributed Indices should be static");
130
131 constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
132 TileDistributedIndices{});
133
134 return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{});
135 }
136
137 //
139};
140
141template <typename DataType, typename StaticTileDistribution>
147
148template <typename DataType, typename StaticTileDistribution, typename ThreadBuffer>
149CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&,
150 ThreadBuffer&& thread_buffer_)
151{
154}
155
156// get X indices from tuple of tile_distributed_index<>
157template <typename StaticTileDistribution, typename DistributedIndices>
158CK_TILE_HOST_DEVICE constexpr auto
160 DistributedIndices distributed_indices)
161{
162 const auto partition_index = detail::get_partition_index(tile_distribution);
163 constexpr auto y_indices =
165
166 const auto x_coord = make_tensor_adaptor_coordinate(
168 container_concat(partition_index, to_array<ck_tile::index_t, y_indices.size()>(y_indices)));
169
170 return x_coord.get_bottom_index();
171}
172
173template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
176 DataType value,
177 XIndicesPredicate predicate)
178{
179 constexpr auto out_spans =
181 sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
182 sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
183 constexpr auto distributed_indices = make_tuple(idx0, idx1);
184 const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{},
185 distributed_indices);
186
187 if(predicate(x_indices))
188 {
189 out_tensor(distributed_indices) = value;
190 }
191 });
192 });
193}
194
195// this function used inside span loop over
196template <typename YLengths, index_t XUnpacks>
198{
199 constexpr auto y_size = reduce_on_sequence(YLengths{}, multiplies{}, number<1>{});
200 constexpr auto y_packs = number<XUnpacks>{};
201 static_assert(y_size % y_packs == 0);
202 constexpr auto y_slice_size = y_size / y_packs;
203
204 constexpr auto slice_info = slice_sequence(YLengths{}, number<y_slice_size>{});
205 constexpr auto unpacks = slice_info[number<1>{}];
206 return unpacks;
207}
208
209namespace detail {
210
211// check if 2 static_distributed_tensor has same data type and size of element
212// but only difference in distribution
213template <typename X, typename Y>
215{
216 static constexpr bool value = false;
217};
218
219template <typename TypeX, typename DistX, typename TypeY, typename DistY>
221 static_distributed_tensor<TypeY, DistY>>
222{
225 static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
227};
228
229template <typename X, typename Y>
230inline constexpr bool is_similiar_distributed_tensor_v =
232
233} // namespace detail
234
235} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
constexpr bool is_similiar_distributed_tensor_v
Definition static_distributed_tensor.hpp:230
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition tile_distribution.hpp:22
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:371
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X &x, const Ys &... ys)
Definition tile/core/container/container_helper.hpp:363
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto to_array(const std::vector< X > &x)
Definition tile/core/container/array.hpp:286
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition tensor_adaptor_coordinate.hpp:55
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
constexpr bool is_static_v
Definition type_traits.hpp:90
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number< XUnpacks >)
Definition static_distributed_tensor.hpp:197
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
constexpr auto slice_sequence(Seq, number< SliceSize >, Mask=typename uniform_sequence_gen< Seq::size(), 1 >::type{})
Definition tile/core/container/sequence.hpp:1249
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition tile/core/container/sequence.hpp:982
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition static_distributed_tensor.hpp:215
static constexpr bool value
Definition static_distributed_tensor.hpp:216
Definition tile/core/numeric/math.hpp:98
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/container/sequence.hpp:49
Definition static_distributed_tensor.hpp:21
CK_TILE_HOST_DEVICE constexpr const DataType & operator[](TileDistributedIndices) const
Definition static_distributed_tensor.hpp:114
static constexpr index_t kThreadElementSpaceSize
Definition static_distributed_tensor.hpp:33
static CK_TILE_HOST_DEVICE constexpr auto get_lengths()
Definition static_distributed_tensor.hpp:41
static CK_TILE_HOST_DEVICE constexpr index_t get_thread_buffer_size()
Definition static_distributed_tensor.hpp:62
remove_cvref_t< decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())> ThreadTensorDesc
Definition static_distributed_tensor.hpp:28
remove_cvref_t< StaticTileDistribution_ > StaticTileDistribution
Definition static_distributed_tensor.hpp:23
CK_TILE_HOST_DEVICE constexpr auto & get_thread_buffer()
Definition static_distributed_tensor.hpp:60
static CK_TILE_HOST_DEVICE constexpr auto get_distributed_spans()
Definition static_distributed_tensor.hpp:51
CK_TILE_HOST_DEVICE constexpr DataType & operator()(TileDistributedIndices)
Definition static_distributed_tensor.hpp:126
static CK_TILE_HOST_DEVICE constexpr auto get_num_of_dimension()
Definition static_distributed_tensor.hpp:36
CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence< YSliceOrigins... >, sequence< YSliceLengths... >, const SlicedThreadData &sliced_thread_data)
Definition static_distributed_tensor.hpp:93
CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence< YSliceOrigins... >, sequence< YSliceLengths... >) const
Definition static_distributed_tensor.hpp:68
thread_buffer< DataType, get_thread_buffer_size()> thread_buf_
Definition static_distributed_tensor.hpp:138
static CK_TILE_HOST_DEVICE constexpr auto get_tile_distribution()
Definition static_distributed_tensor.hpp:46
CK_TILE_HOST_DEVICE constexpr const auto & get_thread_buffer() const
Definition static_distributed_tensor.hpp:58
remove_cvref_t< DataType_ > DataType
Definition static_distributed_tensor.hpp:22
static constexpr index_t PackedSize
Definition static_distributed_tensor.hpp:30
CK_TILE_HOST_DEVICE void initialize(const DataType &x)
Definition static_distributed_tensor.hpp:56
Definition tile/core/utility/functional.hpp:141
Definition tile/core/utility/debug.hpp:67
Definition tile_distribution.hpp:72
CK_TILE_HOST_DEVICE constexpr const auto & get_ps_ys_to_xs_adaptor() const
Definition tile_distribution.hpp:126
static CK_TILE_HOST_DEVICE constexpr auto get_y_indices_from_distributed_indices(DistributedIndices)
Definition tile_distribution.hpp:205