/** * \file src/opr/impl/dnn/dnn.sereg.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/images2neibs.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/adaptive_pooling.h" #include "megbrain/opr/dnn/roi_pooling.h" #include "megbrain/opr/dnn/roi_align.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/lrn.h" #include "megbrain/serialization/sereg.h" namespace mgb { namespace serialization { template struct MakeConvCaller2 { template static VarNode* make(const cg::VarNodeArray &inputs, const typename MegDNNConv::Param ¶m, const megdnn::param::ExecutionPolicy &execution_policy, const OperatorNodeConfig &config) { if (inputs.size() == 2) { return Opr::make( inputs[0], inputs[1], param, execution_policy, config).node(); } return nullptr; } }; template struct MakeConvCaller3 { template static VarNode* make(const cg::VarNodeArray &inputs, const typename MegDNNConv::Param ¶m, const megdnn::param::ExecutionPolicy &execution_policy, const OperatorNodeConfig &config) { if (inputs.size() == 3) { return Opr::make( inputs[0], inputs[1], inputs[2], param, execution_policy, config).node(); } return nullptr; } }; template struct MakeConvCaller4 { template static VarNode* make(const cg::VarNodeArray &inputs, const typename MegDNNConv::Param ¶m, const megdnn::param::ExecutionPolicy &execution_policy, const OperatorNodeConfig &config) { if (inputs.size() == 4) { return Opr::make( inputs[0], inputs[1], inputs[2], inputs[3], param, execution_policy, config).node(); } return nullptr; } }; template struct MakeConvCaller5 { template static VarNode* make( const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, const megdnn::param::ExecutionPolicy& execution_policy, const OperatorNodeConfig& config) { if (inputs.size() == 5) { return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], param, execution_policy, config) .node(); } return nullptr; } }; template struct MakeConvCallerEmpty { template static VarNode* make(const cg::VarNodeArray &, const typename MegDNNConv::Param &, const megdnn::param::ExecutionPolicy &, const OperatorNodeConfig &) { return nullptr; } }; template, class Maker2=MakeConvCallerEmpty, typename ConvParam = megdnn::param::Convolution > struct ConvLoadDumpImpl { static void dump(OprDumpContext &ctx, const cg::OperatorNodeBase &opr_) { auto &&opr = opr_.cast_final_safe(); ctx.write_param(opr.param()); ctx.write_param( opr.execution_policy()); } static VarNode* make( const cg::VarNodeArray& inputs, const ConvParam& param, const megdnn::param::ExecutionPolicy& execution_policy, const OperatorNodeConfig& config) { VarNode* ret = Maker0::template make(inputs, param, execution_policy, config); if (!ret) { ret = Maker1::template make(inputs, param, execution_policy, config); } if (!ret) { ret = Maker2::template make(inputs, param, execution_policy, config); } mgb_assert(ret); return ret; } static cg::OperatorNodeBase* load( OprLoadContext &ctx, const cg::VarNodeArray &inputs, const OperatorNodeConfig &config) { auto param = ctx.read_param(); auto execution_policy = ctx.read_param(); return make(inputs, param, execution_policy, config)->owner_opr(); } }; template<> struct OprLoadDumpImpl: public ConvLoadDumpImpl, megdnn::Convolution> {}; template<> struct OprLoadDumpImpl: public ConvLoadDumpImpl, megdnn::Convolution, MakeConvCaller3 > {}; template<> struct OprLoadDumpImpl: public ConvLoadDumpImpl, megdnn::Convolution> {}; template<> struct OprLoadDumpImpl: public ConvLoadDumpImpl, megdnn::Convolution3D, MakeConvCallerEmpty, MakeConvCallerEmpty, megdnn::param::Convolution3D> {}; template<> struct OprLoadDumpImpl: public ConvLoadDumpImpl, megdnn::Convolution3D, MakeConvCaller3, MakeConvCallerEmpty, megdnn::param::Convolution3D> {}; template<> struct OprLoadDumpImpl: public ConvLoadDumpImpl, megdnn::Convolution3D, MakeConvCallerEmpty, MakeConvCallerEmpty, megdnn::param::Convolution3D> {}; template<> struct OprLoadDumpImpl: public ConvLoadDumpImpl, megdnn::ConvBiasForward, MakeConvCaller3, MakeConvCaller4, megdnn::param::ConvBias> {}; template <> struct OprLoadDumpImpl : public ConvLoadDumpImpl< opr::BatchConvBiasForward, MakeConvCaller2, megdnn::BatchConvBiasForward, MakeConvCaller3, MakeConvCaller4, megdnn::param::BatchConvBias> {}; template <> struct OprMaker { using Param = opr::BatchNorm::Param; static cg::OperatorNodeBase* make(const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); if (i.size() == 3) { return opr::BatchNorm::make(i[0], i[1], i[2], param, config)[0].node()->owner_opr(); } else { mgb_assert(i.size() == 5); return opr::BatchNorm::make(i[0], i[1], i[2], i[3], i[4], param, config)[0].node()->owner_opr(); } } }; template <> struct OprMaker { using Param = opr::BatchNormBackward::Param; static cg::OperatorNodeBase* make(const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], param, config)[0].node()->owner_opr(); } }; template struct MakeLocalShareCaller2 { template static VarNode* make(const cg::VarNodeArray &inputs, const typename MegDNNConv::Param ¶m, const megdnn::param::ExecutionPolicy &execution_policy, const OperatorNodeConfig &config) { if (inputs.size() == 2) { return Opr::make( inputs[0], inputs[1], param, execution_policy, config).node(); } return nullptr; } }; template struct MakeLocalShareCaller3 { template static VarNode* make(const cg::VarNodeArray &inputs, const typename MegDNNConv::Param ¶m, const megdnn::param::ExecutionPolicy &execution_policy, const OperatorNodeConfig &config) { if (inputs.size() == 3) { return Opr::make( inputs[0], inputs[1], inputs[2], param, execution_policy, config).node(); } return nullptr; } }; template struct MakeLocalShareCallerEmpty { template static VarNode* make(const cg::VarNodeArray &, const typename MegDNNConv::Param &, const megdnn::param::ExecutionPolicy &, const OperatorNodeConfig &) { return nullptr; } }; template, class Maker2=MakeLocalShareCallerEmpty, typename LocalShareParam = megdnn::param::LocalShare > struct LocalShareLoadDumpImpl { static void dump(OprDumpContext &ctx, const cg::OperatorNodeBase &opr_) { auto &&opr = opr_.cast_final_safe(); ctx.write_param(opr.param()); ctx.write_param( opr.execution_policy()); } static VarNode* make( const cg::VarNodeArray& inputs, const LocalShareParam& param, const megdnn::param::ExecutionPolicy& execution_policy, const OperatorNodeConfig& config) { VarNode* ret = Maker0::template make(inputs, param, execution_policy, config); if (!ret) { ret = Maker1::template make(inputs, param, execution_policy, config); } if (!ret) { ret = Maker2::template make(inputs, param, execution_policy, config); } mgb_assert(ret); return ret; } static cg::OperatorNodeBase* load( OprLoadContext &ctx, const cg::VarNodeArray &inputs, const OperatorNodeConfig &config) { auto param = ctx.read_param(); auto execution_policy = ctx.read_param(); return make(inputs, param, execution_policy, config)->owner_opr(); } }; template <> struct OprLoadDumpImpl : public LocalShareLoadDumpImpl< opr::LocalShare, MakeLocalShareCaller2, megdnn::LocalShare> {}; template <> struct OprLoadDumpImpl : public LocalShareLoadDumpImpl< opr::LocalShareBackwardData, MakeLocalShareCaller3, megdnn::LocalShare> {}; template <> struct OprLoadDumpImpl : public LocalShareLoadDumpImpl< opr::LocalShareBackwardFilter, MakeLocalShareCaller3, megdnn::LocalShare> {}; template<> struct OprLoadDumpImpl: public ConvLoadDumpImpl, megdnn::Convolution> {}; template<> struct OprLoadDumpImpl: public ConvLoadDumpImpl, megdnn::Convolution> {}; template<> struct OprLoadDumpImpl: public ConvLoadDumpImpl, megdnn::Convolution> {}; } // namespace serialization namespace opr { using ConvolutionV1 = Convolution; using ConvolutionBackwardDataV1 = ConvolutionBackwardData; using ConvolutionBackwardFilterV1 = ConvolutionBackwardFilter; MGB_SEREG_OPR(ConvolutionV1, 0); MGB_SEREG_OPR(ConvolutionBackwardDataV1, 0); MGB_SEREG_OPR(ConvolutionBackwardFilterV1, 0); MGB_SEREG_OPR(Images2Neibs, 1); MGB_SEREG_OPR(Images2NeibsBackward, 2); using LocalV1 = Local; using LocalBackwardDataV1 = LocalBackwardData; using LocalBackwardFilterV1 = LocalBackwardFilter; MGB_SEREG_OPR(LocalV1, 2); MGB_SEREG_OPR(LocalBackwardDataV1, 3); MGB_SEREG_OPR(LocalBackwardFilterV1, 3); using GroupLocalV1 = GroupLocal; using GroupLocalBackwardDataV1 = GroupLocalBackwardData; using GroupLocalBackwardFilterV1 = GroupLocalBackwardFilter; MGB_SEREG_OPR(GroupLocalV1, 2); MGB_SEREG_OPR(GroupLocalBackwardDataV1, 3); MGB_SEREG_OPR(GroupLocalBackwardFilterV1, 3); MGB_SEREG_OPR(LRN, 1); MGB_SEREG_OPR(LRNBackward, 3); MGB_SEREG_OPR(Pooling, 1); MGB_SEREG_OPR(PoolingBackward, 3); MGB_SEREG_OPR(AdaptivePooling, 2); MGB_SEREG_OPR(AdaptivePoolingBackward, 4); MGB_SEREG_OPR(ROIPooling, 3); MGB_SEREG_OPR(ROIPoolingBackward, 4); using MaskConvolutionV1 = MaskConvolution; MGB_SEREG_OPR(MaskConvolutionV1, 3); MGB_SEREG_OPR(MaskPropagate, 1); MGB_SEREG_OPR(Convolution3D, 0); MGB_SEREG_OPR(Convolution3DBackwardData, 0); MGB_SEREG_OPR(Convolution3DBackwardFilter, 0); using ConvBiasForwardV3 = ConvBiasForward; MGB_SEREG_OPR(ConvBiasForwardV3, 0); MGB_SEREG_OPR(BatchNorm, 0); MGB_SEREG_OPR(BatchNormBackward, 5); MGB_SEREG_OPR(LocalShareForward, 0); MGB_SEREG_OPR(LocalShareBackwardData, 0); MGB_SEREG_OPR(LocalShareBackwardFilter, 0); MGB_SEREG_OPR(ROIAlign, 2); MGB_SEREG_OPR(ROIAlignBackward, 4); MGB_SEREG_OPR(DeformableConvForward, 0); MGB_SEREG_OPR(DeformableConvBackwardData, 0); MGB_SEREG_OPR(DeformableConvBackwardFilter, 0); MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3); MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); MGB_SEREG_OPR(BatchConvBiasForward, 0); } // namespace opr } // namespace mgb // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}