threadwise_tensor_slice_transfer_v3r1_dequant.hpp Source File

threadwise_tensor_slice_transfer_v3r1_dequant.hpp Source File#

Composable Kernel: threadwise_tensor_slice_transfer_v3r1_dequant.hpp Source File
threadwise_tensor_slice_transfer_v3r1_dequant.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck {
14
15namespace detail {
16// TODO: How to fix this? It uses an struct instead of lambda because lambda
17// doesn't have constructor
18template <index_t SrcVectorDim,
19 index_t SrcScalarPerVector,
20 index_t DstVectorDim,
21 index_t DstScalarPerVector>
23{
24 __host__ __device__ constexpr auto operator()(index_t i) const
25 {
26 if(i == SrcVectorDim && i == DstVectorDim)
27 {
28 return math::lcm(SrcScalarPerVector, DstScalarPerVector);
29 }
30 else if(i == SrcVectorDim)
31 {
32 return SrcScalarPerVector;
33 }
34 else if(i == DstVectorDim)
35 {
36 return DstScalarPerVector;
37 }
38 else
39 {
40 return 1;
41 }
42 }
43};
44
45} // namespace detail
46
47// Assume:
48// 1. src_desc and dst_desc are not known at compile-time
49// 2. SrcBuffer and DstBuffer are DynamicBuffer
50// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
51// 4. Use thread buffer
52// 5. Dequantization happened between read and write.
53template <typename SliceLengths,
54 typename ScaleSliceLengths,
55 typename SrcElementwiseOperation,
56 typename ScaleElementwiseOperation,
57 typename DstElementwiseOperation,
59 typename SrcData,
60 typename ScaleData,
61 typename DstData,
62 typename SrcDesc,
63 typename ScaleDesc,
64 typename DstDesc,
65 typename SrcDimAccessOrder,
66 typename DstDimAccessOrder,
67 index_t SrcVectorDim,
68 index_t DstVectorDim,
69 index_t SrcScalarPerVector,
70 index_t ScaleScalarPerVector,
71 index_t DstScalarPerVector,
72 index_t SrcScalarStrideInVector,
73 index_t ScaleScalarStrideInVector,
74 index_t DstScalarStrideInVector,
75 bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
76 // RunRead(), will be fused with MoveSrcSliceWindow to
77 // save addr computation
78 bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
79 // RunWrite(), will be fused with MoveDstSliceWindow to
80 // save addr computation
81 index_t NumThreadScratch = 1>
83{
84 static constexpr index_t nDim = SliceLengths::Size();
86
87 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
88 using ScaleCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
89 using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
90
91 static constexpr auto I0 = Number<0>{};
92
94 const SrcDesc& src_desc,
95 const Index& src_slice_origin,
96 const SrcElementwiseOperation& src_element_op,
97 const ScaleDesc& scale_desc,
98 const Index& scale_slice_origin,
99 const ScaleElementwiseOperation& scale_element_op,
100 const DstDesc& dst_desc,
101 const Index& dst_slice_origin,
102 const DstElementwiseOperation& dst_element_op)
103 : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
104 scale_coord_(make_tensor_coordinate(scale_desc, scale_slice_origin)),
105 dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
106 src_element_op_(src_element_op),
107 scale_element_op_(scale_element_op),
108 dst_element_op_(dst_element_op)
109 {
110 }
111
112 __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
113 {
114 src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
115 }
116
117 __device__ void SetScaleSliceOrigin(const ScaleDesc& scale_desc,
118 const Index& scale_slice_origin_idx)
119 {
120 scale_coord_ = make_tensor_coordinate(scale_desc, scale_slice_origin_idx);
121 }
122
123 __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
124 {
125 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
126 }
127
128 template <typename SrcBuffer, index_t ThreadScratchId = 0>
129 __device__ void RunRead(const SrcDesc& src_desc,
130 const SrcBuffer& src_buf,
132 {
133 static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
134 SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
135 "wrong!");
136
137 static_assert(
139 "wrong! SrcBuffer and SrcData data type are inconsistent");
140
141 // scalar per access on each dim
142 // TODO: don't use lambda_scalar_per_access
143 constexpr auto src_scalar_per_access = generate_sequence(
145
146 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
147
148 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
149
150 constexpr auto ordered_src_access_lengths =
151 container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
152
153 // make forward steps
154 const auto src_forward_steps = generate_tuple(
155 [&](auto i) {
156 Index forward_step_idx;
157
158 static_for<0, nDim, 1>{}([&](auto j) {
159 forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
160 });
161
162 return make_tensor_coordinate_step(src_desc, forward_step_idx);
163 },
164 Number<nDim>{});
165
166 // make backward steps
167 const auto src_backward_steps = generate_tuple(
168 [&](auto i) {
169 Index backward_step_idx;
170
171 static_for<0, nDim, 1>{}([&](auto j) {
172 backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
173 });
174
175 return make_tensor_coordinate_step(src_desc, backward_step_idx);
176 },
177 Number<nDim>{});
178
179 // loop over tensor and copy
180 static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
181 // judge move forward or move backward
182 constexpr auto forward_sweep = [&]() {
184
185 forward_sweep_(I0) = true;
186
187 static_for<1, nDim, 1>{}([&](auto i) {
188 index_t tmp = ordered_src_access_idx[I0];
189
190 static_for<1, i, 1>{}([&](auto j) {
191 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
192 });
193
194 forward_sweep_(i) = tmp % 2 == 0;
195 });
196
197 return forward_sweep_;
198 }();
199
200 // calculate src data index
201 constexpr auto src_data_idx = [&]() {
202 Index ordered_idx;
203
204 static_for<0, nDim, 1>{}([&](auto i) {
205 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
206 : ordered_src_access_lengths[i] - 1 -
207 ordered_src_access_idx[i];
208 });
209
210 return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
211 src_scalar_per_access;
212 }();
213
214 constexpr auto src_data_idx_seq = generate_sequence_v2(
215 [&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
216
217 const bool is_src_valid =
219
221 using src_vector_t = typename src_vector_type::type;
222
223 // copy data from src_buf into src_vector_container
224 auto src_vector_container = src_vector_type{
225 src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
226
227 // copy data from src_vector_container into src_thread_scratch_
228 src_thread_scratch_tuple_(thread_scratch_id)
229 .template SetAsType<src_vector_t>(
230 src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
231
232 constexpr auto move_on_dim = [&]() constexpr {
234
235 static_for<0, nDim, 1>{}([&](auto i) {
236 move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
237
238 static_for<i + 1, nDim, 1>{}([&](auto j) {
239 move_on_dim_(i) &=
240 ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
241 });
242 });
243
244 return move_on_dim_;
245 }();
246
247 // move src coord
248 static_for<0, nDim, 1>{}([&](auto i) {
249 if constexpr(move_on_dim[i])
250 {
251 if constexpr(forward_sweep[i])
252 {
254 src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
255 }
256 else
257 {
259 src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
260 }
261 }
262 });
263 });
264
265 // move src coordinate back to slice origin (or not)
266 if constexpr(SrcResetCoordinateAfterRun)
267 {
268 const auto src_reset_step =
270
271 move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
272 }
273 }
274
275 template <typename ScaleBuffer>
276 __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf)
277 {
278 static_assert(ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
279 ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
280 "wrong!");
281
282 static_assert(
284 "wrong! ScaleBuffer and ScaleData data type are inconsistent");
285
286 // scalar per access on each dim
287 // TODO: don't use lambda_scalar_per_access
288 constexpr auto scale_scalar_per_access = generate_sequence(
290
291 constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access;
292
293 constexpr auto scale_dim_access_order = SrcDimAccessOrder{};
294
295 constexpr auto ordered_scale_access_lengths =
296 container_reorder_given_new2old(scale_access_lengths, scale_dim_access_order);
297
298 // make forward steps
299 const auto scale_forward_steps = generate_tuple(
300 [&](auto i) {
301 Index forward_step_idx;
302
303 static_for<0, nDim, 1>{}([&](auto j) {
304 forward_step_idx(j) = (i.value == j.value) ? scale_scalar_per_access[i] : 0;
305 });
306
307 return make_tensor_coordinate_step(scale_desc, forward_step_idx);
308 },
309 Number<nDim>{});
310
311 // make backward steps
312 const auto scale_backward_steps = generate_tuple(
313 [&](auto i) {
314 Index backward_step_idx;
315
316 static_for<0, nDim, 1>{}([&](auto j) {
317 backward_step_idx(j) = (i.value == j.value) ? -scale_scalar_per_access[i] : 0;
318 });
319
320 return make_tensor_coordinate_step(scale_desc, backward_step_idx);
321 },
322 Number<nDim>{});
323
324 // loop over tensor and copy
325 static_ford<decltype(ordered_scale_access_lengths)>{}([&](auto ordered_scale_access_idx) {
326 // judge move forward or move backward
327 constexpr auto forward_sweep = [&]() {
329
330 forward_sweep_(I0) = true;
331
332 static_for<1, nDim, 1>{}([&](auto i) {
333 index_t tmp = ordered_scale_access_idx[I0];
334
335 static_for<1, i, 1>{}([&](auto j) {
336 tmp = tmp * ordered_scale_access_lengths[j] + ordered_scale_access_idx[j];
337 });
338
339 forward_sweep_(i) = tmp % 2 == 0;
340 });
341
342 return forward_sweep_;
343 }();
344
345 // calculate scale data index
346 constexpr auto scale_data_idx = [&]() {
347 Index ordered_idx;
348
349 static_for<0, nDim, 1>{}([&](auto i) {
350 ordered_idx(i) = forward_sweep[i] ? ordered_scale_access_idx[i]
351 : ordered_scale_access_lengths[i] - 1 -
352 ordered_scale_access_idx[i];
353 });
354
355 return container_reorder_given_old2new(ordered_idx, scale_dim_access_order) *
356 scale_scalar_per_access;
357 }();
358
359 constexpr auto scale_data_idx_seq =
360 generate_sequence_v2([&](auto i) { return Number<scale_data_idx[i]>{}; },
361 Number<scale_data_idx.Size()>{});
362
364 scale_desc, scale_coord_);
365
367 using scale_vector_t = typename scale_vector_type::type;
368
369 // copy data from scale_buf into scale_vector_container
370 auto scale_vector_container = scale_vector_type{
371 scale_buf.template Get<scale_vector_t>(scale_coord_.GetOffset(), is_scale_valid)};
372
373 // copy data from scale_vector_container into scale_thread_scratch_
374 scale_thread_scratch_.template SetAsType<scale_vector_t>(
375 scale_data_idx_seq, scale_vector_container.template AsType<scale_vector_t>()[I0]);
376
377 constexpr auto move_on_dim = [&]() constexpr {
379
380 static_for<0, nDim, 1>{}([&](auto i) {
381 move_on_dim_(i) =
382 ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1;
383
384 static_for<i + 1, nDim, 1>{}([&](auto j) {
385 move_on_dim_(i) &=
386 ordered_scale_access_idx[j] == ordered_scale_access_lengths[j] - 1;
387 });
388 });
389
390 return move_on_dim_;
391 }();
392
393 // move scale coord
394 static_for<0, nDim, 1>{}([&](auto i) {
395 if constexpr(move_on_dim[i])
396 {
397 if constexpr(forward_sweep[i])
398 {
399 move_tensor_coordinate(scale_desc,
400 scale_coord_,
401 scale_forward_steps[scale_dim_access_order[i]]);
402 }
403 else
404 {
405 move_tensor_coordinate(scale_desc,
406 scale_coord_,
407 scale_backward_steps[scale_dim_access_order[i]]);
408 }
409 }
410 });
411 });
412
413 // don't need to move scale coordinate back to slice origin
414 /*
415 if constexpr(SrcResetCoordinateAfterRun)
416 {
417 const auto scale_reset_step =
418 make_tensor_coordinate_step(scale_desc, GetScaleCoordinateResetStep());
419
420 move_tensor_coordinate(scale_desc, scale_coord_, scale_reset_step);
421 }
422 */
423 }
424
425 template <index_t ThreadScratchId>
426 __device__ void
428 {
429#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
430 static_ford<SliceLengths>{}([&](auto idx) {
431 // convert from SrcData to DstData here
432 dst_thread_scratch_(idx) =
433 type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
434 });
435#else
436 // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
437 // TODO make this logic more generic for more sub-dword datatype
438 if constexpr(SrcVectorDim != DstVectorDim &&
441 SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
444 SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
445 {
446 // each transpose does
447 // DstScalarPerVector # of src vectors in src_thread_scratch_
448 // SrcScalarPerVector # of dst vectors in dst_thread_scratch_
449 constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
450 constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
451
452 // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
453 // TODO: make this logic generic for all scenario
454 static_assert(SrcVectorDim != DstVectorDim, "wrong");
455
456 constexpr auto src_scalar_step_in_vector = generate_sequence(
458
459 constexpr auto dst_scalar_step_in_vector = generate_sequence(
461
462 constexpr auto scalar_per_access = generate_sequence(
464 SrcScalarPerVector,
465 DstVectorDim,
466 DstScalarPerVector>{},
467 Number<nDim>{});
468
469 constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
470
471 static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
472 constexpr auto data_idx = access_idx * scalar_per_access;
473
474 constexpr auto data_idx_seq = generate_sequence_v2(
475 [&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
476
479
480 // get DstScalarPerVector # of read-only references to src vectors from
481 // src_thread_scratch_
482 const auto src_vector_refs = generate_tie(
483 [&](auto i) -> const src_vector_t& {
484 // i increment corresponds to movement in DstVectorDim
485 return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference(
486 data_idx_seq + i * dst_scalar_step_in_vector);
487 },
489
490 // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
491 auto dst_vector_refs = generate_tie(
492 [&](auto i) -> dst_vector_t& {
493 // i increment corresponds to movement in SrcVectorDim
494 return dst_thread_scratch_.GetVectorTypeReference(
495 data_idx_seq + i * src_scalar_step_in_vector);
496 },
498
499 // do data transpose
501 src_vector_refs, dst_vector_refs);
502 });
503 }
504
505 // Do fast numeric convert
506 constexpr auto scalar_per_access = generate_sequence(
508 SrcScalarPerVector,
509 DstVectorDim,
510 DstScalarPerVector>{},
511 Number<nDim>{});
512
513 constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
514
516 using src_vector_t = typename src_vector_type::type;
517
518 using src_converted_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
519 using src_converted_vector_t = typename src_converted_vector_type::type;
520 // Vector-wise type convert
521 static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
522 auto src_vector_container = src_vector_type{
523 src_thread_scratch_tuple_[thread_scratch_id].template GetAsType<src_vector_t>(
524 access_idx)};
525
526 auto src_converted_vector_container =
527 src_converted_vector_type{fast_numeric_converter(src_vector_container)};
528
529 src_converted_thread_scratch_.template SetAsType<src_converted_vector_t>(
530 access_idx,
531 src_converted_vector_container.template AsType<src_converted_vector_t>()[I0]);
532 });
533
534 // Element-scale operation, expect packed multiplication
535 static_ford<SliceLengths>{}([&](auto idx) {
536 DstData dst_v;
537 constexpr auto scale_idx = Sequence<I0, idx.At(1), I0>{};
538 // printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(),
539 // *(reinterpret_cast<const uint16_t*>(&scale_thread_scratch_[scale_idx])));
540 src_element_op_(dst_v,
541 src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]);
542 dst_thread_scratch_(idx) = dst_v;
543 });
544#endif
545 }
546
547 template <typename DstBuffer, index_t ThreadScratchId = 0>
548 __device__ void RunWrite(const DstDesc& dst_desc,
549 DstBuffer& dst_buf,
551 {
552 // if there is transpose, it's done here
553 // TODO move this elsewhere
555
556 static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
557 DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
558 "wrong!");
559
560 static_assert(
562 "wrong! SrcBuffer or DstBuffer data type is wrong");
563
564 // src scalar per access on each dim
565 // TODO: don't use this
566 constexpr auto dst_scalar_per_access = generate_sequence(
568
569 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
570
571 constexpr auto dst_dim_access_order = DstDimAccessOrder{};
572
573 constexpr auto ordered_dst_access_lengths =
574 container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
575
576 // make forward steps
577 const auto dst_forward_steps = generate_tuple(
578 [&](auto i) {
579 Index forward_step_idx;
580
581 static_for<0, nDim, 1>{}([&](auto j) {
582 forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
583 });
584
585 return make_tensor_coordinate_step(dst_desc, forward_step_idx);
586 },
587 Number<nDim>{});
588
589 // make backward steps
590 const auto dst_backward_steps = generate_tuple(
591 [&](auto i) {
592 Index backward_step_idx;
593
594 static_for<0, nDim, 1>{}([&](auto j) {
595 backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
596 });
597
598 return make_tensor_coordinate_step(dst_desc, backward_step_idx);
599 },
600 Number<nDim>{});
601
602 // loop over tensor and copy
603 static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
604 // judge move forward or move backward
605 constexpr auto forward_sweep = [&]() {
607
608 forward_sweep_(I0) = true;
609
610 static_for<1, nDim, 1>{}([&](auto i) {
611 index_t tmp = ordered_dst_access_idx[I0];
612
613 static_for<1, i, 1>{}([&](auto j) {
614 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
615 });
616
617 forward_sweep_(i) = tmp % 2 == 0;
618 });
619
620 return forward_sweep_;
621 }();
622
623 // calculate dst data index
624 constexpr auto dst_data_idx = [&]() {
625 Index ordered_idx;
626
627 static_for<0, nDim, 1>{}([&](auto i) {
628 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
629 : ordered_dst_access_lengths[i] - 1 -
630 ordered_dst_access_idx[i];
631 });
632
633 return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
634 dst_scalar_per_access;
635 }();
636
637 constexpr auto dst_data_idx_seq = generate_sequence_v2(
638 [&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
639
640 const bool is_dst_valid =
642
644 using dst_vector_t = typename dst_vector_type::type;
645
646 // copy data from dst_thread_scratch_ into dst_vector_container
647 auto dst_vector_container = dst_vector_type{
648 dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
649
650 static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
651 DstData dst_v;
652
653 // apply DstElementwiseOperation
654 dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);
655
656 dst_vector_container.template AsType<DstData>()(i) = dst_v;
657 });
658
659 // copy data from dst_vector_container to dst_buf
660 dst_buf.template Set<dst_vector_t>(
661 dst_coord_.GetOffset(),
662 is_dst_valid,
663 dst_vector_container.template AsType<dst_vector_t>()[I0]);
664
665 constexpr auto move_on_dim = [&]() constexpr {
667
668 static_for<0, nDim, 1>{}([&](auto i) {
669 move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
670
671 static_for<i + 1, nDim, 1>{}([&](auto j) {
672 move_on_dim_(i) &=
673 ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
674 });
675 });
676
677 return move_on_dim_;
678 }();
679
680 // move dst coord
681 static_for<0, nDim, 1>{}([&](auto i) {
682 if constexpr(move_on_dim[i])
683 {
684 if constexpr(forward_sweep[i])
685 {
687 dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
688 }
689 else
690 {
692 dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
693 }
694 }
695 });
696 });
697
698 // move dst coordinate back to slice origin (or not)
699 if constexpr(DstResetCoordinateAfterRun)
700 {
701 const auto dst_reset_step =
703
704 move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
705 }
706 }
707
708 __device__ static constexpr auto GetSrcCoordinateResetStep()
709 {
710 // scalar per access on each dim
711 // TODO: don't use lambda_scalar_per_access
712 constexpr auto src_scalar_per_access = generate_sequence(
714
715 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
716
717 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
718
719 constexpr auto ordered_src_access_lengths =
720 container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
721
722 // judge move forward or move backward during the last iteration
723 constexpr auto forward_sweep = [&]() {
725
726 forward_sweep_(I0) = true;
727
728 static_for<1, nDim, 1>{}([&](auto i) {
729 index_t tmp = ordered_src_access_lengths[I0] - 1;
730
731 static_for<1, i, 1>{}([&](auto j) {
732 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
733 });
734
735 forward_sweep_(i) = tmp % 2 == 0;
736 });
737
738 return forward_sweep_;
739 }();
740
741 // calculate src data index after last iteration in RunRead(), if it has not being reset by
742 // RunRead()
743 constexpr auto src_data_idx = [&]() {
744 Index ordered_idx;
745
746 static_for<0, nDim, 1>{}([&](auto i) {
747 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
748 });
749
750 return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
751 src_scalar_per_access;
752 }();
753
754 //
755 constexpr auto reset_src_data_step = [&]() {
756 Index reset_src_data_step_;
757
758 static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
759
760 return reset_src_data_step_;
761 }();
762
763 return reset_src_data_step;
764 }
765
766 __device__ static constexpr auto GetDstCoordinateResetStep()
767 {
768 // scalar per access on each dim
769 // TODO: don't use lambda_scalar_per_access
770 constexpr auto dst_scalar_per_access = generate_sequence(
772
773 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
774
775 constexpr auto dst_dim_access_order = DstDimAccessOrder{};
776
777 constexpr auto ordered_dst_access_lengths =
778 container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
779
780 // judge move forward or move backward during the last iteration
781 constexpr auto forward_sweep = [&]() {
783
784 forward_sweep_(I0) = true;
785
786 static_for<1, nDim, 1>{}([&](auto i) {
787 index_t tmp = ordered_dst_access_lengths[I0] - 1;
788
789 static_for<1, i, 1>{}([&](auto j) {
790 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
791 });
792
793 forward_sweep_(i) = tmp % 2 == 0;
794 });
795
796 return forward_sweep_;
797 }();
798
799 // calculate dst data index after last iteration in RunWrite(), if it has not being reset by
800 // RunWrite()
801 constexpr auto dst_data_idx = [&]() {
802 Index ordered_idx;
803
804 static_for<0, nDim, 1>{}([&](auto i) {
805 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
806 });
807
808 return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
809 dst_scalar_per_access;
810 }();
811
812 //
813 constexpr auto reset_dst_data_step = [&]() {
814 Index reset_dst_data_step_;
815
816 static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
817
818 return reset_dst_data_step_;
819 }();
820
821 return reset_dst_data_step;
822 }
823
824 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
825 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
826 const Index& src_slice_origin_step_idx)
827 {
828 // if src coord was not reset by RunRead(), then need to adjust the step here
829 const auto adjusted_step_idx =
830 SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
831 : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
832
833 // is it OK to construct a new step every time?
834 const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
835
836 move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
837 }
838
839 // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
840 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
841 const Index& dst_slice_origin_step_idx)
842 {
843 // if dst coord was not reset by RunWrite(), then need to adjust the step here
844 const auto adjusted_step_idx =
845 DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
846 : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
847
848 // is it OK to construct a new step every time?
849 const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
850
851 move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
852 }
853
854 __device__ static constexpr auto GetSrcThreadScratchDescriptor()
855 {
856 constexpr auto src_scalar_per_access = generate_sequence(
858
859 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
860
861 constexpr auto src_access_lengths_and_vector_length = container_push_back(
863
864 // 1st stage of transforms
865 constexpr auto desc0 =
866 make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
867
868 // 2nd stage of transforms
869 constexpr auto transforms = generate_tuple(
870 [&](auto i) {
871 if constexpr(i == SrcVectorDim)
872 {
874 make_tuple(src_access_lengths_and_vector_length[i],
875 src_access_lengths_and_vector_length[Number<nDim>{}]));
876 }
877 else
878 {
879 return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
880 }
881 },
882 Number<nDim>{});
883
884 constexpr auto low_dim_idss = generate_tuple(
885 [&](auto i) {
886 if constexpr(i == SrcVectorDim)
887 {
888 return Sequence<i.value, nDim>{};
889 }
890 else
891 {
892 return Sequence<i.value>{};
893 }
894 },
895 Number<nDim>{});
896
897 constexpr auto up_dim_idss =
898 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
899
900 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
901 }
902
903 __device__ static constexpr auto GetScaleThreadScratchDescriptor()
904 {
905
906 constexpr auto scale_scalar_per_access = generate_sequence(
908
909 constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access;
910
911 constexpr auto scale_access_lengths_and_vector_length = container_push_back(
913
914 // 1st stage of transforms
915 constexpr auto desc0 =
916 make_naive_tensor_descriptor_packed(scale_access_lengths_and_vector_length);
917
918 // 2nd stage of transforms
919 constexpr auto transforms = generate_tuple(
920 [&](auto i) {
921 if constexpr(i == SrcVectorDim)
922 {
924 make_tuple(scale_access_lengths_and_vector_length[i],
925 scale_access_lengths_and_vector_length[Number<nDim>{}]));
926 }
927 else
928 {
929 return make_pass_through_transform(scale_access_lengths_and_vector_length[i]);
930 }
931 },
932 Number<nDim>{});
933
934 constexpr auto low_dim_idss = generate_tuple(
935 [&](auto i) {
936 if constexpr(i == SrcVectorDim)
937 {
938 return Sequence<i.value, nDim>{};
939 }
940 else
941 {
942 return Sequence<i.value>{};
943 }
944 },
945 Number<nDim>{});
946
947 constexpr auto up_dim_idss =
948 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
949
950 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
951 }
952
953 __device__ static constexpr auto GetDstThreadScratchDescriptor()
954 {
955 // 1st stage of transforms
956 constexpr auto dst_scalar_per_access = generate_sequence(
958
959 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
960
961 constexpr auto dst_access_lengths_and_vector_length = container_push_back(
963
964 constexpr auto desc0 =
965 make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
966
967 // 2nd stage of transforms
968 constexpr auto transforms = generate_tuple(
969 [&](auto i) {
970 if constexpr(i == DstVectorDim)
971 {
973 make_tuple(dst_access_lengths_and_vector_length[i],
974 dst_access_lengths_and_vector_length[Number<nDim>{}]));
975 }
976 else
977 {
978 return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
979 }
980 },
981 Number<nDim>{});
982
983 constexpr auto low_dim_idss = generate_tuple(
984 [&](auto i) {
985 if constexpr(i == DstVectorDim)
986 {
987 return Sequence<i.value, nDim>{};
988 }
989 else
990 {
991 return Sequence<i.value>{};
992 }
993 },
994 Number<nDim>{});
995
996 constexpr auto up_dim_idss =
997 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
998
999 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
1000 }
1001
1002 private:
1003 static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
1004 static constexpr auto scale_thread_scratch_desc_ =
1005 decltype(GetScaleThreadScratchDescriptor()){};
1006 static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
1007
1008 /*
1009 template <bool kLastDim>
1010 struct ScaleThreadScratchDesc{};
1011 */
1012
1013 // Registers, contain raw data loaded from global buffer
1014 using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
1015 SrcData,
1016 SrcScalarPerVector,
1017 decltype(src_thread_scratch_desc_),
1018 true>;
1019
1020 // Registers, contain fast converted data
1021 using SrcThreadConvertedScratch =
1022 StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
1023 DstData,
1024 SrcScalarPerVector,
1025 decltype(src_thread_scratch_desc_),
1026 true>;
1027
1028 // Registers, contain scale data
1029 using ScaleThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
1030 ScaleData,
1031 ScaleScalarPerVector,
1032 decltype(scale_thread_scratch_desc_),
1033 true>;
1034
1035 // Registers, contain dequantized data
1036 using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
1037 DstData,
1038 DstScalarPerVector,
1039 decltype(dst_thread_scratch_desc_),
1040 true>;
1041
1042 using FastTypeConverter = tensor_operation::element_wise::
1043 FastNumericArrayConverter<SrcData, DstData, SrcScalarPerVector>;
1044
1046 SrcThreadConvertedScratch src_converted_thread_scratch_;
1047 ScaleThreadScratch scale_thread_scratch_;
1048
1049 DstThreadScratch dst_thread_scratch_;
1050 FastTypeConverter fast_numeric_converter;
1051
1052 SrcCoord src_coord_;
1053 ScaleCoord scale_coord_;
1054 DstCoord dst_coord_;
1055 const SrcElementwiseOperation src_element_op_;
1056 const ScaleElementwiseOperation scale_element_op_;
1057 const DstElementwiseOperation dst_element_op_;
1058};
1059
1060} // namespace ck
Definition threadwise_tensor_slice_transfer_util.hpp:15
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
__host__ __device__ constexpr auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition utility/container_helper.hpp:18
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Lds
Definition amd_address_space.hpp:18
@ Global
Definition amd_address_space.hpp:17
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__host__ __device__ constexpr auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition utility/container_helper.hpp:54
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
__host__ __device__ constexpr auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition utility/container_helper.hpp:43
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition utility/sequence.hpp:43
static __device__ constexpr auto GetScaleThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:903
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:123
static __device__ constexpr auto GetDstThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:953
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:825
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:548
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:129
__device__ void SetScaleSliceOrigin(const ScaleDesc &scale_desc, const Index &scale_slice_origin_idx)
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:117
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_dequant(const SrcDesc &src_desc, const Index &src_slice_origin, const SrcElementwiseOperation &src_element_op, const ScaleDesc &scale_desc, const Index &scale_slice_origin, const ScaleElementwiseOperation &scale_element_op, const DstDesc &dst_desc, const Index &dst_slice_origin, const DstElementwiseOperation &dst_element_op)
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:93
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:112
static __device__ constexpr auto GetSrcThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:854
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:840
__device__ void RunScaleRead(const ScaleDesc &scale_desc, const ScaleBuffer &scale_buf)
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:276
static __device__ constexpr auto GetSrcCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:708
static __device__ constexpr auto GetDstCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:766
__device__ void TransferDataFromSrcThreadScratchToDstThreadScratch(Number< ThreadScratchId > thread_scratch_id)
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:427
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:23
__host__ __device__ constexpr auto operator()(index_t i) const
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:24
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition threadwise_tensor_slice_transfer_util.hpp:29
Definition type.hpp:177
Definition functional2.hpp:33
Definition functional3.hpp:97
Definition utility/transpose_vectors.hpp:16