diff --git a/imperative/src/impl/ops/opr_attr.cpp b/imperative/src/impl/ops/opr_attr.cpp index ca16f7a10d5c34b18dc1dfd826ad7421d848292a..b2e83fa41b58dfc719ecda30a9fac1204a37825b 100644 --- a/imperative/src/impl/ops/opr_attr.cpp +++ b/imperative/src/impl/ops/opr_attr.cpp @@ -76,7 +76,7 @@ public: } }; -cg::OperatorNodeBase* apply_on_var_node( +VarNodeArray apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& attr = def.cast_final_safe(); auto config = attr.config; @@ -85,7 +85,7 @@ cg::OperatorNodeBase* apply_on_var_node( auto registry = serialization::OprRegistry::find_by_name(attr.type); mgb_assert(registry, "operator %s not found", attr.type.c_str()); OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; - return registry->loader(ctx, inputs, config); + return registry->loader(ctx, inputs, config).usable_output(); } std::shared_ptr make_from_op_node(cg::OperatorNodeBase* opr) { diff --git a/imperative/src/test/backward_graph.cpp b/imperative/src/test/backward_graph.cpp index 90e53ce7a77f0e5895d593e569859412ce60f873..2d0e972d61808f768ec473313ad54d0f2b980866 100644 --- a/imperative/src/test/backward_graph.cpp +++ b/imperative/src/test/backward_graph.cpp @@ -200,7 +200,7 @@ TEST(TestImperative, BatchNormGrad) { LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; { - auto op = OprAttr::make("BatchNorm"); + auto op = OprAttr::make("BatchNormV1"); auto&& attr = op->cast_final_safe(); Param param; param.fwd_mode = Param::FwdMode::TRAINING; @@ -210,7 +210,7 @@ TEST(TestImperative, BatchNormGrad) { {false, false, false, false, false, true}); } { - auto op = OprAttr::make("BatchNorm"); + auto op = OprAttr::make("BatchNormV1"); auto&& attr = op->cast_final_safe(); Param param; param.fwd_mode = Param::FwdMode::TRAINING; diff --git a/imperative/src/test/imperative.cpp b/imperative/src/test/imperative.cpp index 329370f3c2bdf18458791dd4016f9063c5aada75..40a19cbb8dec905846b2e9ed063b057a607ca560 100644 --- a/imperative/src/test/imperative.cpp +++ b/imperative/src/test/imperative.cpp @@ -59,7 +59,7 @@ TEST(TestImperative, Reduce) { } TEST(TestImperative, BatchNorm) { - auto op = OprAttr::make("BatchNorm"); + auto op = OprAttr::make("BatchNormV1"); auto&& attr = op->cast_final_safe(); using Param = opr::BatchNorm::Param; Param param; diff --git a/src/opr/impl/dnn/dnn.sereg.h b/src/opr/impl/dnn/dnn.sereg.h index c0fb6b645839808851670ad7fe8094726c33a2f2..e407cb0e55328c7e0f7cbf03f94dca168c6cea54 100644 --- a/src/opr/impl/dnn/dnn.sereg.h +++ b/src/opr/impl/dnn/dnn.sereg.h @@ -16,14 +16,13 @@ #include "megbrain/opr/dnn/correlation.h" #include "megbrain/opr/dnn/fake_quant.h" #include "megbrain/opr/dnn/images2neibs.h" -#include "megbrain/opr/dnn/sliding_window_transpose.h" -#include "megbrain/opr/dnn/adaptive_pooling.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/lsq.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/roi_align.h" #include "megbrain/opr/dnn/roi_pooling.h" +#include "megbrain/opr/dnn/sliding_window_transpose.h" #include "megbrain/opr/dnn/tqt.h" #include "megbrain/serialization/sereg.h" #include "megdnn/opr_param_defs.h" @@ -390,6 +389,7 @@ struct OprMaker { } }; +// OprMaker in MGB_SEREG_OPR only support unique output opr template <> struct OprMaker { using Param = opr::BatchNormBackward::Param; @@ -398,8 +398,8 @@ struct OprMaker { ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); - return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], param, - config)[0] + return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], + param, config)[0] .node() ->owner_opr(); } @@ -575,8 +575,10 @@ MGB_SEREG_OPR(Convolution3DBackwardFilter, 0); using ConvBiasForwardV4 = ConvBiasForward; MGB_SEREG_OPR(ConvBiasForwardV4, 0); -MGB_SEREG_OPR(BatchNorm, 0); -MGB_SEREG_OPR(BatchNormBackward, 6); +using BatchNormV1 = BatchNorm; +using BatchNormBackwardV1 = BatchNormBackward; +MGB_SEREG_OPR(BatchNormV1, 0); +MGB_SEREG_OPR(BatchNormBackwardV1, 6); using LocalShareForwardV1 = LocalShareForward; using LocalShareBackwardDataV1 = LocalShareBackwardData; diff --git a/src/serialization/impl/opr_registry.cpp b/src/serialization/impl/opr_registry.cpp index d4f96717b718b0ac3babc57b13cdfe3ed0540dc5..f2aac20e4fa028721a4cb0d9ee4bb98d8852f642 100644 --- a/src/serialization/impl/opr_registry.cpp +++ b/src/serialization/impl/opr_registry.cpp @@ -39,7 +39,7 @@ namespace { return inst; } - cg::OperatorNodeBase* dynamic_loader( + OprWithOutputAccessor dynamic_loader( OprLoadContext &ctx, const cg::VarNodeArray &inputs, const OperatorNodeConfig &config) { auto name = ctx.load_buf_with_len(); @@ -171,4 +171,20 @@ std::vector> OprRegistry::dump_registries() { } #endif +namespace { +const VarNodeArray& default_accessor(const VarNodeArray& outputs) { + return outputs; +} +} + +OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr) : m_opr(opr){ + m_accessor = &default_accessor; +}; +OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor) + : OprWithOutputAccessor(opr) { + if (accessor) { + m_accessor = accessor; + } +}; + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/serialization/impl/opr_shallow_copy.cpp b/src/serialization/impl/opr_shallow_copy.cpp index a503e6a8f5073c362fb13a5300a6c4bb5d3d42b4..4673189e3d2a20ce5b660610fcdcbb5987752765 100644 --- a/src/serialization/impl/opr_shallow_copy.cpp +++ b/src/serialization/impl/opr_shallow_copy.cpp @@ -207,7 +207,7 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl( registry->dumper(dumper, opr); OprLoadContextMemory loader{opr.owner_graph(), dumper}; - return registry->loader(loader, inputs, config); + return registry->loader(loader, inputs, config).opr(); } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/serialization/impl/serializer_oss.cpp b/src/serialization/impl/serializer_oss.cpp index dfd1af3bc7a9f13c44dbe4a8ebb1540db6339768..c90cd589acb187fa21da8bb05fa4abc64eb14a14 100644 --- a/src/serialization/impl/serializer_oss.cpp +++ b/src/serialization/impl/serializer_oss.cpp @@ -782,7 +782,8 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( } // call loader - auto opr = registry->loader(*this, inputs, config); + auto accessor = registry->loader(*this, inputs, config); + auto opr = accessor.opr(); // check opr type; note that: // 1. registry->type may be empty for dynamic opr loaders or legacy oprs @@ -794,7 +795,7 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( opr ? opr->dyn_typeinfo()->name : nullptr, registry->type->name); // record output vars; read output names size_t i = 0; - for (auto ovar : opr->output()) { + for (auto ovar : accessor.output()) { if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { m_id2varnode.push_back(ovar); if (fbopr->output_name()) { diff --git a/src/serialization/include/megbrain/serialization/opr_registry.h b/src/serialization/include/megbrain/serialization/opr_registry.h index 6946deb44339087a02d64f0261844271a28d72c8..72041f8a597947a108e6d8d9289847f481c08397 100644 --- a/src/serialization/include/megbrain/serialization/opr_registry.h +++ b/src/serialization/include/megbrain/serialization/opr_registry.h @@ -19,16 +19,36 @@ namespace serialization { class OprDumpContext; class OprLoadContext; class OprShallowCopyContext; + class OprWithOutputAccessor { + cg::OperatorNodeBase* m_opr; + using Accessor = thin_function; + Accessor m_accessor; + + public: + OprWithOutputAccessor(cg::OperatorNodeBase* opr); + OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor); + VarNode* output(size_t idx) const { return output().at(idx); } + VarNodeArray output() const { return m_accessor(m_opr->output()); } + VarNodeArray usable_output() const { return m_accessor(m_opr->usable_output()); } + cg::OperatorNodeBase* opr() { return m_opr; } + }; + //! dump opr internal params to OprDumpContext using OprDumper = thin_function; //! load and restore operator from OprLoadContext + //! is also used by GraphLoadConfig. using OprLoader = thin_function; + //! loader that can change opr output map for compatibility + using OprLoaderWrapper = thin_function; + //! shallow copy function for a single operator using OprShallowCopy = thin_function; \ - MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, Impl::load); \ - } \ - }; \ - } \ +#define MGB_SEREG_OPR(_cls, _arity) \ + namespace { \ + namespace ser = ::mgb::serialization; \ + struct _OprReg##_cls { \ + using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \ + static ser::OprWithOutputAccessor wrap_loader( \ + ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ + const mgb::cg::OperatorNodeConfig& config) { \ + return ser::OprWithOutputAccessor( \ + Impl::load(ctx, inputs, config)); \ + } \ + static void entry() { \ + MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \ + } \ + }; \ + } \ MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) //! use to check type is complete or not, midout need a complete type @@ -187,33 +193,35 @@ template struct IsComplete : std::true_type {}; //! call OprRegistry::add with only loader, used for backward compatibility -#define MGB_SEREG_OPR_COMPAT(_name, _load) \ - namespace { \ - static_assert(IsComplete<_name>(), \ - "need a complete type for MGB_SEREG_OPR_COMPAT"); \ - struct _OprReg##_name { \ - static cg::OperatorNodeBase* compat_loader( \ - serialization::OprLoadContext& ctx, \ - const cg::VarNodeArray& inputs, \ - const OperatorNodeConfig& config) { \ - return _load( \ - static_cast(ctx), \ - inputs, config); \ - } \ - static void entry() { \ - ::mgb::serialization::OprRegistry::add( \ - {nullptr, \ - MGB_HASH_STR(#_name), \ - _MGB_SEREG_OPR_NAME_FROM_CLS(_name), \ - nullptr, \ - compat_loader, \ - {}, \ - {}}); \ - } \ - }; \ - } \ +#define MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, _accessor) \ + namespace { \ + static_assert(IsComplete<_name>(), \ + "need a complete type for MGB_SEREG_OPR_COMPAT"); \ + namespace ser = ::mgb::serialization; \ + struct _OprReg##_name { \ + static ser::OprWithOutputAccessor compat_loader( \ + ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ + const mgb::cg::OperatorNodeConfig& config) { \ + auto&& ctx_ = static_cast(ctx); \ + return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), \ + _accessor); \ + } \ + static void entry() { \ + ser::OprRegistry::add({nullptr, \ + MGB_HASH_STR(#_name), \ + _MGB_SEREG_OPR_NAME_FROM_CLS(_name), \ + nullptr, \ + compat_loader, \ + {}, \ + {}}); \ + } \ + }; \ + } \ MGB_SEREG_OPR_INTL_CALL_ENTRY(_name, _OprReg##_name) +#define MGB_SEREG_OPR_COMPAT(_name, _load) \ + MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, nullptr) + /*! * \brief use \p _copy to implement shallow copy for given operator */