/** * \file src/opr/impl/dnn/dnn.sereg.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/images2neibs.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/adaptive_pooling.h" #include "megbrain/opr/dnn/roi_pooling.h" #include "megbrain/opr/dnn/roi_align.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/fake_quant.h" #include "megbrain/opr/dnn/tqt.h" #include "megbrain/serialization/sereg.h" #include "megdnn/opr_param_defs.h" #include "megdnn/oprs/nn.h" namespace mgb { namespace serialization { template struct MakePoolingCaller1 { template static VarNode* make(const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param, const OperatorNodeConfig& config) { if (inputs.size() == 1) { return Opr::make(inputs[0], param, config).node(); } return nullptr; } }; template struct MakeROIAlignCaller1 { template static VarNode* make(const cg::VarNodeArray& inputs, const typename MegDNNROIALIGN::Param& param, const OperatorNodeConfig& config) { if (inputs.size() == 2) { return Opr::make(inputs[0], inputs[1], param, config).node(); } else { return nullptr; } } }; template struct MakeROIAlignCaller4 { template static VarNode* make(const cg::VarNodeArray& inputs, const typename MegDNNROIALIGN::Param& param, const OperatorNodeConfig& config) { if (inputs.size() == 4) { return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param, config) .node(); } else { return nullptr; } } }; template struct MakePoolingBackwardCaller3 { template static VarNode* make(const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param, const OperatorNodeConfig& config) { if (inputs.size() == 3) { return Opr::make(inputs[0], inputs[1], inputs[2], param, config) .node(); } return nullptr; } }; template struct MakeAdaptivePoolingBackwardCaller3 { template static VarNode* make(const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param, const OperatorNodeConfig& config) { if (inputs.size() == 4) { return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param, config) .node(); } return nullptr; } }; template struct MakeConvCaller2 { template static VarNode* make(const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, const megdnn::param::ExecutionPolicy& execution_policy, const OperatorNodeConfig& config) { if (inputs.size() == 2) { return Opr::make(inputs[0], inputs[1], param, execution_policy, config) .node(); } return nullptr; } }; template struct MakeConvCaller3 { template static VarNode* make(const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, const megdnn::param::ExecutionPolicy& execution_policy, const OperatorNodeConfig& config) { if (inputs.size() == 3) { return Opr::make(inputs[0], inputs[1], inputs[2], param, execution_policy, config) .node(); } return nullptr; } }; template struct MakeConvCaller4 { template static VarNode* make(const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, const megdnn::param::ExecutionPolicy& execution_policy, const OperatorNodeConfig& config) { if (inputs.size() == 4) { return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param, execution_policy, config) .node(); } return nullptr; } }; template struct MakeConvCaller5 { template static VarNode* make(const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, const megdnn::param::ExecutionPolicy& execution_policy, const OperatorNodeConfig& config) { if (inputs.size() == 5) { return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], param, execution_policy, config) .node(); } return nullptr; } }; template struct MakeConvCallerEmpty { template static VarNode* make(const cg::VarNodeArray&, const typename MegDNNConv::Param&, const megdnn::param::ExecutionPolicy&, const OperatorNodeConfig&) { return nullptr; } }; template , class Maker2 = MakeConvCallerEmpty, typename ConvParam = megdnn::param::Convolution> struct ConvLoadDumpImpl { static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { auto&& opr = opr_.cast_final_safe(); ctx.write_param(opr.param()); ctx.write_param(opr.execution_policy()); } static VarNode* make(const cg::VarNodeArray& inputs, const ConvParam& param, const megdnn::param::ExecutionPolicy& execution_policy, const OperatorNodeConfig& config) { VarNode* ret = Maker0::template make(inputs, param, execution_policy, config); if (!ret) { ret = Maker1::template make(inputs, param, execution_policy, config); } if (!ret) { ret = Maker2::template make(inputs, param, execution_policy, config); } mgb_assert(ret); return ret; } static cg::OperatorNodeBase* load(OprLoadContext& ctx, const cg::VarNodeArray& inputs, const OperatorNodeConfig& config) { auto param = ctx.read_param(); auto execution_policy = ctx.read_param(); return make(inputs, param, execution_policy, config)->owner_opr(); } }; template struct PoolingLoadDumpImpl { static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { auto&& opr = opr_.cast_final_safe(); ctx.write_param(opr.param()); } static VarNode* make(const cg::VarNodeArray& inputs, const PoolingParam& param, const OperatorNodeConfig& config) { VarNode* ret = Maker0::template make(inputs, param, config); mgb_assert(ret); return ret; } static cg::OperatorNodeBase* load(OprLoadContext& ctx, const cg::VarNodeArray& inputs, const OperatorNodeConfig& config) { auto param = ctx.read_param(); return make(inputs, param, config)->owner_opr(); } }; template <> struct OprMaker { using Param = opr::TQTBackward::Param; static cg::OperatorNodeBase* make(const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); return opr::TQTBackward::make(i[0], i[1], i[2], param, config)[0] .node() ->owner_opr(); } }; template <> struct OprLoadDumpImpl : public PoolingLoadDumpImpl, megdnn::param::AdaptivePooling> {}; template <> struct OprLoadDumpImpl : public PoolingLoadDumpImpl< opr::AdaptivePooling, MakeROIAlignCaller1, megdnn::param::AdaptivePooling> {}; template <> struct OprLoadDumpImpl : public PoolingLoadDumpImpl, megdnn::param::ROIAlign> {}; template <> struct OprLoadDumpImpl : public PoolingLoadDumpImpl< opr::ROIAlignBackward, MakeROIAlignCaller4, megdnn::param::ROIAlign> {}; template <> struct OprLoadDumpImpl : public PoolingLoadDumpImpl, megdnn::param::Pooling> {}; template <> struct OprLoadDumpImpl : public PoolingLoadDumpImpl< opr::PoolingBackward, MakePoolingBackwardCaller3, megdnn::param::Pooling> {}; template <> struct OprLoadDumpImpl : public ConvLoadDumpImpl, megdnn::Convolution> {}; template <> struct OprLoadDumpImpl : public ConvLoadDumpImpl, megdnn::Convolution, MakeConvCaller3 > {}; template <> struct OprLoadDumpImpl : public ConvLoadDumpImpl, megdnn::Convolution> {}; template <> struct OprLoadDumpImpl : public ConvLoadDumpImpl, megdnn::Convolution3D, MakeConvCallerEmpty, MakeConvCallerEmpty, megdnn::param::Convolution3D> {}; template <> struct OprLoadDumpImpl : public ConvLoadDumpImpl, megdnn::Convolution3D, MakeConvCaller3, MakeConvCallerEmpty, megdnn::param::Convolution3D> {}; template <> struct OprLoadDumpImpl : public ConvLoadDumpImpl, megdnn::Convolution3D, MakeConvCallerEmpty, MakeConvCallerEmpty, megdnn::param::Convolution3D> {}; template <> struct OprLoadDumpImpl : public ConvLoadDumpImpl, megdnn::ConvBiasForward, MakeConvCaller3, MakeConvCaller4, megdnn::param::ConvBias> {}; template <> struct OprLoadDumpImpl : public ConvLoadDumpImpl, megdnn::BatchConvBiasForward, MakeConvCaller3, MakeConvCaller4, megdnn::param::BatchConvBias> {}; template <> struct OprMaker { using Param = opr::BatchNorm::Param; static cg::OperatorNodeBase* make(const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); if (i.size() == 3) { return opr::BatchNorm::make(i[0], i[1], i[2], param, config)[0] .node() ->owner_opr(); } else { mgb_assert(i.size() == 5); return opr::BatchNorm::make(i[0], i[1], i[2], i[3], i[4], param, config)[0] .node() ->owner_opr(); } } }; template <> struct OprMaker { using Param = opr::BatchNormBackward::Param; static cg::OperatorNodeBase* make(const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], param, config)[0] .node() ->owner_opr(); } }; template struct MakeLocalShareCaller2 { template static VarNode* make(const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, const megdnn::param::ExecutionPolicy& execution_policy, const OperatorNodeConfig& config) { if (inputs.size() == 2) { return Opr::make(inputs[0], inputs[1], param, execution_policy, config) .node(); } return nullptr; } }; template struct MakeLocalShareCaller3 { template static VarNode* make(const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, const megdnn::param::ExecutionPolicy& execution_policy, const OperatorNodeConfig& config) { if (inputs.size() == 3) { return Opr::make(inputs[0], inputs[1], inputs[2], param, execution_policy, config) .node(); } return nullptr; } }; template struct MakeLocalShareCallerEmpty { template static VarNode* make(const cg::VarNodeArray&, const typename MegDNNConv::Param&, const megdnn::param::ExecutionPolicy&, const OperatorNodeConfig&) { return nullptr; } }; template , class Maker2 = MakeLocalShareCallerEmpty, typename LocalShareParam = megdnn::param::LocalShare> struct LocalShareLoadDumpImpl { static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { auto&& opr = opr_.cast_final_safe(); ctx.write_param(opr.param()); ctx.write_param(opr.execution_policy()); } static VarNode* make(const cg::VarNodeArray& inputs, const LocalShareParam& param, const megdnn::param::ExecutionPolicy& execution_policy, const OperatorNodeConfig& config) { VarNode* ret = Maker0::template make(inputs, param, execution_policy, config); if (!ret) { ret = Maker1::template make(inputs, param, execution_policy, config); } if (!ret) { ret = Maker2::template make(inputs, param, execution_policy, config); } mgb_assert(ret); return ret; } static cg::OperatorNodeBase* load(OprLoadContext& ctx, const cg::VarNodeArray& inputs, const OperatorNodeConfig& config) { auto param = ctx.read_param(); auto execution_policy = ctx.read_param(); return make(inputs, param, execution_policy, config)->owner_opr(); } }; template <> struct OprLoadDumpImpl : public LocalShareLoadDumpImpl< opr::LocalShare, MakeLocalShareCaller2, megdnn::LocalShare> {}; template <> struct OprLoadDumpImpl : public LocalShareLoadDumpImpl< opr::LocalShareBackwardData, MakeLocalShareCaller3, megdnn::LocalShare> {}; template <> struct OprLoadDumpImpl : public LocalShareLoadDumpImpl< opr::LocalShareBackwardFilter, MakeLocalShareCaller3, megdnn::LocalShare> {}; template <> struct OprLoadDumpImpl : public ConvLoadDumpImpl< opr::DeformableConvForward, MakeConvCaller4, megdnn::Convolution> {}; template <> struct OprLoadDumpImpl : public ConvLoadDumpImpl< opr::DeformableConvBackwardData, MakeConvCaller5, megdnn::Convolution> {}; template <> struct OprLoadDumpImpl : public ConvLoadDumpImpl< opr::DeformableConvBackwardFilter, MakeConvCaller5, megdnn::Convolution> {}; } // namespace serialization namespace opr { using ConvolutionV2 = Convolution; using ConvolutionBackwardDataV2 = ConvolutionBackwardData; using ConvolutionBackwardFilterV2 = ConvolutionBackwardFilter; MGB_SEREG_OPR(ConvolutionV2, 0); MGB_SEREG_OPR(ConvolutionBackwardDataV2, 0); MGB_SEREG_OPR(ConvolutionBackwardFilterV2, 0); MGB_SEREG_OPR(Images2Neibs, 1); MGB_SEREG_OPR(Images2NeibsBackward, 2); using LocalV1 = Local; using LocalBackwardDataV1 = LocalBackwardData; using LocalBackwardFilterV1 = LocalBackwardFilter; MGB_SEREG_OPR(LocalV1, 2); MGB_SEREG_OPR(LocalBackwardDataV1, 3); MGB_SEREG_OPR(LocalBackwardFilterV1, 3); using GroupLocalV1 = GroupLocal; using GroupLocalBackwardDataV1 = GroupLocalBackwardData; using GroupLocalBackwardFilterV1 = GroupLocalBackwardFilter; MGB_SEREG_OPR(GroupLocalV1, 2); MGB_SEREG_OPR(GroupLocalBackwardDataV1, 3); MGB_SEREG_OPR(GroupLocalBackwardFilterV1, 3); MGB_SEREG_OPR(LRN, 1); MGB_SEREG_OPR(LRNBackward, 3); using PoolingV1 = Pooling; using PoolingBackwardV1 = PoolingBackward; MGB_SEREG_OPR(PoolingV1, 1); MGB_SEREG_OPR(PoolingBackwardV1, 3); using AdaptivePoolingV1 = AdaptivePooling; using AdaptivePoolingBackwardV1 = AdaptivePoolingBackward; MGB_SEREG_OPR(AdaptivePoolingV1, 2); MGB_SEREG_OPR(AdaptivePoolingBackwardV1, 4); MGB_SEREG_OPR(ROIPooling, 3); MGB_SEREG_OPR(ROIPoolingBackward, 4); using MaskConvolutionV1 = MaskConvolution; MGB_SEREG_OPR(MaskConvolutionV1, 3); MGB_SEREG_OPR(MaskPropagate, 1); MGB_SEREG_OPR(Convolution3D, 0); MGB_SEREG_OPR(Convolution3DBackwardData, 0); MGB_SEREG_OPR(Convolution3DBackwardFilter, 0); using ConvBiasForwardV4 = ConvBiasForward; MGB_SEREG_OPR(ConvBiasForwardV4, 0); MGB_SEREG_OPR(BatchNorm, 0); MGB_SEREG_OPR(BatchNormBackward, 5); using LocalShareForwardV1 = LocalShareForward; using LocalShareBackwardDataV1 = LocalShareBackwardData; using LocalShareBackwardFilterV1 = LocalShareBackwardFilter; MGB_SEREG_OPR(LocalShareForwardV1, 0); MGB_SEREG_OPR(LocalShareBackwardDataV1, 0); MGB_SEREG_OPR(LocalShareBackwardFilterV1, 0); using ROIAlignV1 = ROIAlign; using ROIAlignBackwardV1 = ROIAlignBackward; MGB_SEREG_OPR(ROIAlignV1, 2); MGB_SEREG_OPR(ROIAlignBackwardV1, 4); MGB_SEREG_OPR(DeformableConvForward, 0); MGB_SEREG_OPR(DeformableConvBackwardData, 0); MGB_SEREG_OPR(DeformableConvBackwardFilter, 0); MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3); MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); using BatchConvBiasForwardV1 = BatchConvBiasForward; MGB_SEREG_OPR(BatchConvBiasForwardV1, 0); MGB_SEREG_OPR(FakeQuant, 3); MGB_SEREG_OPR(FakeQuantBackward, 4); MGB_SEREG_OPR(TQT, 2); MGB_SEREG_OPR(TQTBackward, 3); } // namespace opr } // namespace mgb // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}