25 const InTensor& in_tensor)
29 static_assert(std::is_same_v<typename InTensor::DataType, typename OutTensor::DataType>,
30 "Data type for InTensor and OutTensor must be the same!");
32 using DataType =
typename InTensor::DataType;
34 constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
35 constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
40 constexpr auto get_rh_minor_to_y = [](
auto dstr_tensor) {
41 using DstrEncode =
typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
46 constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
48 rh_minor_to_y_(rh_minor) = i;
51 return rh_minor_to_y_;
56 constexpr auto rh_minor_to_y_in = get_rh_minor_to_y(InTensor{});
57 constexpr auto rh_minor_to_y_out = get_rh_minor_to_y(OutTensor{});
60 constexpr auto y_dim_out_to_in = [&] {
63 for(
const auto& [rh_minor, y_out] : rh_minor_to_y_out)
65 y_dim_out_to_in_(y_out) = rh_minor_to_y_in[rh_minor];
68 return y_dim_out_to_in_;
71 constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
72 constexpr auto y_lengths =
to_sequence(y_in_desc.get_lengths());
75 constexpr index_t y_dim_vec_in = NDimY - 1;
76 constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
79 constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
80 constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
83 constexpr index_t num_vec_in = vec_length_out;
84 constexpr index_t num_vec_out = vec_length_in;
89 if constexpr(vec_length_in == 1)
92 return (i == y_dim_vec_in || i == y_dim_vec_out) ? y_lengths[i] : 1;
96 constexpr auto scalars_per_access =
TO_SEQUENCE(scalars_per_access_arr, NDimY);
100 decltype(scalars_per_access)>;
102 constexpr index_t num_access = SFC_Y::get_num_of_access();
104 static_assert(num_access > 0,
"wrong! num_access should be larger than 0");
106 if constexpr(num_vec_in == 1 || num_vec_out == 1)
111 constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
112 constexpr auto idx_y_in =
114 constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
115 static_assert(in_offset % vec_length_in == 0);
116 constexpr auto idx_y_out_tmp =
118 constexpr auto idx_y_out =
120 constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
121 if constexpr(vec_length_in == 1)
130 out_tensor.get_thread_buffer().template get_as<Vec>(
132 in_tensor.get_thread_buffer().template get_as<Vec>(
149 constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
155 return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
159 constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
160 static_assert(in_offset % vec_length_in == 0);
162 in_vectors(i).template get_as<InVec>()(I0) =
163 in_tensor.get_thread_buffer()
164 .template get_as<InVec>()[
number<in_offset / vec_length_in>{}];
174 return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii];
178 constexpr auto idx_y_out =
181 constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
182 static_assert(out_offset % vec_length_out == 0);
184 out_tensor.get_thread_buffer().template set_as<OutVec>(
186 out_vectors[i].template get_as<OutVec>()[I0]);
197 using InDataType =
typename InTensor::DataType;
198 using OutDataType =
typename OutTensor::DataType;
200 using InTileDistr =
typename InTensor::StaticTileDistribution;
201 using OutTileDistr =
typename OutTensor::StaticTileDistribution;
203 using InDstrEncode =
typename InTileDistr::DstrEncode;
204 using OutDstrEncode =
typename OutTileDistr::DstrEncode;
206 using InThreadTensorDesc =
typename InTensor::ThreadTensorDesc;
207 using OutThreadTensorDesc =
typename OutTensor::ThreadTensorDesc;
210 constexpr auto in_thread_desc_lengths = InThreadTensorDesc{}.get_lengths();
211 constexpr auto out_thread_desc_lengths = OutThreadTensorDesc{}.get_lengths();
214 const auto in_tmp = [&]() {
215 if constexpr(std::is_same_v<OutDataType, InDataType>)
227 if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
228 InDstrEncode::hs_lengthss_ ==
tuple_reverse(OutDstrEncode::hs_lengthss_) &&
229 InDstrEncode::NDimY == OutDstrEncode::NDimY && InDstrEncode::NDimY == 2 &&
230 in_thread_desc_lengths ==
tuple_reverse(out_thread_desc_lengths))
239 static_assert(
false,
"Provided tensors could not be transposed!");