DeviceGroupedConvFwdMultipleABD< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, AComputeType, BComputeType > Struct Template Reference#
Grouped Convolution Forward. More...
#include <device_grouped_conv_fwd_multiple_abd.hpp>
Public Types | |
| using | APointers |
| using | BPointers |
Public Member Functions | |
| virtual std::unique_ptr< BaseArgument > | MakeArgumentPointer (APointers p_a, BPointers p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)=0 |
| Make argument pointer for grouped conv fwd. | |
| virtual std::unique_ptr< BaseArgument > | MakeArgumentPointer (APointers p_a, BPointers p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)=0 |
| virtual std::unique_ptr< BaseInvoker > | MakeInvokerPointer ()=0 |
| Public Member Functions inherited from ck::tensor_operation::device::BaseOperator | |
| BaseOperator ()=default | |
| BaseOperator (const BaseOperator &)=default | |
| BaseOperator & | operator= (const BaseOperator &)=default |
| virtual bool | IsSupportedArgument (const BaseArgument *) |
| virtual std::string | GetTypeString () const |
| virtual std::string | GetInstanceString () const |
| virtual std::string | GetTypeIdName () const |
| virtual std::optional< std::string > | GetObjectName () const |
| virtual std::optional< std::string > | GetTemplateInfo () const |
| virtual std::string | GetTypeIdHashCode () const |
| virtual size_t | GetWorkSpaceSize (const BaseArgument *) const |
| virtual void | SetWorkSpacePointer (BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const |
| virtual | ~BaseOperator () |
Static Public Attributes | |
| static constexpr bool | isMultiA = is_detected<is_tuple, ADataType>::value |
| static constexpr bool | isMultiB = is_detected<is_tuple, BDataType>::value |
| static constexpr index_t | NumATensor = GetNumABTensors<isMultiA, ADataType>() |
| static constexpr index_t | NumBTensor = GetNumABTensors<isMultiB, BDataType>() |
| static constexpr index_t | NumDTensor = DsDataType::Size() |
Detailed Description
struct ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, AComputeType, BComputeType >
Grouped Convolution Forward.
input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]... input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]... input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... output : output image E[G, N, K, Ho, Wo]
C = a_op(A, A1...) * b_op(B, B1...) E = cde_op(C, D0, D1, ...)
- Template Parameters
-
NDimSpatial Number of spatial dimensions. ALayout Input layout (also for a1, a2...). BLayout Weight layout (also for b1, b2...). DsLayout Ds layouts. ELayout Output layout. ADataType Input data type. Pass tuple if there is multiple A. BDataType Weight data type. Pass tuple if there is multiple B. DsDataType D data types. EDataType Output data type. AElementwiseOperation A elementwise operation. BElementwiseOperation B elementwise operation. CDEElementwiseOperation CDE elementwise operation. AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed). BComputeType Compute data type for B tensor (default: AComputeType).
Member Typedef Documentation
◆ APointers
| using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, AComputeType, BComputeType >::APointers |
◆ BPointers
| using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, AComputeType, BComputeType >::BPointers |
Member Function Documentation
◆ MakeArgumentPointer() [1/2]
|
pure virtual |
Make argument pointer for grouped conv fwd.
- Parameters
-
p_a A pointer to the input (std::array<const void*, NumA> with pointers for multiple A). p_b A pointer to the weight (std::array<const void*, NumA> with pointers for multiple B). p_ds A pointers to the Ds. p_e A pointers to the output. a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d). a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d). b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d). b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d). ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d). ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d). e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d). e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d). conv_filter_strides Convolution filter strides. conv_filter_dilations Convolution filter dilations. input_left_pads Input left paddings. input_right_pads Input right paddings. a_element_op A elementwise operation object. b_element_op B elementwise operation object. cde_element_op CDE elementwise operation object.
- Returns
- Pointer to the argument.
◆ MakeArgumentPointer() [2/2]
|
pure virtual |
◆ MakeInvokerPointer()
|
pure virtual |
Implemented in ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< NDimSpatial, ADataType, BDataType, DsDataType, EDataType, AccDataType, ALayout, BLayout, DsLayout, ELayout, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, BlockSize, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >, ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched, NumGroupsToMerge >, ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >, ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, AComputeDataType, BComputeDataType, DirectLoad >, ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, K1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer >, and ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >.
Member Data Documentation
◆ isMultiA
|
staticconstexpr |
◆ isMultiB
|
staticconstexpr |
◆ NumATensor
|
staticconstexpr |
◆ NumBTensor
|
staticconstexpr |
◆ NumDTensor
|
staticconstexpr |
The documentation for this struct was generated from the following file: