opr_trait.h 5.6 KB
Newer Older
1 2
#pragma once
#include "megdnn/oprs.h"
3
#include "megdnn/oprs/nn.h"
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19

#include <cstddef>

namespace megdnn {

template <typename Opr>
struct OprTrait {};

#define DEF(Name, Arity, HasWorkspace, CanDeduceLayout)        \
    template <>                                                \
    struct OprTrait<Name> {                                    \
        static const size_t arity = Arity;                     \
        static const bool has_workspace = HasWorkspace;        \
        static const bool can_deduce_layout = CanDeduceLayout; \
    }

20
DEF(Norm, 2, true, true);
21 22
DEF(Padding, 2, false, true);
DEF(PaddingBackward, 2, false, false);
23 24 25 26 27 28 29 30 31 32 33 34
DEF(ConvolutionForward, 3, true, true);
DEF(Convolution3DForward, 3, true, true);
DEF(ConvolutionBackwardData, 3, true, false);
DEF(ConvolutionBackwardFilter, 3, true, false);
DEF(Convolution3DBackwardData, 3, true, false);
DEF(Convolution3DBackwardFilter, 3, true, false);
DEF(ConvPoolingForward, 4, true, true);
DEF(ConvBiasForward, 5, true, true);
DEF(SeparableConvForward, 4, true, true);
DEF(SeparableFilterForward, 4, true, true);
DEF(Images2NeibsForward, 2, true, true);
DEF(Images2NeibsBackward, 2, true, false);
35 36
DEF(SlidingWindowTransposeForward, 2, true, true);
DEF(SlidingWindowTransposeBackward, 2, true, false);
37 38
DEF(PoolingForward, 2, true, true);
DEF(PoolingBackward, 4, true, false);
39 40
DEF(AdaptivePoolingForward, 2, true, false);
DEF(AdaptivePoolingBackward, 4, true, false);
41 42 43 44 45 46 47 48
DEF(LocalForward, 3, true, true);
DEF(LocalBackwardData, 3, true, false);
DEF(LocalBackwardFilter, 3, true, false);
DEF(GroupLocalForward, 3, true, true);
DEF(GroupLocalBackwardData, 3, true, false);
DEF(GroupLocalBackwardFilter, 3, true, false);
DEF(LRNForward, 2, true, true);
DEF(LRNBackward, 4, true, false);
49 50
DEF(BNForward, 9, true, true);
DEF(BNBackward, 9, true, false);
51 52
DEF(ROIPoolingForward, 4, true, false);
DEF(ROIPoolingBackward, 5, true, false);
53 54 55
DEF(CorrelationForward, 3, true, true);
DEF(CorrelationBackwardData1, 4, true, true);
DEF(CorrelationBackwardData2, 4, true, true);
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
DEF(WarpPerspectiveForward, 3, true, false);
DEF(WarpPerspectiveBackwardData, 3, true, false);
DEF(WarpPerspectiveBackwardMat, 4, true, false);
DEF(AddUpdateForward, 2, false, false);
DEF(DotForward, 3, true, true);
DEF(MatrixMulForward, 3, true, true);
DEF(BatchedMatrixMulForward, 3, true, true);
DEF(MatrixInverse, 2, true, true);
DEF(SVDForward, 4, true, true);
DEF(ReduceForward, 2, true, true);
DEF(CumsumForward, 2, true, true);
DEF(ArgmaxForward, 2, true, true);
DEF(ArgminForward, 2, true, true);
DEF(TransposeForward, 2, true, true);
DEF(RelayoutForward, 2, false, false);
DEF(TileForward, 2, true, true);
DEF(TileBackward, 2, true, false);
DEF(RepeatForward, 2, true, true);
DEF(RepeatBackward, 2, true, false);
DEF(ArgsortForward, 3, true, true);
DEF(ArgsortBackward, 3, true, false);
DEF(TypeCvtForward, 2, false, false);
DEF(IndexingRemapForward, 3, true, true);
DEF(IndexingRemapBackward, 3, true, false);
DEF(Linspace, 1, true, false);
DEF(Eye, 1, true, false);
82
DEF(Diag, 2, true, true);
83
DEF(Cross, 3, true, true);
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
DEF(Flip, 2, true, true);
DEF(ROICopy, 2, true, true);
DEF(Rotate, 2, true, true);
DEF(CvtColor, 2, true, true);
DEF(WarpAffine, 3, true, false);
DEF(GaussianBlur, 2, true, true);
DEF(Resize, 2, true, false);
DEF(ResizeBackward, 2, true, false);
DEF(IndexingOneHot, 3, true, true);
DEF(IndexingSetOneHot, 3, true, false);
DEF(MaskConvolution, 4, true, true);
DEF(MaskPropagate, 2, true, true);
DEF(RelayoutFormat, 2, true, true);
DEF(MaxTensorDiff, 2, true, false);
DEF(LocalShareForward, 3, true, true);
DEF(LocalShareBackwardData, 3, true, false);
DEF(LocalShareBackwardFilter, 3, true, false);
DEF(ROIAlignForward, 4, true, true);
DEF(ROIAlignBackward, 4, true, false);
DEF(DeformableConvForward, 5, true, true);
DEF(DeformableConvBackwardFilter, 5, true, false);
DEF(DeformableConvBackwardData, 8, true, false);
DEF(DeformablePSROIPoolingForward, 5, true, true);
DEF(DeformablePSROIPoolingBackward, 7, true, false);
DEF(BatchConvBiasForward, 5, true, true);
109
DEF(Remap, 3, true, true);
110 111
DEF(RemapBackwardData, 3, true, false);
DEF(RemapBackwardMat, 4, true, false);
112
DEF(DctChannelSelectForward, 4, true, true);
113 114
DEF(FakeQuantForward, 4, true, true);
DEF(FakeQuantBackward, 5, true, false);
M
Megvii Engine Team 已提交
115 116
DEF(TQTForward, 3, true, true);
DEF(TQTBackward, 5, true, false);
117 118 119
DEF(PowC, 2, false, true);
DEF(UniformRNG, 1, true, true);
DEF(GaussianRNG, 1, true, true);
120 121 122 123
DEF(GammaRNG, 3, true, true);
DEF(BetaRNG, 3, true, true);
DEF(PoissonRNG, 2, true, true);
DEF(PermutationRNG, 1, true, true);
124 125
DEF(ShuffleRNGForward, 3, true, true);
DEF(ShuffleRNGBackward, 3, true, false);
126
DEF(ChecksumForward, 1, true, false);
127
DEF(CheckNonFinite, 2, true, true);
M
Megvii Engine Team 已提交
128 129
DEF(LSQForward, 5, true, true);
DEF(LSQBackward, 7, true, false);
M
Megvii Engine Team 已提交
130
DEF(Fill, 1, true, false);
131 132
DEF(LayerNormForward, 6, true, true);
DEF(LayerNormBackward, 8, true, true);
133 134
DEF(GeneralNormForward, 6, true, true);
DEF(GeneralNormBackward, 8, true, true);
135
DEF(LAMBUpdate, 7, true, true);
136 137
DEF(DropoutForward, 3, true, true);
DEF(DropoutBackward, 3, true, true);
138
DEF(RNNCellForward, 7, true, true);
139
DEF(RNNForward, 6, true, true);
140 141 142 143
DEF(RNNBackward, 10, true, true);
DEF(LSTMCellForward, 10, true, true);
DEF(LSTMForward, 8, true, true);
DEF(LSTMBackward, 13, true, true);
144 145
DEF(SoftmaxForward, 2, true, true);
DEF(SoftmaxBackward, 3, true, false);
146 147 148
DEF(RegionRestrictedConvolutionForward, 5, true, true);
DEF(RegionRestrictedConvolutionBackwardData, 5, true, false);
DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false);
149 150
DEF(GroupNormForward, 6, true, true);
DEF(GroupNormBackward, 8, true, true);
M
Megvii Engine Team 已提交
151
DEF(MaskedFill, 3, false, true);
152 153
DEF(MultiHeadAttnForward, 11, true, true);
DEF(MultiHeadAttnBackward, 15, true, true);
154
DEF(Resize3D, 2, true, false);
155 156 157
}  // namespace megdnn

// vim: syntax=cpp.doxygen