diff --git a/src/opr/impl/custom_opnode.sereg.h b/src/opr/impl/custom_opnode.sereg.h index 3500d92415e7ae041fb443665f50f48ebec6b81d..dd4a6bd87f08fd4f924c6c5d424ce556b175e703 100644 --- a/src/opr/impl/custom_opnode.sereg.h +++ b/src/opr/impl/custom_opnode.sereg.h @@ -40,16 +40,16 @@ mgb::cg::OperatorNodeBase* custom_loader( } // namespace serialization } // namespace mgb -#define CUSTOM_OP_SEREG_REG(cls) \ - namespace { \ - struct _OprReg##cls { \ - static void entry() { \ - MGB_SEREG_OPR_INTL_CALL_ADD( \ - cls, ::mgb::serialization::custom_dumper, \ - ::mgb::serialization::custom_loader); \ - } \ - }; \ - } \ +#define CUSTOM_OP_SEREG_REG(cls) \ + namespace { \ + struct _OprReg##cls { \ + static void entry() { \ + MGB_SEREG_OPR_INTL_CALL_ADD( \ + cls, ::mgb::serialization::custom_dumper, \ + ::mgb::serialization::custom_loader, true); \ + } \ + }; \ + } \ MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls) #define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \ diff --git a/src/opr/impl/loop/forward_sereg.cpp b/src/opr/impl/loop/forward_sereg.cpp index 8eb1bccd0f8799263dc52bb5fc1b4d47a5b4646c..1884adc7dec4d0e70ada570ad7441854e9010b13 100644 --- a/src/opr/impl/loop/forward_sereg.cpp +++ b/src/opr/impl/loop/forward_sereg.cpp @@ -131,10 +131,10 @@ cg::OperatorNodeBase* serialization::opr_shallow_copy_loop( } void LoopSerializer::reg_all() { - MGB_SEREG_OPR_INTL_CALL_ADD(opr::Loop, dump_loop, load_loop); - MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker); + MGB_SEREG_OPR_INTL_CALL_ADD(opr::Loop, dump_loop, load_loop, true); + MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker, true); MGB_SEREG_OPR_INTL_CALL_ADD( - CounterProvider, dump_counter_provider, load_counter_provider); + CounterProvider, dump_counter_provider, load_counter_provider, true); MGB_SEREG_OPR_INTL_CALL_ADD_V2( opr::Loop, dump_loop, load_loop, nullptr, 2, CURRENT_VERSION); diff --git a/src/opr/impl/nn_int.sereg.h b/src/opr/impl/nn_int.sereg.h index b4ae73070343230015bed23c5df3b34d61316762..4527d44b1396298cf5a333b6e18e3c7f00255919 100644 --- a/src/opr/impl/nn_int.sereg.h +++ b/src/opr/impl/nn_int.sereg.h @@ -1,3 +1,4 @@ +#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/nn_int.h" #include "megbrain/serialization/sereg.h" @@ -7,10 +8,74 @@ template <> struct OprMaker : public OprMakerVariadic {}; +template <> +struct OprLoadDumpImplV2 { + using Opr = opr::ElemwiseMultiType; + using PersisParam = opr::ElemwiseMultiType::Param; + using PersisElemwseiParam = opr::Elemwise::Param; + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { + ctx.write_param(opr.cast_final_safe().param()); + } + + static cg::OperatorNodeBase* replace_opr( + cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { + auto mode = opr->cast_final_safe().param().mode; + auto change_to_elemwise_mode = [&](PersisParam::Mode multitype_mode) { + if (multitype_mode == PersisParam::Mode::EQ) { + return PersisElemwseiParam::Mode::EQ; + } else if (multitype_mode == PersisParam::Mode::LT) { + return PersisElemwseiParam::Mode::LT; + } else if (multitype_mode == PersisParam::Mode::LEQ) { + return PersisElemwseiParam::Mode::LEQ; + } + mgb_assert(0, "no supported model."); + }; + if (PersisParam::Mode::EQ == mode || PersisParam::Mode::LT == mode || + PersisParam::Mode::LEQ == mode) { + auto elemwise_mode = change_to_elemwise_mode(mode); + auto elemiwse_out = opr::Elemwise::make(inputs, {elemwise_mode}); + return opr::TypeCvt::make(elemiwse_out, dtype::Bool()).node()->owner_opr(); + } else if (PersisParam::Mode::NEQ == mode) { + auto elemiwse_out = + opr::Elemwise::make(inputs, {PersisElemwseiParam::Mode::EQ}); + auto bool_out = opr::TypeCvt::make(elemiwse_out, dtype::Bool()); + return opr::Elemwise::make({bool_out}, {PersisElemwseiParam::Mode::NOT}) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::ISNAN == mode) { + auto elemiwse_out = opr::Elemwise::make( + {inputs[0], inputs[0]}, {PersisElemwseiParam::Mode::EQ}); + auto bool_out = opr::TypeCvt::make(elemiwse_out, dtype::Bool()); + return opr::Elemwise::make({bool_out}, {PersisElemwseiParam::Mode::NOT}) + .node() + ->owner_opr(); + } else if (PersisParam::Mode::ISINF == mode) { + auto input_var = SymbolVar{inputs[0]}; + auto inf_var = input_var.make_scalar(INFINITY); + auto float_out = opr::TypeCvt::make(inputs[0], dtype::Float32()); + auto elemiwse_out = opr::Elemwise::make( + {float_out, inf_var}, {PersisElemwseiParam::Mode::EQ}); + return opr::TypeCvt::make(elemiwse_out, dtype::Bool()).node()->owner_opr(); + } + return opr; + } + + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + return OprMaker::make( + ctx.read_param(), inputs, ctx.graph(), config); + } +}; + } // namespace serialization namespace opr { -MGB_SEREG_OPR(ElemwiseMultiType, 0); +MGB_SEREG_OPR_CONDITION(ElemwiseMultiType, 0, false); +MGB_SEREG_OPR_V2( + ElemwiseMultiType, 0, + (mgb::serialization::OprLoadDumpImplV2::replace_opr), + VERSION_1, VERSION_1); MGB_SEREG_OPR(AffineInt, 3); } // namespace opr } // namespace mgb diff --git a/src/serialization/impl/opr_shallow_copy.cpp b/src/serialization/impl/opr_shallow_copy.cpp index 943fddf6f81c1086f06d8e3f82f54b500e35edc5..c4f1a9f3359e1ecc86177d8aa38eba7c01a695da 100644 --- a/src/serialization/impl/opr_shallow_copy.cpp +++ b/src/serialization/impl/opr_shallow_copy.cpp @@ -125,16 +125,18 @@ ComputingGraph* serialization::OprShallowCopyContext::owner_graph( cg::OperatorNodeBase* serialization::copy_opr_shallow( const cg::OperatorNodeBase& opr, const VarNodeArray& inputs, const OperatorNodeConfig& config, const OprShallowCopyContext& ctx) { - auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo()); - mgb_assert( - registry, "could not find OprReceiver to copy opr %s{%s}", opr.cname(), - opr.dyn_typeinfo()->name); + OprShallowCopy shallow_copy = nullptr; + if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) { + shallow_copy = registry->shallow_copy; + } else { + shallow_copy = intl::copy_opr_shallow_default_impl; + } mgb_assert(inputs.size() == opr.input().size()); auto dst_og = ctx.owner_graph(opr, inputs); auto do_copy = [&]() { auto nr_opr_before = opr.owner_graph()->nr_oprs_in_graph(); - auto ret = registry->shallow_copy(ctx, opr, inputs, config); + auto ret = shallow_copy(ctx, opr, inputs, config); if (dst_og != opr.owner_graph() || opr.owner_graph()->nr_oprs_in_graph() != nr_opr_before) { @@ -188,18 +190,28 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl( const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr, const VarNodeArray& inputs, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(ctx); + OprDumper opr_dumper = nullptr; + OprLoaderWrapper opr_loader = nullptr; - auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo()); + if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) { + opr_loader = registry->loader; + opr_dumper = registry->dumper; + } else { + auto registryv2 = OprRegistryV2::versioned_find_by_typeinfo( + opr.dyn_typeinfo(), CURRENT_VERSION); + opr_loader = registryv2->loader; + opr_dumper = registryv2->dumper; + } mgb_assert( - registry && registry->dumper && registry->loader, + opr_dumper && opr_loader, "can not shallow_copy operator %s{%s}: " "no dumper/loader registered", opr.cname(), opr.dyn_typeinfo()->name); - OprDumpContextMemory dumper; - registry->dumper(dumper, opr); + OprDumpContextMemory memory_dumper; + opr_dumper(memory_dumper, opr); - OprLoadContextMemory loader{opr.owner_graph(), dumper}; - return registry->loader(loader, inputs, config).opr(); + OprLoadContextMemory loader{opr.owner_graph(), memory_dumper}; + return opr_loader(loader, inputs, config).opr(); } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/serialization/impl/serializer_oss_v2.cpp b/src/serialization/impl/serializer_oss_v2.cpp index c1e94e14aa12e5d3415e8d6860a7027930c26496..0873ff236414a0c69e4703b8bd74036060789e2f 100644 --- a/src/serialization/impl/serializer_oss_v2.cpp +++ b/src/serialization/impl/serializer_oss_v2.cpp @@ -358,6 +358,11 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( auto new_output_vars = output_vars; if (!config.no_change_graph) { new_output_vars = converter_all_opr_to_compatiable(output_vars); + mgb_assert(output_vars.size() == new_output_vars.size()); + for (size_t id = 0; id < output_vars.size(); id++) { + auto& new_var = new_output_vars[id]; + new_var.rename(output_vars[id].node()->name()); + } } auto begin_pos = m_file->tell(); diff --git a/src/serialization/include/megbrain/serialization/sereg.h b/src/serialization/include/megbrain/serialization/sereg.h index 4849b0451d6ea90dad8391af848905590c9c24c9..3c61bed5d8450724c3a0f08526abfd931b19efca 100644 --- a/src/serialization/include/megbrain/serialization/sereg.h +++ b/src/serialization/include/megbrain/serialization/sereg.h @@ -151,20 +151,22 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl {}; //! call OprRegistryV2::versioned_add for new serialization which is compatiable //! with old serialization, convert is nullptr, this registry is just only for //! varsion 1 -#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \ - do { \ - ::mgb::serialization::OprRegistry::add( \ - {_cls::typeinfo(), \ - MGB_HASH_STR(#_cls), \ - _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ - _dump, \ - _load, \ - {}, \ - MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ - ::mgb::serialization::OprRegistryV2::versioned_add( \ - {_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ - _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \ - ::mgb::VERSION_1, ::mgb::VERSION_1); \ +#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load, _registerv2) \ + do { \ + ::mgb::serialization::OprRegistry::add( \ + {_cls::typeinfo(), \ + MGB_HASH_STR(#_cls), \ + _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ + _dump, \ + _load, \ + {}, \ + MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ + if (_registerv2) { \ + ::mgb::serialization::OprRegistryV2::versioned_add( \ + {_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ + _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \ + ::mgb::VERSION_1, ::mgb::VERSION_1); \ + } \ } while (0) //! call OprRegistryV2::versioned_add for new serialization, in which convert the @@ -181,23 +183,25 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl {}; /*! * \brief register opr serialization methods */ -#define MGB_SEREG_OPR(_cls, _arity) \ - namespace { \ - namespace ser = ::mgb::serialization; \ - struct _OprReg##_cls { \ - using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \ - static ser::OprWithOutputAccessor wrap_loader( \ - ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ - const mgb::cg::OperatorNodeConfig& config) { \ - return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ - } \ - static void entry() { \ - MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \ - } \ - }; \ - } \ +#define MGB_SEREG_OPR_CONDITION(_cls, _arity, _registerv2) \ + namespace { \ + namespace ser = ::mgb::serialization; \ + struct _OprReg##_cls { \ + using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \ + static ser::OprWithOutputAccessor wrap_loader( \ + ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ + const mgb::cg::OperatorNodeConfig& config) { \ + return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ + } \ + static void entry() { \ + MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader, _registerv2); \ + } \ + }; \ + } \ MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) +#define MGB_SEREG_OPR(_cls, _arity) MGB_SEREG_OPR_CONDITION(_cls, _arity, true) + //! new dump/load function should implement in OprLoadDumpImplV2, _converter is //! optional , if not implement pass nullptr #define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \ diff --git a/src/serialization/test/serializer_oss.cpp b/src/serialization/test/serializer_oss.cpp index f821827744fe9b4502d4e95013fc0f56d6e1d4b4..cc0d618a04bf3c9dece041bf881ba97ea9085200 100644 --- a/src/serialization/test/serializer_oss.cpp +++ b/src/serialization/test/serializer_oss.cpp @@ -1,3 +1,4 @@ +#include "megbrain/opr/nn_int.h" #if MGB_ENABLE_FBS_SERIALIZATION #include "megbrain/opr/basic_arith_wrapper.h" @@ -1016,4 +1017,107 @@ TEST(TestSerializer2, TestSoftMaxLoadDump) { load(); } +TEST(TestSerializer2, TestElemwiseMultiTypeLoadDump) { + auto fname = GET_OUTPUT_FILE(GraphDumpFormat::FLATBUFFERS_V2); + TensorShape shape{3}; + auto cn = CompNode::load("xpu0"); + std::shared_ptr host0 = + std::make_shared(cn, shape, dtype::Float32{}); + std::shared_ptr host1 = + std::make_shared(cn, shape, dtype::Float32{}); + HostTensorND dst_truth; + host0->ptr()[0] = 2; + host0->ptr()[1] = 2; + host0->ptr()[2] = -1; + host1->ptr()[0] = 1; + host1->ptr()[1] = 2; + host1->ptr()[2] = 3; + + auto dump = [&](opr::ElemwiseMultiType::Param::Mode mode, size_t nr_opr) { + auto graph = ComputingGraph::make(); + OperatorNodeConfig config; + config.name("input0"); + auto h2d0 = opr::Host2DeviceCopy::make(*graph, host0, config); + config.name("input1"); + auto h2d1 = opr::Host2DeviceCopy::make(*graph, host1, config); + + auto x = opr::ElemwiseMultiType::make( + {h2d0, h2d1}, {mode}, OperatorNodeConfig{dtype::Bool()}); + x.rename("out"); + auto func = graph->compile({make_callback_copy(x, dst_truth)}); + auto dumper = GraphDumper::make( + OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); + auto rst = dumper->dump({x}); + func->execute().wait(); + ASSERT_EQ(rst.nr_opr, nr_opr); + }; + auto load = [&]() { + auto loader = GraphLoader::make( + InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); + auto rst = loader->load(); + ASSERT_EQ(rst.tensor_map.size(), 2); + ASSERT_EQ(rst.output_var_map.count("out"), 1); + + HostTensorND host_x; + auto func = + rst.graph_compile({make_callback_copy(rst.output_var_list[0], host_x)}); + for (auto& input : rst.tensor_map) { + if (input.first == "input0") { + input.second->copy_from(*host0).sync(); + } else if (input.first == "input1") { + input.second->copy_from(*host1).sync(); + } + } + func->execute().wait(); + for (int i = 0; i < 3; i++) { + EXPECT_EQ(host_x.ptr()[i], dst_truth.ptr()[i]); + } + }; + dump(opr::ElemwiseMultiType::Param::Mode::EQ, 4); + load(); + dump(opr::ElemwiseMultiType::Param::Mode::LT, 4); + load(); + dump(opr::ElemwiseMultiType::Param::Mode::LEQ, 4); + load(); + dump(opr::ElemwiseMultiType::Param::Mode::NEQ, 5); + load(); + + auto dump_single_input = [&](opr::ElemwiseMultiType::Param::Mode mode, + size_t nr_opr) { + auto graph = ComputingGraph::make(); + auto h2d0 = opr::Host2DeviceCopy::make(*graph, host0); + auto x = opr::ElemwiseMultiType::make( + {h2d0}, {mode}, OperatorNodeConfig{dtype::Bool()}); + x.rename("out"); + auto func = graph->compile({make_callback_copy(x, dst_truth)}); + auto dumper = GraphDumper::make( + OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); + auto rst = dumper->dump({x}); + func->execute().wait(); + ASSERT_EQ(rst.nr_opr, nr_opr); + }; + auto load_single_input = [&]() { + auto loader = GraphLoader::make( + InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); + auto rst = loader->load(); + ASSERT_EQ(rst.tensor_map.size(), 1); + ASSERT_EQ(rst.output_var_map.count("out"), 1); + + HostTensorND host_x; + auto func = + rst.graph_compile({make_callback_copy(rst.output_var_list[0], host_x)}); + rst.tensor_map.begin()->second->copy_from(*host0).sync(); + func->execute().wait(); + for (int i = 0; i < 3; i++) { + EXPECT_EQ(host_x.ptr()[i], dst_truth.ptr()[i]); + } + }; + host0->ptr()[2] = INFINITY; + dump_single_input(opr::ElemwiseMultiType::Param::Mode::ISINF, 4); + load_single_input(); + host0->ptr()[2] = NAN; + dump_single_input(opr::ElemwiseMultiType::Param::Mode::ISNAN, 4); + load_single_input(); +} + #endif