tile_window.hpp Source File

tile_window.hpp Source File#

Composable Kernel: tile_window.hpp Source File
tile_window.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
19
20namespace ck_tile {
21
33template <typename BottomTensorView_,
34 typename WindowLengths_,
35 typename StaticTileDistribution_,
36 index_t NumCoord>
39 tile_window_with_static_distribution<BottomTensorView_,
40 WindowLengths_,
41 StaticTileDistribution_,
42 NumCoord>,
43 BottomTensorView_,
44 WindowLengths_,
45 StaticTileDistribution_>
46{
49 WindowLengths_,
50 StaticTileDistribution_,
51 NumCoord>,
52 BottomTensorView_,
53 WindowLengths_,
54 StaticTileDistribution_>;
55
56 static constexpr auto I0 = number<0>{};
57 static constexpr auto I1 = number<1>{};
58 static_assert(NumCoord == 1);
59
60 static_assert(Base::Traits::NumAccess % NumCoord == 0,
61 "wrong! # of access is not divisible by NumCoord");
62 static constexpr index_t NumAccessPerCoord = Base::Traits::NumAccess / NumCoord;
63
65
67 const typename Base::BottomTensorView& bottom_tensor_view,
68 const typename Base::WindowLengths& window_lengths,
69 const typename Base::BottomTensorIndex& window_origin,
70 const typename Base::TileDstr& tile_distribution)
72 {
73
74 this->window_origin_ = window_origin;
75 this->window_lengths_ = window_lengths;
76 this->bottom_tensor_view_ = bottom_tensor_view;
78 const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
82
83 typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
84 window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
85
86 const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
87 bottom_tensor_view.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
88
89 // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
90 // future load/store() calls (might allocate more registers)
91 using Traits = typename Base::Traits;
92 using SFC_Ys = typename Traits::SFC_Ys;
93
94 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
95 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
96 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
97
98 constexpr auto idx_diff_ys =
99 SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
100
101 constexpr auto idx_diff_ps_ys = container_concat(
102 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
103 idx_diff_ys);
104
106 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
107
108 pre_computed_coords_(iCoord) =
109 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
110 });
111 }
112
113 template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
116 {
117 constexpr auto tile_dstr = typename Base::TileDstr{};
120 return dst_tensor;
121 }
122
133 template <typename TileWindow_,
134 typename ElementWise_,
135 index_t i_access_unsupport_ = -1,
136 bool oob_conditional_check = true>
137 CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
138 ElementWise_ elementwise,
141 {
142 constexpr auto tile_dstr = typename Base::TileDstr{};
144 load(dst_tensor,
145 tile_window,
146 elementwise,
149 return dst_tensor;
150 }
151
152 template <typename DistributedTensor,
153 typename TileWindow_,
154 typename ElementWise_,
155 index_t i_access_unsupport_ = -1,
156 bool oob_conditional_check = true>
157 CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
158 const TileWindow_& tile_window,
159 ElementWise_ elementwise,
162 {
163
164 using Traits = typename Base::Traits;
165 using vector_t = typename Traits::vector_t;
166 using SFC_Ys = typename Traits::SFC_Ys;
167
168 constexpr auto tile_dstr = typename Base::TileDstr{};
169 constexpr auto sizeOfTuple = TileWindow_::size();
170 // loop over thread tensor space [y0, y1, ...]
171 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
173 auto window_adaptor_thread_coord =
174 tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
175 auto bottom_tensor_thread_coord =
176 tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
177
178 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
179 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
180
181 // data index [y0, y1, ...]
182 constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
183
184 // read from bottom tensor
185 const auto idx_vec_value = generate_tuple(
186 [&](auto jj) {
187 return tile_window[number<jj>{}]
188 .get_bottom_tensor_view()
189 .template get_vectorized_elements<vector_t>(
190 bottom_tensor_thread_coord,
191 0,
193 },
195
196 // write into distributed tensor
197 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
198 constexpr auto idx_ys = generate_tuple(
199 [&](auto jj) {
200 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
201 : idx_ys_start[jj];
202 },
204
205 constexpr index_t d =
206 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
207 Traits::PackedSize;
208
210 [&](auto&&... t) {
211 elementwise(dst_tensor.get_thread_buffer().template at<d>(),
212 t.template get_as<
213 typename Base::DataType>()[j / Traits::PackedSize]...);
214 },
215 idx_vec_value);
216 });
217 // move thread coordinate
218 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
219 {
220 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
221
222 constexpr auto idx_diff_ps_ys = container_concat(
223 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
224 idx_diff_ys);
225
227 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
228 }
229 });
230 });
231 }
232
233 template <typename DistributedTensor,
234 index_t i_access_unsupport_ = -1,
235 bool oob_conditional_check = true>
236 CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
239 {
240 using Traits = typename Base::Traits;
241 using vector_t = typename Traits::vector_t;
242 using SFC_Ys = typename Traits::SFC_Ys;
243
244 constexpr auto tile_dstr = typename Base::TileDstr{};
245
246 // loop over thread tensor space [y0, y1, ...]
247 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
249 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
250 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
251
252 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
253 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
254
255 // data index [y0, y1, ...]
256 constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
257
258 // read from bottom tensor
259 const vector_t vec_value =
260 this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
261 bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
262 // write into distributed tensor
263 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
264 constexpr auto idx_ys = generate_tuple(
265 [&](auto jj) {
266 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
267 : idx_ys_start[jj];
268 },
270
271 constexpr index_t d =
272 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
273 Traits::PackedSize;
274
275 dst_tensor.get_thread_buffer().template at<d>() =
276 vec_value
277 .template get_as<typename Base::DataType>()[j / Traits::PackedSize];
278 });
279 // move thread coordinate
280 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
281 {
282 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
283
284 constexpr auto idx_diff_ps_ys = container_concat(
285 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
286 idx_diff_ys);
287
289 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
290 }
291 });
292 });
293 }
294
295 template <typename DstTile,
296 index_t i_access_unsupport_ = -1,
297 bool oob_conditional_check = true,
298 bool pre_nop = false>
299 CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
302 bool_constant<pre_nop> = {}) const
303 {
304 using Traits = typename Base::Traits;
305 using vector_t = typename Traits::vector_t;
306 using SFC_Ys = typename Traits::SFC_Ys;
307 static constexpr index_t YElementSize =
308 typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
309 static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0);
310 using vectorized_tbuf =
311 array<vector_t, YElementSize / (Traits::PackedSize * Traits::ScalarPerVector)>;
312
313 constexpr auto tile_dstr = typename Base::TileDstr{};
314
315 auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
316
317 // loop over thread tensor space [y0, y1, ...]
318 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
320 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
321 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
322
323 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
324 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
325 constexpr auto pre_nop_ = [&]() {
326 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
327 return bool_constant<true>{};
328 else
329 return bool_constant<false>{};
330 }();
331
332 // data index [y0, y1, ...]
333 constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
334 constexpr index_t d =
335 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
336 Traits::PackedSize;
337 static_assert(d % Traits::ScalarPerVector == 0);
338
339 this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
340 dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
341 bottom_tensor_thread_coord,
342 0 /**/,
344 pre_nop_);
345#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
346 CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
347 asm volatile(
348 ""); // this is starting from rocm-6.2, but same sympton, reuse this flag
349#endif
350 // move thread coordinate
351 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
352 {
353 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
354
355 constexpr auto idx_diff_ps_ys = container_concat(
356 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
357 idx_diff_ys);
358
360 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
361 }
362 });
363 });
364 }
365
366 // TODO: currently async load only implemented in inline asm
367 template <typename LdsTileWindow_,
368 index_t i_access_unsupport_ = -1,
369 bool oob_conditional_check = true,
370 bool pre_nop = false>
371 CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
374 bool_constant<pre_nop> = {}) const
375 {
376 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
377 // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
378 using LdsDataType = typename LdsTileWindow::DataType;
379 // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
380
381 // issues * warps * lanes
382 static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
383
384 const index_t size_per_buf =
385 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
387 sizeof(LdsDataType);
388
389 const index_t size_per_wave =
390 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
392 sizeof(LdsDataType) -
393 size_per_buf;
394
395 const index_t size_per_issue =
396 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
398 sizeof(LdsDataType) -
399 size_per_buf;
400
401 // Use VALU so the compiler can optimize redundant/repeated computations
402 const index_t m0_init_value =
403 size_per_buf + size_per_wave * get_warp_id(/*ReturnSgpr=*/bool_constant<false>{});
405 amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
406
407 using Traits = typename Base::Traits;
408
409 using vector_t = typename Traits::vector_t;
410 using SFC_Ys = typename Traits::SFC_Ys;
411
412 LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
413
414 // loop over thread tensor space [y0, y1, ...]
415 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
417 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
418 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
419
420 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
421 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
422 constexpr auto pre_nop_ = [&]() {
423 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
424 return bool_constant<true>{};
425 else
426 return bool_constant<false>{};
427 }();
428
429 // read from bottom tensor
430 this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
431 smem, bottom_tensor_thread_coord, 0, pre_nop_);
432
433 // move thread coordinate
434 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
435 {
436 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
437
438 constexpr auto idx_diff_ps_ys = container_concat(
439 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
440 idx_diff_ys);
441
443 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
444
445 m0_inc_with_memory(size_per_issue);
446 }
447 });
448 });
449 }
450
451 template <typename LdsTileWindow_,
452 index_t i_access_unsupport_ = -1,
453 bool oob_conditional_check = true>
454 CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
457 {
458 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
459 using LdsDataType = typename LdsTileWindow::DataType;
460 using Traits = typename Base::Traits;
461
462 using vector_t = typename Traits::vector_t;
463 using SFC_Ys = typename Traits::SFC_Ys;
464
465 // Precompute invariant values outside loops
466 const auto window_origin = lds_tile.get_window_origin();
467 const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
468 const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
469 auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
470
471 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
472 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
473 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
474
475 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
476 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
477
478 // Use precomputed window origin
479 auto lds_bottom_tensor_thread_idx =
480 window_origin + window_adaptor_thread_coord.get_bottom_index();
481
482 // Use precomputed tensor descriptor
483 const auto lds_coord =
484 make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
485
486 // Calculate SMEM address using base pointer
487 CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
488
489 // Write into bottom tensor
490 this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
491 smem,
492 bottom_tensor_thread_coord,
493 number<0>{},
495
496 // Move thread coordinate if not last access
497 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
498 {
499 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
500 constexpr auto idx_diff_ps_ys = container_concat(
501 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
502 idx_diff_ys);
503
505 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
506 }
507 });
508 });
509 }
510
511 template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
513 {
514 constexpr auto tile_dstr = typename Base::TileDstr{};
516 this->template load_transpose<Policy>(
518 return dst_tensor;
519 }
520
521 template <typename Policy,
522 typename DistributedTensor,
523 index_t i_access_unsupport_ = -1,
524 bool oob_conditional_check = true>
525 CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor,
528 {
529 using Traits = typename Base::Traits;
530 using vector_t = typename Traits::vector_t;
531 using SFC_Ys = typename Traits::SFC_Ys;
532
533 constexpr auto tile_dstr = typename Base::TileDstr{};
534
535 constexpr auto group_func = Policy::group_func;
536
537 // loop over thread tensor space [y0, y1, ...]
538 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
540 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
541 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
542
543 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
544 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
545
546 // data index [y0, y1, ...]
547 constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
548
549 // read from bottom tensor
550 const vector_t vec_value =
552 .template get_transpose_vectorized_elements<vector_t>(
553 bottom_tensor_thread_coord, 0);
554 // write into distributed tensor
555 static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
556 constexpr auto orig_idx_ys = generate_tuple(
557 [&](auto jj) {
558 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
559 : idx_ys_start[jj];
560 },
562
563 constexpr auto grouped_idx_ys = group_func(orig_idx_ys);
564
565 constexpr index_t linear_distributed_index =
566 tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
567
568 dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
569 vec_value.template get_as<typename Base::DataType>()[j];
570 });
571 // move thread coordinate
572 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
573 {
574 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
575
576 constexpr auto idx_diff_ps_ys = container_concat(
577 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
578 idx_diff_ys);
579
581 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
582 }
583 });
584 });
585 }
586
587 template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
589 typename Base::TileDstr>& dstr_tensor,
592 {
593 using Traits = typename Base::Traits;
594
595 using vector_t = typename Traits::vector_t;
596 using SFC_Ys = typename Traits::SFC_Ys;
597
598 constexpr auto tile_dstr = typename Base::TileDstr{};
599
600 // loop over thread tensor space [y0, y1, ...]
601 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
602 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
603 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
604
605 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
606 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
607
608 // data index [y0, y1, ...]
609 constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
610
611 // read from distributed tensor
612 // vector_type_t vec;
613 vector_t vec_value;
614
615 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
616 constexpr auto idx_ys = generate_tuple(
617 [&](auto jj) {
618 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
619 : idx_ys_start[jj];
620 },
622
623 constexpr index_t d =
624 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
625 Traits::PackedSize;
626
627 vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
628 dstr_tensor.get_thread_buffer().template at<d>();
629 });
630
631 // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
632
633 // write into bottom tensor
634 this->get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
635 bottom_tensor_thread_coord,
636 0,
637 vec_value,
639
640 // move thread coordinate
641 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
642 {
643 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
644
645 constexpr auto idx_diff_ps_ys = container_concat(
646 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
647 idx_diff_ys);
648
650 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
651 }
652 });
653 });
654 }
655
656 template <index_t i_access_unsupport_ = -1>
657 CK_TILE_DEVICE void
659 dstr_tensor,
661 {
662 using Traits = typename Base::Traits;
663
664 using vector_t = typename Traits::vector_t;
665 using SFC_Ys = typename Traits::SFC_Ys;
666
667 constexpr auto tile_dstr = typename Base::TileDstr{};
668 static constexpr bool oob_conditional_check = true;
669
670 // loop over thread tensor space [y0, y1, ...]
671 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
673 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
674 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
675
676 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
677 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
678
679 // data index [y0, y1, ...]
680 constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
681
682 // read from distributed tensor
683 vector_t vec_value;
684 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
685 constexpr auto idx_ys = generate_tuple(
686 [&](auto jj) {
687 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
688 : idx_ys_start[jj];
689 },
691 constexpr index_t d =
692 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
693 Traits::PackedSize;
694 vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
695 dstr_tensor.get_thread_buffer().template at<d>();
696 });
697
698 // write into bottom tensor
700 .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
701 bottom_tensor_thread_coord, 0, vec_value);
702
703 // move thread coordinate
704 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
705 {
706 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
707
708 constexpr auto idx_diff_ps_ys = container_concat(
709 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
710 idx_diff_ys);
711
713 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
714 }
715 });
716 });
717 }
718
719 template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
720 CK_TILE_DEVICE void
722 dstr_tensor,
725 {
726 using Traits = typename Base::Traits;
727
728 using vector_t = typename Traits::vector_t;
729 using SFC_Ys = typename Traits::SFC_Ys;
730
731 constexpr auto tile_dstr = typename Base::TileDstr{};
732
733 // loop over thread tensor space [y0, y1, ...]
734 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
736 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
737 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
738
739 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
740 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
741
742 // data index [y0, y1, ...]
743 constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
744
745 // read from distributed tensor
746 vector_t vec_value;
747
748 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
749 constexpr auto idx_ys = generate_tuple(
750 [&](auto jj) {
751 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
752 : idx_ys_start[jj];
753 },
755
756 constexpr index_t d =
757 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
758 Traits::PackedSize;
759
760 vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
761 dstr_tensor.get_thread_buffer().template at<d>();
762 });
763
764 // write into bottom tensor
765 this->get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
766 bottom_tensor_thread_coord,
767 0,
768 vec_value,
770
771 // move thread coordinate
772 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
773 {
774 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
775
776 constexpr auto idx_diff_ps_ys = container_concat(
777 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
778 idx_diff_ys);
779
781 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
782 }
783 });
784 });
785 }
786
787 template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
788 CK_TILE_DEVICE void
790 dstr_tensor,
793 bool_constant<pre_nop> = {}) const
794 {
795 using Traits = typename Base::Traits;
796
797 using vector_t = typename Traits::vector_t;
798 using SFC_Ys = typename Traits::SFC_Ys;
799
800 constexpr auto tile_dstr = typename Base::TileDstr{};
801
802 // loop over thread tensor space [y0, y1, ...]
803 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
805 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
806 auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
807
808 static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
809 constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
810
811 // data index [y0, y1, ...]
812 constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
813
814 // read from distributed tensor
815 vector_t vec_value;
816
817 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
818 constexpr auto idx_ys = generate_tuple(
819 [&](auto jj) {
820 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
821 : idx_ys_start[jj];
822 },
824
825 constexpr index_t d =
826 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
827 Traits::PackedSize;
828
829 vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
830 dstr_tensor.get_thread_buffer().template at<d>();
831 });
832
833 // write into bottom tensor
834 this->get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
835 bottom_tensor_thread_coord,
836 0,
837 vec_value,
840
841 // move thread coordinate
842 if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
843 {
844 constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
845
846 constexpr auto idx_diff_ps_ys = container_concat(
847 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
848 idx_diff_ys);
849
851 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
852 }
853 });
854 });
855 }
856
857 // Custom move behavior
859 {
860 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
861 move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
862 pre_computed_coords_(iCoord)(I1),
863 step);
864 });
865 }
866
868 {
869 // TODO: this use less register for FA, but more register for GEMM
870 // need investigation
871 const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
872 this->tile_dstr_.get_ps_ys_to_xs_adaptor(),
875
876 typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
877 this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
878
879 const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
880 this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
881
882 // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
883 // future load/store() calls (might allocate more registers)
884 using Traits = typename Base::Traits;
885 using SFC_Ys = typename Traits::SFC_Ys;
886
887 static_for<0, NumCoord, 1>{}([&](auto iCoord) {
888 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
889 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
890
891 constexpr auto idx_diff_ys =
892 SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
893
894 constexpr auto idx_diff_ps_ys = container_concat(
895 generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
896 idx_diff_ys);
897
899 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
900
901 pre_computed_coords_(iCoord) =
902 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
903 });
904 }
905
906 // this contains:
907 // per-thread coordinate for window adaptor
908 // per-thread coordinate for bottom tensor
911};
912
913// TODO: use strategy
914template <typename TensorView_,
915 typename WindowLengths_,
916 typename StaticTileDistribution_,
917 index_t NumCoord = 1>
918CK_TILE_DEVICE constexpr auto
919make_tile_window(const TensorView_& tensor_view,
920 const WindowLengths_& window_lengths,
921 const multi_index<TensorView_::get_num_of_dimension()>& origin,
922 const StaticTileDistribution_& tile_distribution,
923 number<NumCoord> = {})
924{
925 return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
928 NumCoord>{
929 tensor_view, window_lengths, origin, tile_distribution};
930}
931
932// this version can't be called in a constexpr context
933template <typename TensorView_,
934 typename WindowLengths_,
935 typename StaticTileDistribution_,
936 index_t NumCoord = 1>
939 const WindowLengths_& window_lengths,
940 const multi_index<TensorView_::get_num_of_dimension()>& origin,
941 const StaticTileDistribution_& tile_distribution,
942 number<NumCoord> = {})
943{
944 auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
947 NumCoord>{
948 tensor_view, window_lengths, origin, tile_distribution};
949 w.init_raw();
950 return w;
951}
952
953template <typename TensorView_,
954 typename WindowLengths_,
955 typename StaticTileDistribution_,
956 index_t NumCoord>
959 WindowLengths_,
960 StaticTileDistribution_,
961 NumCoord>& window,
962 const typename tile_window_with_static_distribution<TensorView_,
963 WindowLengths_,
964 StaticTileDistribution_,
965 NumCoord>::BottomTensorIndex& step)
966{
967 window.move(step);
968}
969
970template <typename TensorView_,
971 typename WindowLengths_,
972 typename StaticTileDistribution_,
973 index_t NumCoord>
976 WindowLengths_,
977 StaticTileDistribution_,
978 NumCoord>>& window,
979 const typename tile_window_with_static_distribution<TensorView_,
980 WindowLengths_,
981 StaticTileDistribution_,
982 NumCoord>::BottomTensorIndex& step)
983{
984 using T = tuple<tile_window_with_static_distribution<TensorView_,
985 WindowLengths_,
986 StaticTileDistribution_,
987 NumCoord>>;
988
989 static constexpr auto N = T::size();
990 static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
991}
992
993template <typename TileWindowWithStaticDistributionType,
994 typename StepType,
995 typename std::enable_if_t<
996 is_detected<is_tuple, TileWindowWithStaticDistributionType>::value>* = nullptr>
997CK_TILE_DEVICE void move_tile_window(TileWindowWithStaticDistributionType& window, StepType& step)
998{
999 static constexpr auto N = TileWindowWithStaticDistributionType::size();
1000 static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
1001}
1002
1011template <typename BottomTensorView_, typename WindowLengths_>
1013 : public tile_window_base<tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>,
1014 BottomTensorView_,
1015 WindowLengths_>
1016{
1017 using Base =
1019 BottomTensorView_,
1020 WindowLengths_>;
1021
1023
1025 const typename Base::BottomTensorView& bottom_tensor_view,
1026 const typename Base::WindowLengths& window_lengths,
1027 const typename Base::BottomTensorIndex& window_origin)
1028 {
1029 this->window_origin_ = window_origin;
1030 this->window_lengths_ = window_lengths;
1031 this->bottom_tensor_view_ = bottom_tensor_view;
1032 }
1033
1047 template <typename DataType>
1049 index_t end_i,
1050 index_t start_j,
1051 index_t end_j,
1052 const char* label = "") const
1053 {
1054 const auto& tensor_view = this->get_bottom_tensor_view();
1055 const auto window_origin = this->get_window_origin();
1056
1057 printf("%s Window Range [%d:%d, %d:%d] (origin: %d, %d):\n",
1058 label,
1059 start_i,
1060 end_i - 1,
1061 start_j,
1062 end_j - 1,
1063 window_origin[0],
1064 window_origin[1]);
1065
1066 for(index_t i = start_i; i < end_i; i++)
1067 {
1068 for(index_t j = start_j; j < end_j; j++)
1069 {
1070 // Create coordinate for this element relative to window origin
1071 auto coord =
1073 make_tuple(window_origin[0] + i, window_origin[1] + j));
1074
1075 // Get the element using thread buffer type directly
1076 using ThreadBuf = thread_buffer<DataType, 2>;
1077 auto buf = tensor_view.template get_vectorized_elements<ThreadBuf>(coord, 0);
1078 auto value = buf.at(number<0>{}); // Extract first element from thread buffer
1079 printf(" %s[%d,%d] = %f", label, i, j, type_convert<float>(value));
1080 }
1081 printf("\n");
1082 }
1083 printf("\n");
1084 }
1085};
1086
1087template <typename TensorView_, typename WindowLengths_>
1088CK_TILE_DEVICE constexpr auto
1090 const WindowLengths_& window_lengths,
1091 const multi_index<TensorView_::get_num_of_dimension()>& origin)
1092{
1094 "wrong! lengths should be static");
1095
1098 tensor_view, window_lengths, origin};
1099}
1100
1101// duplicate tile window and replace its origin
1102template <typename TensorView, typename WindowLengths>
1103CK_TILE_DEVICE constexpr auto
1105 const multi_index<TensorView::get_num_of_dimension()>& origin)
1106{
1108 tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), origin};
1109}
1110
1111template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
1112CK_TILE_DEVICE constexpr auto
1114 const multi_index<TensorView::get_num_of_dimension()>& origin,
1115 const StaticTileDistribution& tile_distribution)
1116{
1117 return make_tile_window(tile_window.get_bottom_tensor_view(),
1118 tile_window.get_window_lengths(),
1119 origin,
1121}
1122
1123template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
1124CK_TILE_DEVICE constexpr auto
1126 const StaticTileDistribution& tile_distribution)
1127{
1128 return make_tile_window(tile_window.get_bottom_tensor_view(),
1129 tile_window.get_window_lengths(),
1130 tile_window.get_window_origin(),
1132}
1133
1134template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
1135CK_TILE_DEVICE constexpr auto
1137 const StaticTileDistribution& tile_distribution)
1138{
1139 auto w = make_tile_window(tile_window.get_bottom_tensor_view(),
1140 tile_window.get_window_lengths(),
1141 tile_window.get_window_origin(),
1143 w.init_raw();
1144 return w;
1145}
1146
1147template <typename TensorView_, typename WindowLengths_>
1155
1156template <typename NewTensorView_,
1157 typename OldTensorView_,
1158 typename WindowLengths_,
1159 typename StaticTileDistribution_,
1160 index_t NumCoord = 1>
1161CK_TILE_DEVICE auto
1162replace_bottom_tensor_view(const NewTensorView_& new_tensor_view,
1163 const tile_window_with_static_distribution<OldTensorView_,
1164 WindowLengths_,
1165 StaticTileDistribution_,
1166 NumCoord>& tile_window)
1167{
1168 return make_tile_window(new_tensor_view,
1169 tile_window.get_window_lengths(),
1170 tile_window.get_window_origin(),
1171 tile_window.get_tile_distribution());
1172}
1173
1174template <typename NewTensorView_, typename OldTensorView_, typename WindowLengths_>
1176 const NewTensorView_& new_tensor_view,
1178{
1179 return make_tile_window(
1180 new_tensor_view, tile_window.get_window_lengths(), tile_window.get_window_origin());
1181}
1182
1190template <typename T>
1192{
1193};
1194
1203template <typename BottomTensorView_,
1204 typename WindowLengths_,
1205 typename StaticTileDistribution_,
1206 index_t NumCoord>
1208 tile_window_with_static_distribution<BottomTensorView_,
1209 WindowLengths_,
1210 StaticTileDistribution_,
1211 NumCoord>> : std::true_type
1212{
1213};
1214
1222template <typename T>
1225
1233template <typename T>
1235{
1236};
1237
1244template <typename BottomTensorView_, typename WindowLengths_>
1246 tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>> : std::true_type
1247{
1248};
1249
1257template <typename T>
1260
1261} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_LDS_ADDR
Definition config.hpp:58
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition tile_distribution.hpp:22
Definition tile/core/algorithm/cluster_descriptor.hpp:13
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition tile/core/container/tuple.hpp:526
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
constexpr bool is_tile_window_with_static_distribution_v
Helper variable template to check if a type is a tile window with static distribution.
Definition tile_window.hpp:1223
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_ &new_tensor_view, const tile_scatter_gather< OldTensorView_, WindowLengths_, StaticTileDistribution_, StaticPageIndexArray_, StaticValidArray_, HsGatherDim, NumCoord > &tile_window)
Definition tile_scatter_gather.hpp:1043
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_DEVICE auto make_tile_window_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, number< NumCoord >={})
Definition tile_window.hpp:938
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X &x, const Ys &... ys)
Definition tile/core/container/container_helper.hpp:363
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_HOST_DEVICE constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition tensor_coordinate.hpp:72
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition tensor_adaptor_coordinate.hpp:55
constexpr bool is_tile_window_with_static_lengths_v
Helper variable template to check if a type is a tile window with static lengths.
Definition tile_window.hpp:1258
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition utility.hpp:19
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition utility.hpp:25
CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition tensor_coordinate.hpp:60
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
static constexpr bool value
Definition type_traits.hpp:77
Type trait to determine if a type is a tile window with static distribution.
Definition tile_window.hpp:1192
Type trait to determine if a type is a tile window with static lengths.
Definition tile_window.hpp:1235
Definition static_distributed_tensor.hpp:21
CK_TILE_HOST_DEVICE constexpr const auto & get_thread_buffer() const
Definition static_distributed_tensor.hpp:58
Definition tile/core/utility/functional.hpp:43
Definition tensor_view.hpp:41
CK_TILE_HOST_DEVICE constexpr auto & get_tensor_descriptor() const
Definition tensor_view.hpp:61
Definition tile/core/utility/debug.hpp:67
Definition tile_distribution.hpp:72
CK_TILE_HOST_DEVICE constexpr const auto & get_ps_ys_to_xs_adaptor() const
Definition tile_distribution.hpp:126
This class provides description of tile windowed view on the device memory.
Definition tile_window_base.hpp:31
BottomTensorView bottom_tensor_view_
Definition tile_window_base.hpp:85
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition tile_window_base.hpp:36
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const
Definition tile_window_base.hpp:47
BottomTensorIndex window_origin_
Definition tile_window_base.hpp:79
CK_TILE_DEVICE constexpr auto get_window_lengths() const
Definition tile_window_base.hpp:46
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition tile_window_base.hpp:67
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition tile_window_base.hpp:33
remove_cvref_t< WindowLengths_ > WindowLengths
Definition tile_window_base.hpp:34
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition tile_window_base.hpp:43
WindowLengths window_lengths_
Definition tile_window_base.hpp:81
This class provides tile (windowed) view and access to the device memory.
Definition tile_window.hpp:46
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}) const
Definition tile_window.hpp:658
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex &step)
Definition tile_window.hpp:858
CK_TILE_DEVICE auto load_transpose() const
Definition tile_window.hpp:512
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex &)
Definition tile_window.hpp:867
array< tuple< typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition tile_window.hpp:910
CK_TILE_DEVICE void update_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition tile_window.hpp:789
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_window.hpp:454
CK_TILE_DEVICE auto load_transpose(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_window.hpp:525
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_window.hpp:157
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_window.hpp:236
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_window.hpp:114
static constexpr auto I0
Definition tile_window.hpp:56
CK_TILE_DEVICE constexpr tile_window_with_static_distribution()=default
CK_TILE_DEVICE auto load(const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Load tile with elementwise function.
Definition tile_window.hpp:137
CK_TILE_DEVICE void load_raw(DstTile &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition tile_window.hpp:299
static constexpr auto I1
Definition tile_window.hpp:57
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition tile_window.hpp:371
tile_window_with_tile_dstr_base< tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, StaticTileDistribution_, NumCoord >, BottomTensorView_, WindowLengths_, StaticTileDistribution_ > Base
Definition tile_window.hpp:47
CK_TILE_DEVICE void update(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_window.hpp:721
static constexpr index_t NumAccessPerCoord
Definition tile_window.hpp:62
CK_TILE_DEVICE constexpr tile_window_with_static_distribution(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin, const typename Base::TileDstr &tile_distribution)
Definition tile_window.hpp:66
CK_TILE_DEVICE void store(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition tile_window.hpp:588
This class provides description of tile windowed view on the device memory.
Definition tile_window.hpp:1016
CK_TILE_DEVICE constexpr tile_window_with_static_lengths(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin)
Definition tile_window.hpp:1024
tile_window_base< tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ >, BottomTensorView_, WindowLengths_ > Base
Definition tile_window.hpp:1017
CK_TILE_DEVICE void print_tile_window_range(index_t start_i, index_t end_i, index_t start_j, index_t end_j, const char *label="") const
Definition tile_window.hpp:1048
CK_TILE_DEVICE constexpr tile_window_with_static_lengths()=default
Definition tile_window_base.hpp:94
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition tile_window_base.hpp:129
Definition tile/core/container/tuple.hpp:192