reference_batched_transpose.hpp Source File

reference_batched_transpose.hpp Source File#

Composable Kernel: reference_batched_transpose.hpp Source File
reference_batched_transpose.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <thread>
9
10namespace ck_tile {
11
12template <typename Type>
15 std::string layout_in = "NCHW",
16 std::string layout_out = "NHWC")
17{
18 const int N = x.mDesc.get_lengths()[0];
19
20 auto f = [&](auto batch) {
21 if(layout_in == "NCHW" && layout_out == "NHWC")
22 {
23 const int C = x.mDesc.get_lengths()[1];
24 const int H = x.mDesc.get_lengths()[2];
25 const int W = x.mDesc.get_lengths()[3];
26 for(int c = 0; c < C; ++c)
27 {
28 for(int h = 0; h < H; ++h)
29 {
30 for(int w = 0; w < W; ++w)
31 {
32 Type v_x = x(batch, c, h, w);
33 y(batch, h, w, c) = v_x;
34 }
35 }
36 }
37 }
38 else if(layout_in == "NHWC" && layout_out == "NCHW")
39 {
40 const int H = x.mDesc.get_lengths()[1];
41 const int W = x.mDesc.get_lengths()[2];
42 const int C = x.mDesc.get_lengths()[3];
43 for(int h = 0; h < H; ++h)
44 {
45 for(int w = 0; w < W; ++w)
46 {
47 for(int c = 0; c < C; ++c)
48 {
49 Type v_x = x(batch, h, w, c);
50 y(batch, c, h, w) = v_x;
51 }
52 }
53 }
54 }
55 };
56
57 make_ParallelTensorFunctor(f, N)(std::thread::hardware_concurrency());
58}
59} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
CK_TILE_HOST void reference_batched_transpose(const HostTensor< Type > &x, HostTensor< Type > &y, std::string layout_in="NCHW", std::string layout_out="NHWC")
Definition reference_batched_transpose.hpp:13
Type
Type of JSON value.
Definition rapidjson.h:760
const std::vector< std::size_t > & get_lengths() const
Definition tile/host/host_tensor.hpp:198
Definition tile/host/host_tensor.hpp:336
Descriptor mDesc
Definition tile/host/host_tensor.hpp:800