From dbe2b89331d2d4b711395159cf746e2317ef0f43 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 3 Apr 2020 16:00:40 +0800 Subject: [PATCH] fix(midout): fix brain opr midout 2/2 (also see a6aa1574) fix extern_c_opr midout GitOrigin-RevId: 7de4f650d1c8fc3a6a4cedb04b4386c4cd66600f --- python_module/src/cpp/opr_defs.cpp | 2 +- src/serialization/impl/extern_c_opr.cpp | 1 + src/serialization/impl/extern_c_opr.sereg.h | 7 ++++--- .../include/megbrain/serialization/extern_c_opr_io.h | 9 +++++---- src/serialization/include/megbrain/serialization/sereg.h | 9 +++++++++ src/serialization/test/extern_c_opr.cpp | 4 ++-- src/serialization/test/extern_c_opr_v23.cpp | 2 +- 7 files changed, 23 insertions(+), 11 deletions(-) diff --git a/python_module/src/cpp/opr_defs.cpp b/python_module/src/cpp/opr_defs.cpp index 8107094e7..6dc69be54 100644 --- a/python_module/src/cpp/opr_defs.cpp +++ b/python_module/src/cpp/opr_defs.cpp @@ -224,7 +224,7 @@ SymbolVarArray _Opr::extern_c_opr_placeholder( } } - auto opr = serialization::ExternCOprRunner::make_placeholder( + auto opr = opr::ExternCOprRunner::make_placeholder( inputs, cpp_output_shapes, dump_name, PyBytes_AsString(data_bytes), PyBytes_Size(data_bytes), config, cpp_output_dtypes); SymbolVarArray ret; diff --git a/src/serialization/impl/extern_c_opr.cpp b/src/serialization/impl/extern_c_opr.cpp index 45acb31eb..f4f2a4e35 100644 --- a/src/serialization/impl/extern_c_opr.cpp +++ b/src/serialization/impl/extern_c_opr.cpp @@ -18,6 +18,7 @@ using namespace mgb; using namespace serialization; +using namespace opr; namespace { diff --git a/src/serialization/impl/extern_c_opr.sereg.h b/src/serialization/impl/extern_c_opr.sereg.h index de3486b95..5f545f0b2 100644 --- a/src/serialization/impl/extern_c_opr.sereg.h +++ b/src/serialization/impl/extern_c_opr.sereg.h @@ -16,18 +16,19 @@ namespace mgb { namespace serialization { template <> -struct OprLoadDumpImpl { +struct OprLoadDumpImpl { static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { - ExternCOprRunner::dump(ctx, opr); + opr::ExternCOprRunner::dump(ctx, opr); } static cg::OperatorNodeBase* load(OprLoadContext& ctx, const cg::VarNodeArray& inputs, const OperatorNodeConfig& config) { - return ExternCOprRunner::load(ctx, inputs, config); + return opr::ExternCOprRunner::load(ctx, inputs, config); } }; +using ExternCOprRunner = opr::ExternCOprRunner; MGB_SEREG_OPR(ExternCOprRunner, 0); MGB_REG_OPR_SHALLOW_COPY(ExternCOprRunner, ExternCOprRunner::shallow_copy); } // namespace serialization diff --git a/src/serialization/include/megbrain/serialization/extern_c_opr_io.h b/src/serialization/include/megbrain/serialization/extern_c_opr_io.h index ff58315a9..060a54745 100644 --- a/src/serialization/include/megbrain/serialization/extern_c_opr_io.h +++ b/src/serialization/include/megbrain/serialization/extern_c_opr_io.h @@ -16,7 +16,7 @@ #include "megbrain/serialization/opr_registry.h" namespace mgb { -namespace serialization { +namespace opr { //! an operator to run extern C oprs MGB_DEFINE_OPR_CLASS(ExternCOprRunner, @@ -68,10 +68,11 @@ public: static bool unregister_loader(const char* name); //! impl for serialization dump - static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr); + static void dump(serialization::OprDumpContext& ctx, + const cg::OperatorNodeBase& opr); //! impl for serialization load - static cg::OperatorNodeBase* load(OprLoadContext& ctx, + static cg::OperatorNodeBase* load(serialization::OprLoadContext& ctx, const cg::VarNodeArray& inputs, const OperatorNodeConfig& config); @@ -88,7 +89,7 @@ public: static TensorShape tensor_shape_from_c(const MGBTensorShape& shape); }; -} // namespace serialization +} // namespace opr } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/serialization/include/megbrain/serialization/sereg.h b/src/serialization/include/megbrain/serialization/sereg.h index 0f0ace3b2..51951c0d0 100644 --- a/src/serialization/include/megbrain/serialization/sereg.h +++ b/src/serialization/include/megbrain/serialization/sereg.h @@ -179,9 +179,18 @@ namespace { \ } \ MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) +//! use to check type is complete or not, midout need a complete type +template +struct IsComplete : std::false_type {}; + +template +struct IsComplete : std::true_type {}; + //! call OprRegistry::add with only loader, used for backward compatibility #define MGB_SEREG_OPR_COMPAT(_name, _load) \ namespace { \ + static_assert(IsComplete<_name>(), \ + "need a complete type for MGB_SEREG_OPR_COMPAT"); \ struct _OprReg##_name { \ static cg::OperatorNodeBase* compat_loader( \ serialization::OprLoadContext& ctx, \ diff --git a/src/serialization/test/extern_c_opr.cpp b/src/serialization/test/extern_c_opr.cpp index 5e46c8c40..0b29bc67c 100644 --- a/src/serialization/test/extern_c_opr.cpp +++ b/src/serialization/test/extern_c_opr.cpp @@ -182,7 +182,7 @@ std::vector create_graph_dump(float bias, float extra_scale, auto x = opr::Host2DeviceCopy::make(*graph, host_x); if (sleep) x = opr::Sleep::make(x, sleep); - x = serialization::ExternCOprRunner::make_placeholder( + x = opr::ExternCOprRunner::make_placeholder( {x}, {TensorShape{1}}, dtype == MGB_DTYPE_FLOAT32 ? "bias_adder_dump" @@ -280,7 +280,7 @@ TEST(TestExternCOpr, Dedup) { auto graph = ComputingGraph::make(); auto x = opr::Host2DeviceCopy::make(*graph, host_x); auto make_opr = [x](float bias) { - return ExternCOprRunner::make_from_desc( + return opr::ExternCOprRunner::make_from_desc( {x.node()}, MGBOprDescImpl<>::make(bias)); }; auto y0 = make_opr(0.5), y1 = make_opr(0.6), y2 = make_opr(0.5); diff --git a/src/serialization/test/extern_c_opr_v23.cpp b/src/serialization/test/extern_c_opr_v23.cpp index 166dc5d59..339b0861a 100644 --- a/src/serialization/test/extern_c_opr_v23.cpp +++ b/src/serialization/test/extern_c_opr_v23.cpp @@ -115,7 +115,7 @@ std::vector create_graph_dump(float bias, float extra_scale, auto x = opr::Host2DeviceCopy::make(*graph, host_x); if (sleep) x = opr::Sleep::make(x, sleep); - x = serialization::ExternCOprRunner::make_placeholder( + x = opr::ExternCOprRunner::make_placeholder( {x}, {TensorShape{1}}, "bias_adder_dump_v23", &bias, sizeof(bias)) ->output(0); if (extra_scale) -- GitLab