handle_impl.h 36.7 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/common/handle_impl.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megdnn/handle.h"
#include "megdnn/oprs.h"

#include "src/common/utils.h"

#include <mutex>

21 22 23 24
#include "midout.h"

MIDOUT_DECL(dnn_src_common_handle_impl)

25 26 27 28 29 30 31
namespace megdnn {

class HandleImplHelper : public Handle {
public:
    using Handle::Handle;

    //! global matmul opr
M
Megvii Engine Team 已提交
32
    virtual MatrixMul* matmul_opr() { megdnn_throw("Unimplement matmul opr.\n"); }
33 34

    //! global matmul opr with first operand transposed
M
Megvii Engine Team 已提交
35
    virtual MatrixMul* matmul_aT_opr() { megdnn_throw("Unimplement matmul_aT opr.\n"); }
36 37

    //! global matmul opr with second operand transposed
M
Megvii Engine Team 已提交
38
    virtual MatrixMul* matmul_bT_opr() { megdnn_throw("Unimplement matmul_bT opr.\n"); }
39 40 41 42 43 44 45

    //! global matmul opr with both operand transposed
    virtual MatrixMul* matmul_aT_bT_opr() {
        megdnn_throw("Unimplement matmul_aT_bT opr.\n");
    }

    //! global relayout opr
M
Megvii Engine Team 已提交
46
    virtual Relayout* relayout_opr() { megdnn_throw("Unimplement Relayout opr.\n"); }
47

M
Megvii Engine Team 已提交
48
    virtual Checksum* checksum_opr() { megdnn_throw("Unimplement Checksum opr.\n"); }
49 50 51 52 53 54 55 56 57

    virtual MaxTensorDiff* max_tensor_diff_opr() {
        megdnn_throw("Unimplement MaxTensorDiff opr.\n");
    }

protected:
    static constexpr size_t NR_HELPER_OPRS = 7;

    template <class Opr, size_t idx, class Self>
M
Megvii Engine Team 已提交
58
    static Opr* get_helper_opr(Self self, const typename Opr::Param& param = {}) {
59 60
        MIDOUT_BEGIN(dnn_src_common_handle_impl, Opr, idx) {
            static_assert(idx < NR_HELPER_OPRS, "invalid idx");
61
            if (!self->m_helper_oprs[idx]) {
62
                MEGDNN_LOCK_GUARD(self->m_helper_oprs_mtx);
63
                if (!self->m_helper_oprs[idx]) {
M
Megvii Engine Team 已提交
64 65
                    self->m_helper_oprs[idx] = self->template create_operator<Opr>();
                    auto ret = static_cast<Opr*>(self->m_helper_oprs[idx].get());
66 67 68 69
                    ret->param() = param;
                    megdnn_assert(ret->is_thread_safe());
                    return ret;
                }
70
            }
71
            return static_cast<Opr*>(self->m_helper_oprs[idx].get());
72
        }
73
        MIDOUT_END();
74 75 76 77
    }

private:
    std::array<std::unique_ptr<OperatorBase>, NR_HELPER_OPRS> m_helper_oprs;
78
    DNN_MUTEX m_helper_oprs_mtx;
79 80 81 82 83 84 85
};

}  // namespace megdnn
/*!
 * \brief iterate though each operator class name; useful for explicit
 *      instantialization of create_operator<> templates
 */
M
Megvii Engine Team 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
#define MEGDNN_FOREACH_OPR_CLASS(cb)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    \
    cb(ConvolutionForward) cb(ConvolutionBackwardData) cb(ConvolutionBackwardFilter) cb(                                                                                                                                                                                                                                                                                                                                                                                                                                                                \
            ConvPoolingForward) cb(ConvBiasForward) cb(Images2NeibsForward) cb(Images2NeibsBackward)                                                                                                                                                                                                                                                                                                                                                                                                                                                    \
            cb(SlidingWindowTransposeForward) cb(SlidingWindowTransposeBackward) cb(                                                                                                                                                                                                                                                                                                                                                                                                                                                                    \
                    ElemwiseForward) cb(ElemwiseMultiType) cb(AddUpdateForward)                                                                                                                                                                                                                                                                                                                                                                                                                                                                         \
                    cb(RelayoutForward) cb(PoolingForward) cb(PoolingBackward) cb(                                                                                                                                                                                                                                                                                                                                                                                                                                                                      \
                            LocalForward) cb(LocalBackwardData) cb(LocalBackwardFilter)                                                                                                                                                                                                                                                                                                                                                                                                                                                                 \
                            cb(LRNForward) cb(LRNBackward) cb(ROIPoolingForward) cb(                                                                                                                                                                                                                                                                                                                                                                                                                                                                    \
                                    ROIPoolingBackward) cb(WarpPerspectiveForward)                                                                                                                                                                                                                                                                                                                                                                                                                                                                      \
                                    cb(WarpPerspectiveBackwardData) cb(                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 \
                                            WarpPerspectiveBackwardMat) cb(DotForward)                                                                                                                                                                                                                                                                                                                                                                                                                                                                  \
                                            cb(MatrixInverse) cb(MatrixMulForward) cb(                                                                                                                                                                                                                                                                                                                                                                                                                                                                  \
                                                    BatchedMatrixMulForward)                                                                                                                                                                                                                                                                                                                                                                                                                                                                            \
                                                    cb(SVDForward) cb(                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  \
                                                            ReduceForward) cb(CondTake)                                                                                                                                                                                                                                                                                                                                                                                                                                                                 \
                                                            cb(CumsumForward) cb(                                                                                                                                                                                                                                                                                                                                                                                                                                                                       \
                                                                    ArgmaxForward)                                                                                                                                                                                                                                                                                                                                                                                                                                                                      \
                                                                    cb(ArgminForward)                                                                                                                                                                                                                                                                                                                                                                                                                                                                   \
                                                                            cb(TransposeForward)                                                                                                                                                                                                                                                                                                                                                                                                                                                        \
                                                                                    cb(ConcatForward)                                                                                                                                                                                                                                                                                                                                                                                                                                                   \
                                                                                            cb(SplitForward)                                                                                                                                                                                                                                                                                                                                                                                                                                            \
                                                                                                    cb(TileForward)                                                                                                                                                                                                                                                                                                                                                                                                                                     \
                                                                                                            cb(TileBackward)                                                                                                                                                                                                                                                                                                                                                                                                                            \
                                                                                                                    cb(RepeatForward)                                                                                                                                                                                                                                                                                                                                                                                                                   \
                                                                                                                            cb(RepeatBackward)                                                                                                                                                                                                                                                                                                                                                                                                          \
                                                                                                                                    cb(ArgsortForward)                                                                                                                                                                                                                                                                                                                                                                                                  \
                                                                                                                                            cb(ArgsortBackward)                                                                                                                                                                                                                                                                                                                                                                                         \
                                                                                                                                                    cb(TypeCvt)                                                                                                                                                                                                                                                                                                                                                                                         \
                                                                                                                                                            cb(IndexingRemapForward)                                                                                                                                                                                                                                                                                                                                                                    \
                                                                                                                                                                    cb(IndexingRemapBackward)                                                                                                                                                                                                                                                                                                                                                           \
                                                                                                                                                                            cb(ChecksumForward) cb(IndexingOneHotForward) cb(IndexingSetOneHotForward) cb(IndexingMultiAxisVec) cb(IndexingSetMultiAxisVec) cb(IndexingIncrMultiAxisVec)                                                                                                                                                                                                                \
                                                                                                                                                                                    cb(                                                                                                                                                                                                                                                                                                                                                                 \
                                                                                                                                                                                            MeshIndexing) cb(IncrMeshIndexing) cb(SetMeshIndexing) cb(BatchedMeshIndexing) cb(BatchedIncrMeshIndexing) cb(BatchedSetMeshIndexing) cb(Linspace) cb(Eye) cb(SleepForward)                                                                                                                                                                                 \
                                                                                                                                                                                            cb(UniformRNG) cb(GaussianRNG) cb(                                                                                                                                                                                                                                                                                                                          \
                                                                                                                                                                                                    GammaRNG)                                                                                                                                                                                                                                                                                                                                           \
                                                                                                                                                                                                    cb(BetaRNG) cb(PoissonRNG) cb(PermutationRNG) cb(ShuffleRNGForward) cb(ShuffleRNGBackward) cb(SeparableConvForward) cb(                                                                                                                                                                                                                             \
                                                                                                                                                                                                            SeparableFilterForward)                                                                                                                                                                                                                                                                                                                     \
                                                                                                                                                                                                            cb(                                                                                                                                                                                                                                                                                                                                         \
                                                                                                                                                                                                                    BNForward) cb(BNBackward) cb(GroupLocalForward) cb(GroupLocalBackwardData)                                                                                                                                                                                                                                                          \
                                                                                                                                                                                                                    cb(GroupLocalBackwardFilter)                                                                                                                                                                                                                                                                                                        \
                                                                                                                                                                                                                            cb(Flip) cb(                                                                                                                                                                                                                                                                                                                \
                                                                                                                                                                                                                                    Rotate)                                                                                                                                                                                                                                                                                                             \
                                                                                                                                                                                                                                    cb(                                                                                                                                                                                                                                                                                                                 \
                                                                                                                                                                                                                                            ROICopy) cb(CvtColor) cb(WarpAffine) cb(GaussianBlur) cb(Resize) cb(ResizeBackward)                                                                                                                                                                                                                         \
                                                                                                                                                                                                                                            cb(ParamPackConcat) cb(MaxTensorDiff) cb(MaskConvForward) cb(                                                                                                                                                                                                                                               \
                                                                                                                                                                                                                                                    MaskPropagate)                                                                                                                                                                                                                                                                                      \
                                                                                                                                                                                                                                                    cb(Convolution3DForward)                                                                                                                                                                                                                                                                            \
                                                                                                                                                                                                                                                            cb(Convolution3DBackwardData) cb(Convolution3DBackwardFilter) cb(DeformableConvForward) cb(                                                                                                                                                                                                 \
                                                                                                                                                                                                                                                                    DeformableConvBackwardFilter)                                                                                                                                                                                                                                                       \
                                                                                                                                                                                                                                                                    cb(                                                                                                                                                                                                                                                                                 \
                                                                                                                                                                                                                                                                            DeformableConvBackwardData) cb(DeformablePSROIPoolingForward) cb(DeformablePSROIPoolingBackward) cb(RelayoutFormat) cb(TopK)                                                                                                                                                \
                                                                                                                                                                                                                                                                            cb(PowC) cb(LocalShareForward) cb(                                                                                                                                                                                                                                          \
                                                                                                                                                                                                                                                                                    LocalShareBackwardData) cb(LocalShareBackwardFilter)                                                                                                                                                                                                                \
                                                                                                                                                                                                                                                                                    cb(                                                                                                                                                                                                                                                                 \
                                                                                                                                                                                                                                                                                            ROIAlignForward) cb(ROIAlignBackward) cb(CorrelationForward) cb(CorrelationBackwardData1) cb(CorrelationBackwardData2) cb(BatchConvBiasForward) cb(Remap) cb(RemapBackwardData) cb(RemapBackwardMat) cb(AdaptivePoolingForward) cb(AdaptivePoolingBackward) \
                                                                                                                                                                                                                                                                                            cb(DctChannelSelectForward) cb(FakeQuantForward) cb(FakeQuantBackward)                                                                                                                                                                                      \
                                                                                                                                                                                                                                                                                                    cb(TQTForward) cb(                                                                                                                                                                                                                                  \
                                                                                                                                                                                                                                                                                                            TQTBackward)                                                                                                                                                                                                                                \
                                                                                                                                                                                                                                                                                                            cb(CheckNonFinite)                                                                                                                                                                                                                          \
                                                                                                                                                                                                                                                                                                                    cb(LSQForward) cb(                                                                                                                                                                                                                  \
                                                                                                                                                                                                                                                                                                                            LSQBackward)                                                                                                                                                                                                                \
                                                                                                                                                                                                                                                                                                                            cb(Fill) cb(                                                                                                                                                                                                                \
                                                                                                                                                                                                                                                                                                                                    PaddingForward)                                                                                                                                                                                                     \
                                                                                                                                                                                                                                                                                                                                    cb(PaddingBackward)
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167

/*!
 * \brief specialize HandleImpl::create_operator for a single opr type;
 *      implemented by <opr>Impl class
 */
#define MEGDNN_SPECIALIZE_CREATE_OPERATOR(opr)                   \
    template <>                                                  \
    std::unique_ptr<megdnn::opr> HandleImpl::create_operator() { \
        return megdnn::make_unique<opr##Impl>(this);             \
    }

/*!
 * \brief for explicit instantiation for HandleImpl::create_operator methods
 */
#define MEGDNN_INST_CREATE_OPERATOR(opr) \
    template std::unique_ptr<megdnn::opr> HandleImpl::create_operator();

// vim: syntax=cpp.doxygen