device_elementwise_normalization_impl.hpp Source File

device_elementwise_normalization_impl.hpp Source File#

Composable Kernel: device_elementwise_normalization_impl.hpp Source File
device_elementwise_normalization_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
9#include "ck/utility/math.hpp"
12
20
21// X = Elementwise(input1, input2, input3, ...)
22// Y = Normalization(X, beta, gamma)
23namespace ck {
24template <typename GridwiseElementwiseReduction,
25 typename InDataTypePointerTuple, // Datatype tuple of inputs
26 typename XDataType, // Datatype of X
27 typename GammaDataType, // Datatype of Gamma
28 typename BetaDataType, // Datatype of Beta
29 typename YDataType, // Datatype of Y
30 typename AccDataType, // AccDatatype
31 typename XElementwiseOperation, // Operation of input
32 typename YElementwiseOperation, // Operation of output of normalization
33 typename InGrid2dDescTuple, // Descriptor tuple of inputs
34 typename GridDesc_M_K> // Descriptor of inputs, Gamma, Beta
36 const InGrid2dDescTuple in_grid_2d_desc_tuple, // Descriptor tuple of inputs
37 const GridDesc_M_K x_grid_desc_m_k, // Descriptor of X
38 const GridDesc_M_K gamma_grid_desc_m_k, // Descriptor of gamma
39 const GridDesc_M_K beta_grid_desc_m_k, // Descriptor of beta
40 const GridDesc_M_K y_grid_desc_m_k, // Descriptor of Y
41 index_t num_k_block_tile_iteration, //
42 AccDataType epsilon, // Datatype of epsilon
43 const InDataTypePointerTuple p_in_global_tuple, // Ptr tuple of input matrixs
44 const GammaDataType* const __restrict__ p_gamma_global, // Ptr of gamma
45 const BetaDataType* const __restrict__ p_beta_global, // Ptr of beta
46 YDataType* const __restrict__ p_y_global, // Ptr of y
47 const XElementwiseOperation x_elementwise_op, // Operation of input
48 const YElementwiseOperation y_elementwise_op) // Operation of output of normalization
49{
50 extern __shared__ XDataType p_x_lds[];
51 GridwiseElementwiseReduction::Run(in_grid_2d_desc_tuple, // Descriptor tuple of inputs
52 x_grid_desc_m_k, // Descriptor of X
53 gamma_grid_desc_m_k, // Descriptor of Gamma
54 beta_grid_desc_m_k, // Descriptor of Beta
55 y_grid_desc_m_k, // Descriptor of Y
56 num_k_block_tile_iteration, //
57 epsilon, // epsilon
58 p_in_global_tuple, // Ptr tuple of inputs
59 p_x_lds, // Ptr of X
60 p_gamma_global, // Ptr of gamma
61 p_beta_global, // Ptr of beta
62 p_y_global, // Ptr of Y
63 x_elementwise_op, // Operation of input
64 y_elementwise_op); // Operation of output of normalization
65};
66} // namespace ck
67
68namespace ck {
69namespace tensor_operation {
70namespace device {
71
72// Y = LayerNorm(A + B, Beta, Gamma)
73template <typename InDataTypeTuple, // Datatype of inputs
74 typename GammaDataType, // Datatype of gamma
75 typename BetaDataType, // Datatype of beta
76 typename AccDataType, //
77 typename YDataType, //
78 typename XElementwiseOperation, //
79 typename YElementwiseOperation, //
80 index_t Rank, //
81 index_t NumReduceDim, //
82 index_t BlockSize, //
83 index_t MThreadClusterSize, // Num of threads in a block on M direction
84 index_t KThreadClusterSize, // Num of threads in a block on N direction
85 index_t MThreadSliceSize, // Each thread calculate rows
86 index_t KThreadSliceSize, // Each thread calculate columns
87 index_t XYSrcVectorDim, // Dimension to do reduce
88 index_t XSrcVectorSize, // Size to fetch source x
89 index_t GammaSrcVectorDim, // Dimension for gamma to do reduce
90 index_t GammaSrcVectorSize, // Size to fetch source gamma
91 index_t BetaSrcVectorDim, // Dimension for beta to do reduce
92 index_t BetaSrcVectorSize, // Size to fetch source beta
93 index_t YDstVectorSize> // Size to write destination Y
95 : public DeviceElementwiseNormalization<InDataTypeTuple,
96 GammaDataType,
97 BetaDataType,
98 AccDataType,
99 YDataType,
100 XElementwiseOperation,
101 YElementwiseOperation,
102 Rank,
103 NumReduceDim>
104{
105 static constexpr int NumInput = InDataTypeTuple::Size();
106
107 using XDataType = YDataType;
108
109 static_assert(
110 (KThreadSliceSize % GammaSrcVectorSize == 0),
111 "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
112
113 static_assert(
114 (KThreadSliceSize % BetaSrcVectorSize == 0),
115 "Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
116
117 static constexpr index_t M_BlockTileSize =
118 MThreadClusterSize * MThreadSliceSize; // num of rows calculated in a block
119 static constexpr index_t K_BlockTileSize =
120 KThreadClusterSize * KThreadSliceSize; // num of columns calculated in a block
121
123 {
124 return generate_tuple(
125 [&](auto I) {
126 using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
127 return static_cast<const DataType*>(nullptr);
128 },
130 };
131
133
134 static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
135 const std::vector<index_t>& inStrides,
136 int blkGroupSize,
137 int numBlockTileIteration)
138 {
139 constexpr index_t NumInvariantDim = Rank - NumReduceDim;
140 static constexpr index_t numSrcDim = Rank;
141 static constexpr bool reduceAllDim = (NumInvariantDim == 0);
142
143 const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
144 const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
145
146 const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
147
148 const auto in_grid_desc_m_k = [&]() {
149 if constexpr(reduceAllDim)
150 {
151 const auto one_dim_inDesc = transform_tensor_descriptor(
152 inDesc,
153 make_tuple(make_merge_transform(tupleSrcLengths)),
156
157 return transform_tensor_descriptor(one_dim_inDesc,
159 1, one_dim_inDesc.GetLength(Number<0>{})))),
162 }
163 else
164 {
165 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
167
168 const auto reduceDimLengths =
169 make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
170 const auto invariantDimLengths =
171 make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
172
174 inDesc,
175 make_tuple(make_merge_transform(invariantDimLengths),
176 make_merge_transform(reduceDimLengths)),
177 make_tuple(InvariantDims{}, ReduceDims{}),
179 }
180 }();
181
182 const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
183 const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
184
185 const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
186 const auto inPad_M =
187 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
188 const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
189
190 auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
191 in_grid_desc_m_k,
192 make_tuple(make_right_pad_transform(invariantLength, inPad_M),
193 make_right_pad_transform(reduceLength, inPad_K)),
196
197 return (in_grid_desc_m_k_padded);
198 };
199
200 template <index_t TupleSize>
202 {
203 return generate_tuple([&](auto) { return MakeSrc2dDescriptor({1}, {1}, 1, 1); },
205 };
206
208
209 using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
210
213 XDataType,
214 GammaDataType,
215 BetaDataType,
216 YDataType,
217 AccDataType,
218 XElementwiseOperation,
219 YElementwiseOperation,
222 BlockSize,
223 MThreadClusterSize,
224 KThreadClusterSize,
225 MThreadSliceSize,
226 KThreadSliceSize,
227 XYSrcVectorDim,
228 XSrcVectorSize,
229 GammaSrcVectorDim,
230 GammaSrcVectorSize,
231 BetaSrcVectorDim,
232 BetaSrcVectorSize,
233 XYSrcVectorDim,
234 YDstVectorSize,
235 false>;
236
239 XDataType,
240 GammaDataType,
241 BetaDataType,
242 YDataType,
243 AccDataType,
244 XElementwiseOperation,
245 YElementwiseOperation,
248 BlockSize,
249 MThreadClusterSize,
250 KThreadClusterSize,
251 MThreadSliceSize,
252 KThreadSliceSize,
253 XYSrcVectorDim,
254 XSrcVectorSize,
255 GammaSrcVectorDim,
256 GammaSrcVectorSize,
257 BetaSrcVectorDim,
258 BetaSrcVectorSize,
259 XYSrcVectorDim,
260 YDstVectorSize,
261 true>;
262
263 struct Argument : public BaseArgument
264 {
265 Argument(const std::vector<index_t> lengths,
266 const std::array<std::vector<index_t>, NumInput> inStridesArray,
267 const std::vector<index_t> gammaStrides,
268 const std::vector<index_t> betaStrides,
269 const std::vector<index_t> yStrides,
270 const std::vector<index_t> reduceDims,
271 XElementwiseOperation x_elementwise_op,
272 YElementwiseOperation y_elementwise_op,
273 double epsilon,
274 const std::array<const void*, NumInput> in_dev_buffers,
275 const GammaDataType* p_gamma,
276 const BetaDataType* p_beta,
277 YDataType* p_y)
278 : p_gamma_(p_gamma),
279 p_beta_(p_beta),
280 p_y_(p_y),
281 x_elementwise_op_(x_elementwise_op),
282 y_elementwise_op_(y_elementwise_op)
283 {
284 epsilon_ = static_cast<AccDataType>(epsilon);
285
287 for(int i = 0; i < NumInput; i++)
288 {
289 inStridesArray_[i] =
290 shuffle_tensor_dimensions<Rank, NumReduceDim>(inStridesArray[i], reduceDims);
291 }
292
295
298
300 [&](auto I) {
301 using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
302 return static_cast<const DataType*>(in_dev_buffers[I.value]);
303 },
305
306 long_index_t invariant_total_length;
307 long_index_t reduce_total_length;
308
309 std::tie(invariant_total_length, reduce_total_length) =
311
312 blkGroupSize_ = 1;
313 numBlockTileIteration_ = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
314
315 gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
317
319 [&](auto I) {
320 return MakeSrc2dDescriptor(
322 },
324
327
330
333
336
338 x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
339
340 if(!sweep_once_) // if not sweep once, compute memory size for matrix X in lds for
341 // store Intermediate results
342 {
343 int block_TileSize = M_BlockTileSize * reduce_total_length;
344 x_lds_size_ = block_TileSize * sizeof(XDataType);
345 }
346 else
347 x_lds_size_ = 0;
348 }
349
350 AccDataType epsilon_;
351
353 const GammaDataType* p_gamma_;
354 const BetaDataType* p_beta_;
355 YDataType* p_y_;
356
357 std::vector<index_t> Lengths_;
358 std::array<std::vector<index_t>, NumInput> inStridesArray_;
359 std::vector<index_t> xStrides_;
360 std::vector<index_t> gammaStrides_;
361 std::vector<index_t> betaStrides_;
362 std::vector<index_t> yStrides_;
363
364 XElementwiseOperation x_elementwise_op_;
365 YElementwiseOperation y_elementwise_op_;
366
369 size_t gridSize_;
370
378 };
379
380 struct Invoker : public BaseInvoker
381 {
382 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
383 {
384 const auto kernel_main =
387 XDataType,
388 GammaDataType,
389 BetaDataType,
390 YDataType,
391 AccDataType,
392 XElementwiseOperation,
393 YElementwiseOperation,
398 XDataType,
399 GammaDataType,
400 BetaDataType,
401 YDataType,
402 AccDataType,
403 XElementwiseOperation,
404 YElementwiseOperation,
407
408 float avg_time = 0;
409 avg_time += launch_and_time_kernel(stream_config,
410 kernel_main,
411 dim3(arg.gridSize_),
412 dim3(BlockSize),
413 arg.x_lds_size_,
420 arg.epsilon_,
421 arg.in_dev_buffers_,
422 arg.p_gamma_,
423 arg.p_beta_,
424 arg.p_y_,
427
428 return (avg_time);
429 };
430
431 float Run(const BaseArgument* p_arg,
432 const StreamConfig& stream_config = StreamConfig{}) override
433 {
434 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
435 };
436 };
437
438 bool IsSupportedArgument(const BaseArgument* p_arg) override
439 {
440 const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
441
442 constexpr index_t NumInvariantDim = Rank - NumReduceDim;
443
444 if constexpr(XYSrcVectorDim == 0)
445 {
446 if constexpr(NumInvariantDim == 0)
447 {
448 return false;
449 }
450 else
451 {
452 for(int i = 0; i < NumInput; i++)
453 {
454 if(p_arg_->inStridesArray_[i][NumInvariantDim - 1] != 1)
455 return false;
456 }
457
458 if(p_arg_->inStridesArray_[0][NumInvariantDim - 1] != 1 &&
459 p_arg_->inStridesArray_[1][NumInvariantDim - 1] != 1)
460 return false;
461
462 if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0)
463 return false;
464 };
465 }
466 else
467 {
468 for(int i = 0; i < NumInput; i++)
469 {
470 if(p_arg_->inStridesArray_[i][Rank - 1] != 1)
471 return false;
472 }
473
474 if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0)
475 return false;
476 };
477
478 if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0)
479 {
480 return false;
481 }
482
483 auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
484 bool ret = true;
485
486 if(!isLastDimensionCoalesced)
487 ret = scalarPerVector == 1;
488 else
489 ret = KThreadSliceSize % scalarPerVector == 0;
490
491 return ret;
492 };
493
494 if(!IsScalarPerVectorValid(p_arg_->gammaStrides_.back() == 1, GammaSrcVectorSize))
495 return false;
496
497 if(!IsScalarPerVectorValid(p_arg_->betaStrides_.back() == 1, BetaSrcVectorSize))
498 return false;
499
500 // if fastest dim is not reduced
501 if constexpr(XYSrcVectorDim == 0) //
502 {
503 if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1)
504 return (false);
505
506 if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
507 return (false);
508 }
509 else // if fastest dim is reduced
510 {
511 if(p_arg_->gammaStrides_[Rank - 1] != 1)
512 return (false);
513
514 if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
515 return (false);
516 }
517
518 // if fastest dim is not reduced
519 if constexpr(XYSrcVectorDim == 0)
520 {
521 if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
522 return (false);
523
524 if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0)
525 return (false);
526 }
527 else // if fastest dim is reduced
528 {
529 if(p_arg_->betaStrides_[Rank - 1] != 1)
530 return (false);
531
532 if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0)
533 return (false);
534 }
535
536 if(p_arg_->x_lds_size_ >= 65536)
537 {
538 return (false);
539 }
540
541 return true;
542 };
543
544 std::unique_ptr<BaseArgument>
545 MakeArgumentPointer(const std::vector<index_t> lengths,
546 const std::array<std::vector<index_t>, NumInput> inStridesArray,
547 const std::vector<index_t> gammaStrides,
548 const std::vector<index_t> betaStrides,
549 const std::vector<index_t> yStrides,
550 const std::vector<index_t> reduceDims,
551 double epsilon,
552 const std::array<const void*, NumInput> in_dev_buffers,
553 const void* p_gamma,
554 const void* p_beta,
555 void* p_y,
556 XElementwiseOperation x_elementwise_op,
557 YElementwiseOperation y_elementwise_op) override
558 {
559 return std::make_unique<Argument>(lengths,
560 inStridesArray,
561 gammaStrides,
562 betaStrides,
563 yStrides,
564 reduceDims,
565 x_elementwise_op,
566 y_elementwise_op,
567 epsilon,
568 in_dev_buffers,
569 static_cast<const GammaDataType*>(p_gamma),
570 static_cast<const BetaDataType*>(p_beta),
571 static_cast<YDataType*>(p_y));
572 };
573
574 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
575 {
576 return std::make_unique<Invoker>();
577 };
578
579 std::string GetTypeString() const override
580 {
581 auto str = std::stringstream();
582
583 // clang-format off
584 str << "DeviceElementwiseNormalizationImpl<" << BlockSize << ",";
585 str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
586 str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
587 str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
588 str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">";
589 // clang-format on
590
591 return str.str();
592 }
593};
594
595} // namespace device
596} // namespace tensor_operation
597} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
auto make_tuple_from_array(const std::vector< index_t > &lengths, Number< arraySize >)
Definition device_reduce_common.hpp:65
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
auto make_tuple_from_array_and_index_seq(const std::vector< index_t > &lengths, Sequence< Ns... >)
Definition device_reduce_common.hpp:59
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__global__ void kernel_elementwise_layernorm(const InGrid2dDescTuple in_grid_2d_desc_tuple, const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K beta_grid_desc_m_k, const GridDesc_M_K y_grid_desc_m_k, index_t num_k_block_tile_iteration, AccDataType epsilon, const InDataTypePointerTuple p_in_global_tuple, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, const XElementwiseOperation x_elementwise_op, const YElementwiseOperation y_elementwise_op)
Definition device_elementwise_normalization_impl.hpp:35
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_elementwise_layernorm_welford_variance.hpp:42
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition device_base.hpp:197
Definition device_elementwise_normalization.hpp:25
Definition device_elementwise_normalization_impl.hpp:264
std::array< std::vector< index_t >, NumInput > inStridesArray_
Definition device_elementwise_normalization_impl.hpp:358
YElementwiseOperation y_elementwise_op_
Definition device_elementwise_normalization_impl.hpp:365
AccDataType epsilon_
Definition device_elementwise_normalization_impl.hpp:350
GridDesc_M_K gamma_grid_desc_m_k_
Definition device_elementwise_normalization_impl.hpp:373
size_t gridSize_
Definition device_elementwise_normalization_impl.hpp:369
GridDesc_M_K y_grid_desc_m_k_
Definition device_elementwise_normalization_impl.hpp:375
XElementwiseOperation x_elementwise_op_
Definition device_elementwise_normalization_impl.hpp:364
std::vector< index_t > betaStrides_
Definition device_elementwise_normalization_impl.hpp:361
std::vector< index_t > gammaStrides_
Definition device_elementwise_normalization_impl.hpp:360
bool sweep_once_
Definition device_elementwise_normalization_impl.hpp:376
int x_lds_size_
Definition device_elementwise_normalization_impl.hpp:377
int blkGroupSize_
Definition device_elementwise_normalization_impl.hpp:367
InGrid2dDescTuple in_grid_2d_desc_tuple_
Definition device_elementwise_normalization_impl.hpp:371
YDataType * p_y_
Definition device_elementwise_normalization_impl.hpp:355
std::vector< index_t > Lengths_
Definition device_elementwise_normalization_impl.hpp:357
GridDesc_M_K x_grid_desc_m_k_
Definition device_elementwise_normalization_impl.hpp:372
GridDesc_M_K beta_grid_desc_m_k_
Definition device_elementwise_normalization_impl.hpp:374
std::vector< index_t > yStrides_
Definition device_elementwise_normalization_impl.hpp:362
InDataTypePointerTuple in_dev_buffers_
Definition device_elementwise_normalization_impl.hpp:352
std::vector< index_t > xStrides_
Definition device_elementwise_normalization_impl.hpp:359
const GammaDataType * p_gamma_
Definition device_elementwise_normalization_impl.hpp:353
int numBlockTileIteration_
Definition device_elementwise_normalization_impl.hpp:368
const BetaDataType * p_beta_
Definition device_elementwise_normalization_impl.hpp:354
Argument(const std::vector< index_t > lengths, const std::array< std::vector< index_t >, NumInput > inStridesArray, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > reduceDims, XElementwiseOperation x_elementwise_op, YElementwiseOperation y_elementwise_op, double epsilon, const std::array< const void *, NumInput > in_dev_buffers, const GammaDataType *p_gamma, const BetaDataType *p_beta, YDataType *p_y)
Definition device_elementwise_normalization_impl.hpp:265
Definition device_elementwise_normalization_impl.hpp:381
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_elementwise_normalization_impl.hpp:431
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_elementwise_normalization_impl.hpp:382
Definition device_elementwise_normalization_impl.hpp:104
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::array< std::vector< index_t >, NumInput > inStridesArray, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > reduceDims, double epsilon, const std::array< const void *, NumInput > in_dev_buffers, const void *p_gamma, const void *p_beta, void *p_y, XElementwiseOperation x_elementwise_op, YElementwiseOperation y_elementwise_op) override
Definition device_elementwise_normalization_impl.hpp:545
std::string GetTypeString() const override
Definition device_elementwise_normalization_impl.hpp:579
static constexpr index_t M_BlockTileSize
Definition device_elementwise_normalization_impl.hpp:117
static auto GenerateSrcGrid2dDescTuple(Number< TupleSize >)
Definition device_elementwise_normalization_impl.hpp:201
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_elementwise_normalization_impl.hpp:574
static constexpr index_t K_BlockTileSize
Definition device_elementwise_normalization_impl.hpp:119
YDataType XDataType
Definition device_elementwise_normalization_impl.hpp:107
decltype(GenerateInDataTypePointerTuple()) InDataTypePointerTuple
Definition device_elementwise_normalization_impl.hpp:132
static auto GenerateInDataTypePointerTuple()
Definition device_elementwise_normalization_impl.hpp:122
decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)) GridDesc_M_K
Definition device_elementwise_normalization_impl.hpp:209
GridwiseElementwiseLayernormWelfordVariance_mk_to_mk< InDataTypePointerTuple, XDataType, GammaDataType, BetaDataType, YDataType, AccDataType, XElementwiseOperation, YElementwiseOperation, InGrid2dDescTuple, GridDesc_M_K, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, XYSrcVectorDim, YDstVectorSize, true > GridwiseReduceLayernormSweepOnce
Definition device_elementwise_normalization_impl.hpp:237
GridwiseElementwiseLayernormWelfordVariance_mk_to_mk< InDataTypePointerTuple, XDataType, GammaDataType, BetaDataType, YDataType, AccDataType, XElementwiseOperation, YElementwiseOperation, InGrid2dDescTuple, GridDesc_M_K, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, XYSrcVectorDim, YDstVectorSize, false > GridwiseReduceLayernormGeneric
Definition device_elementwise_normalization_impl.hpp:211
static constexpr int NumInput
Definition device_elementwise_normalization_impl.hpp:105
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_elementwise_normalization_impl.hpp:438
static auto MakeSrc2dDescriptor(const std::vector< index_t > &inLengths, const std::vector< index_t > &inStrides, int blkGroupSize, int numBlockTileIteration)
Definition device_elementwise_normalization_impl.hpp:134
decltype(GenerateSrcGrid2dDescTuple(Number< NumInput >{})) InGrid2dDescTuple
Definition device_elementwise_normalization_impl.hpp:207