From 50faabf614f9dc95df699cbaa79d09ba678a1a00 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 11 May 2022 18:02:52 +0800 Subject: [PATCH] feat(serialization): support the registry for new serialization format GitOrigin-RevId: 8eacd5e77c833d06a5aac2879d097182d4689fda --- src/core/include/megbrain/utils/hash_ct.h | 2 + src/opr/impl/custom_opnode.sereg.h | 15 ++ src/opr/impl/dnn/dnn.sereg.v2.h | 228 ++++++++++++++++++ src/opr/impl/io.sereg.v2.h | 197 +++++++++++++++ src/opr/impl/loop/forward_sereg.cpp | 10 + src/serialization/impl/opr_registry.cpp | 94 +++++++- src/serialization/impl/sereg_caller.cpp | 2 + .../megbrain/serialization/opr_registry.h | 34 ++- .../include/megbrain/serialization/sereg.h | 71 +++++- 9 files changed, 639 insertions(+), 14 deletions(-) create mode 100644 src/opr/impl/dnn/dnn.sereg.v2.h create mode 100644 src/opr/impl/io.sereg.v2.h diff --git a/src/core/include/megbrain/utils/hash_ct.h b/src/core/include/megbrain/utils/hash_ct.h index 658a660b7..e84668562 100644 --- a/src/core/include/megbrain/utils/hash_ct.h +++ b/src/core/include/megbrain/utils/hash_ct.h @@ -153,4 +153,6 @@ struct EnsureHashConstexpr { #define MGB_HASH_STR(v) \ ::mgb::EnsureHashConstexpr<::mgb::XXHash64CT::hash(v, sizeof(v), 20160701)>::val +#define MGB_HASH_RUNTIME(v) XXHash64CT::hash((v).c_str(), (v).size() + 1, 20160701) + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/custom_opnode.sereg.h b/src/opr/impl/custom_opnode.sereg.h index 931841d76..3500d9241 100644 --- a/src/opr/impl/custom_opnode.sereg.h +++ b/src/opr/impl/custom_opnode.sereg.h @@ -52,6 +52,21 @@ mgb::cg::OperatorNodeBase* custom_loader( } \ MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls) +#define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \ + namespace { \ + struct _OprRegV2##cls { \ + static void entry() { \ + MGB_SEREG_OPR_INTL_CALL_ADD_V2( \ + cls, ::mgb::serialization::custom_dumper, \ + ::mgb::serialization::custom_loader, nullptr, _version_min, \ + _version_max); \ + } \ + }; \ + } \ + MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(cls, _OprRegV2##cls) + using namespace mgb; using CustomOpNode = opr::CustomOpNode; CUSTOM_OP_SEREG_REG(CustomOpNode); + +CUSTOM_OP_SEREG_REG_V2(CustomOpNode, 2, CURRENT_VERSION); diff --git a/src/opr/impl/dnn/dnn.sereg.v2.h b/src/opr/impl/dnn/dnn.sereg.v2.h new file mode 100644 index 000000000..d97dd3a05 --- /dev/null +++ b/src/opr/impl/dnn/dnn.sereg.v2.h @@ -0,0 +1,228 @@ +#include "megbrain/graph/symbol_var.h" +#include "megdnn/oprs/general.h" +#if MGB_ENABLE_FBS_SERIALIZATION +#include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/dnn/softmax.h" +#include "megbrain/serialization/oss_opr_load_dump.h" +#include "megbrain/serialization/sereg.h" +#include "megdnn/opr_param_defs.h" +#include "megdnn/oprs/nn.h" + +namespace mgb { +namespace serialization { + +template <> +struct OprLoadDumpImplV2 { + using Opr = opr::Softmax; + using PersisParam = opr::Softmax::Param; + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { + ctx.write_param(opr.cast_final_safe().param()); + } + + static cg::OperatorNodeBase* replace_opr( + cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { + int32_t axis = opr->cast_final_safe().param().axis; + auto input_var = inputs[0]; + auto max_reduce_out = + opr::Reduce::make(input_var, {megdnn::Reduce::Mode::MAX, axis}); + auto elemwise_sub_out = opr::Elemwise::make( + {input_var, max_reduce_out}, {megdnn::Elemwise::Mode::SUB}); + auto elemwise_exp_out = + opr::Elemwise::make({elemwise_sub_out}, {megdnn::Elemwise::Mode::EXP}); + auto sum_reduce_out = + opr::Reduce::make(elemwise_exp_out, {megdnn::Reduce::Mode::SUM, axis}); + auto out = opr::Elemwise::make( + {elemwise_exp_out, sum_reduce_out}, {megdnn::Elemwise::Mode::TRUE_DIV}); + return out.node()->owner_opr(); + } + + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); + auto param = fbs_ctx.read_param(0); + return Opr::make(inputs[0], param, config).node()->owner_opr(); + } +}; + +template < + class Opr, class Maker0, class MegDNNConv, + class Maker1 = MakeConvCallerEmpty, + class Maker2 = MakeConvCallerEmpty, + typename ConvParam = megdnn::param::Convolution> +struct WithPolicyOprLoadDumpImpl { + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { + auto&& opr = opr_.cast_final_safe(); + ctx.write_param(opr.param()); + ctx.write_param( + opr.execution_policy_transient()); + } + static VarNode* make( + const cg::VarNodeArray& inputs, const ConvParam& param, + const megdnn::param::ExecutionPolicy& execution_policy, + const OperatorNodeConfig& config) { + VarNode* ret = + Maker0::template make(inputs, param, execution_policy, config); + if (!ret) { + ret = Maker1::template make(inputs, param, execution_policy, config); + } + if (!ret) { + ret = Maker2::template make(inputs, param, execution_policy, config); + } + mgb_assert(ret); + return ret; + } + + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); + auto fopr = reinterpret_cast( + fbs_ctx.get_current_opr_data()); + auto conv_param = fbs_ctx.read_param(0); + megdnn::param::ExecutionPolicy policy; + if (fopr->additional_params() && fopr->additional_params()->size()) { + policy = fbs_ctx.read_param(1); + } + return make(inputs, conv_param, policy, config)->owner_opr(); + } +}; + +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::Convolution, MakeConvCaller2, + megdnn::Convolution> {}; +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::ConvolutionBackwardData, MakeConvCaller2, + megdnn::Convolution, MakeConvCaller3> {}; +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::ConvolutionBackwardFilter, MakeConvCaller3, + megdnn::Convolution> {}; + +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::Convolution3D, MakeConvCaller2, + megdnn::Convolution3D, MakeConvCallerEmpty, + MakeConvCallerEmpty, + megdnn::param::Convolution3D> {}; +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::Convolution3DBackwardData, + MakeConvCaller2, megdnn::Convolution3D, + MakeConvCaller3, + MakeConvCallerEmpty, + megdnn::param::Convolution3D> {}; +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::Convolution3DBackwardFilter, + MakeConvCaller3, megdnn::Convolution3D, + MakeConvCallerEmpty, + MakeConvCallerEmpty, + megdnn::param::Convolution3D> {}; +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::ConvBiasForward, MakeConvCaller2, + megdnn::ConvBiasForward, MakeConvCaller3, + MakeConvCaller4, megdnn::param::ConvBias> {}; +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::BatchConvBiasForward, + MakeConvCaller2, + megdnn::BatchConvBiasForward, + MakeConvCaller3, + MakeConvCaller4, + megdnn::param::BatchConvBias> {}; + +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::LocalShare, MakeLocalShareCaller2, + megdnn::LocalShare, MakeLocalShareCallerEmpty, + MakeLocalShareCallerEmpty, + megdnn::param::LocalShare> {}; +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::LocalShareBackwardData, + MakeLocalShareCaller3, megdnn::LocalShare, + MakeLocalShareCallerEmpty, + MakeLocalShareCallerEmpty, + megdnn::param::LocalShare> {}; +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::LocalShareBackwardFilter, + MakeLocalShareCaller3, megdnn::LocalShare, + MakeLocalShareCallerEmpty, + MakeLocalShareCallerEmpty, + megdnn::param::LocalShare> {}; +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::DeformableConvForward, + MakeConvCaller4, megdnn::Convolution> { +}; +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::DeformableConvBackwardData, + MakeConvCaller5, + megdnn::Convolution> {}; +template <> +struct OprLoadDumpImplV2 + : public WithPolicyOprLoadDumpImpl< + opr::DeformableConvBackwardFilter, + MakeConvCaller5, + megdnn::Convolution> {}; + +} // namespace serialization + +namespace opr { +#define SERGE_OPR_V2_CONVERTER(_cls, _arity, _converter) \ + MGB_SEREG_OPR_V2(_cls, _arity, _converter, VERSION_2, CURRENT_VERSION); + +#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \ + MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION); + +SERGE_OPR_V2_CONVERTER( + Softmax, 1, + (mgb::serialization::OprLoadDumpImplV2::replace_opr)); + +SERGE_OPR_V2_NO_CONVERTER(ConvBiasForward, 0) +SERGE_OPR_V2_NO_CONVERTER(BatchConvBiasForward, 0); + +SERGE_OPR_V2_NO_CONVERTER(Convolution, 0) +SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardData, 0) +SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardFilter, 0) + +SERGE_OPR_V2_NO_CONVERTER(Convolution3D, 0); +SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardData, 0); +SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardFilter, 0); + +SERGE_OPR_V2_NO_CONVERTER(LocalShareForward, 0); +SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardData, 0); +SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardFilter, 0); + +SERGE_OPR_V2_NO_CONVERTER(DeformableConvForward, 0); +SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardData, 0); +SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardFilter, 0); + +#undef SERGE_OPR_V2_CONVERTER +#undef SERGE_OPR_V2_NO_CONVERTER +} // namespace opr + +} // namespace mgb + +#endif + +// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/io.sereg.v2.h b/src/opr/impl/io.sereg.v2.h new file mode 100644 index 000000000..5c68adb31 --- /dev/null +++ b/src/opr/impl/io.sereg.v2.h @@ -0,0 +1,197 @@ +#if MGB_ENABLE_FBS_SERIALIZATION +#include "megbrain/comp_node_env.h" +#include "megbrain/opr/dnn/softmax.h" +#include "megbrain/opr/io.h" +#include "megbrain/serialization/oss_opr_load_dump.h" +#include "megbrain/serialization/sereg.h" + +#include "megbrain/serialization/internal/mgb_cpp_opr_generated.h" +#include "megbrain/serialization/internal/schema_v2_generated.h" + +namespace mgb { +namespace serialization { + +template <> +struct OprLoadDumpImplV2 { + using Opr = opr::ImmutableTensor; + + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { + using Meth = OprDumpContext::TensorWriteMethod; + auto&& opr = opr_.cast_final_safe(); + ctx.dump_tensor( + {}, HostTensorND{}.copy_from(opr.value()).sync(), + Meth::VALUE_ANONYMOUS); + } + + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); + mgb_assert(inputs.empty()); + auto fopr = reinterpret_cast( + fbs_ctx.get_current_opr_data()); + if (fopr->tensors() && fopr->tensors()->size() > 0) { + auto val = fbs_ctx.load_tensor(); + return Opr::make(fbs_ctx.graph(), *val, config).node()->owner_opr(); + } else { + mgb_throw(SerializationError, "ImmutableTensor load with no tensor data."); + } + } +}; + +template <> +struct OprLoadDumpImplV2 { + using Opr = opr::Host2DeviceCopy; + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { + auto&& opr = opr_.cast_final_safe(); + ctx.write_param(opr.param()); + + using Meth = OprDumpContext::TensorWriteMethod; + ctx.dump_tensor( + opr.name(), *opr.host_data(), + opr.param().dump_default_value ? Meth::VALUE_INPUT : Meth::META_INPUT); + } + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + mgb_assert(inputs.empty()); + auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); + auto param = fbs_ctx.read_param(0); + auto tensor = fbs_ctx.load_tensor(); + return Opr::make(fbs_ctx.graph(), tensor, param, config).node()->owner_opr(); + } +}; + +template <> +struct OprLoadDumpImplV2 { + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { + using Meth = OprDumpContext::TensorWriteMethod; + auto&& opr = opr_.cast_final_safe(); + HostTensorND val; + val.copy_from(opr.get_dev_tensor()).sync(); + ctx.dump_tensor({}, val, Meth::VALUE_ANONYMOUS); + } + + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + mgb_assert(inputs.empty()); + auto val = ctx.load_tensor(); + auto dev_val = + std::make_shared(val->comp_node(), val->layout()); + dev_val->copy_from_fixlayout(*val); + auto out_var = + opr::SharedDeviceTensorWithFormat::make(ctx.graph(), dev_val, config); + dev_val->sync(); + return out_var.node()->owner_opr(); + } +}; + +template <> +struct OprLoadDumpImplV2 { + using Opr = opr::MultipleDeviceTensorHolder; + + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { + using Meth = OprDumpContext::TensorWriteMethod; + auto&& opr = opr_.cast_final_safe(); + uint32_t nr_val = opr.values().size(); + for (uint32_t i = 0; i < nr_val; ++i) { + HostTensorND val; + val.copy_from(*opr.values()[i]).sync(); + ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED); + } + } + + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + mgb_assert(inputs.empty()); + auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); + auto fopr = reinterpret_cast( + fbs_ctx.get_current_opr_data()); + uint32_t nr = 0; + if (fopr && fopr->tensors()) { + nr = fopr->tensors()->size(); + } + Opr::ValueArray values(nr); + for (auto&& i : values) { + i = ctx.load_tensor_shared(); + } + return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr(); + } +}; + +template <> +struct OprLoadDumpImplV2 { + using Opr = opr::MultipleDeviceTensorWithFormatHolder; + + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { + using Meth = OprDumpContext::TensorWriteMethod; + auto&& opr = opr_.cast_final_safe(); + uint32_t nr_val = opr.values().size(); + for (uint32_t i = 0; i < nr_val; ++i) { + HostTensorND val; + auto value = *opr.values()[i]; + val.copy_from(value).sync(); + ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED); + } + } + + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + mgb_assert(inputs.empty()); + auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); + auto fopr = reinterpret_cast( + fbs_ctx.get_current_opr_data()); + uint32_t nr = 0; + if (fopr && fopr->tensors()) { + nr = fopr->tensors()->size(); + } + Opr::ValueArray values(nr); + for (auto&& i : values) { + i = ctx.load_tensor_shared(); + //! set tensor format + TensorLayout layout_with_format = i->layout(); + + if (i->storage().comp_node().mem_node() == + CompNode::default_cpu().mem_node()) { + mgb_assert( + i->storage().ptr(), + "storage should not be nullptr if mem_node is " + "default_cpu"); + HostTensorND src{i->storage().comp_node(), layout_with_format}; + src.copy_from_fixlayout(*i).sync(); + *i = DeviceTensorND::make_proxy(src); + } else { + //! actually only layout of this tensor will be used later, see + //! src/serialization/impl/batched_device_value_loader.cpp:49. But we + //! have no way to reset layout only, so just construct a invalid + //! storage instead + auto size = layout_with_format.span().dist_byte(); + DeviceTensorStorage storage; + storage.reset(i->comp_node(), size, nullptr); + i->reset(storage, layout_with_format); + } + } + return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr(); + } +}; + +} // namespace serialization + +namespace opr { +#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \ + MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION); + +SERGE_OPR_V2_NO_CONVERTER(ImmutableTensor, 0); +SERGE_OPR_V2_NO_CONVERTER(Host2DeviceCopy, 0); +SERGE_OPR_V2_NO_CONVERTER(SharedDeviceTensorWithFormat, 0); +SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorWithFormatHolder, 0); +SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorHolder, 0); +} // namespace opr +} // namespace mgb + +#endif + +// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/loop/forward_sereg.cpp b/src/opr/impl/loop/forward_sereg.cpp index c2877618b..df5b2862b 100644 --- a/src/opr/impl/loop/forward_sereg.cpp +++ b/src/opr/impl/loop/forward_sereg.cpp @@ -135,6 +135,16 @@ void LoopSerializer::reg_all() { MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker); MGB_SEREG_OPR_INTL_CALL_ADD( CounterProvider, dump_counter_provider, load_counter_provider); + + MGB_SEREG_OPR_INTL_CALL_ADD_V2( + opr::Loop, dump_loop, load_loop, nullptr, 2, + CURRENT_VERSION); + MGB_SEREG_OPR_INTL_CALL_ADD_V2( + InputMaker, dump_input_maker, load_input_maker, nullptr, 2, + CURRENT_VERSION); + MGB_SEREG_OPR_INTL_CALL_ADD_V2( + CounterProvider, dump_counter_provider, load_counter_provider, nullptr, 2, + CURRENT_VERSION); } void LoopSerializer::dump_loop(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { diff --git a/src/serialization/impl/opr_registry.cpp b/src/serialization/impl/opr_registry.cpp index c13704d51..c0ea45eb5 100644 --- a/src/serialization/impl/opr_registry.cpp +++ b/src/serialization/impl/opr_registry.cpp @@ -20,6 +20,11 @@ struct StaticData { ThinHashMap type2reg; std::unordered_map name2reg; ThinHashMap unversioned_id2reg; + + //! versioned OprRegistryV2, version_id_reg_map is used for Operator + //! load/shallow copy and version_type_reg_map is used for Operator dump + ThinHashMap> version_id_reg_map; + ThinHashMap> version_type_reg_map; }; StaticData& static_data() { @@ -47,6 +52,20 @@ const OprRegistry* dynamic_registry() { return ret; } +const OprRegistryV2* dynamic_registry_v2() { + static const OprRegistryV2* ret = nullptr; + if (ret) + return ret; + + auto id = MGB_HASH_STR("dynamic"); + OprRegistryV2::versioned_add( + {nullptr, id, {}, {}, dynamic_loader, {}}, CURRENT_VERSION, + CURRENT_VERSION); + ret = OprRegistryV2::versioned_find_by_id(id, CURRENT_VERSION); + mgb_assert(ret); + return ret; +} + class _Init { public: _Init() { @@ -63,8 +82,7 @@ void OprRegistry::add(const OprRegistry& record) { auto registry_ins = sd.id2reg.emplace(persist_id, record); mgb_assert( registry_ins.second || persist_id == dynamic_registry()->persist_type_id, - "duplicated operator persist_type_id: %s", - std::to_string(persist_id).c_str()); + "duplicated operator name : %s", record.name.c_str()); OprRegistry* persis_record_ptr; if (registry_ins.second) { @@ -129,6 +147,73 @@ const OprRegistry* OprRegistry::find_by_unversioned_id(size_t unversioned_id) { return iter == uid2reg.end() ? nullptr : iter->second; } +//! find the registry equal to the giving version +const OprRegistryV2* OprRegistryV2::versioned_find_by_id( + const size_t id, uint8_t version) { + auto&& id_reg_map = static_data().version_id_reg_map; + auto iter_version = id_reg_map.find(version); + if (iter_version != id_reg_map.end()) { + auto iter = iter_version->second.find(id); + return iter == iter_version->second.end() ? nullptr : &iter->second; + } + return nullptr; +} +//! find the registry equal or below the giving version +const OprRegistryV2* OprRegistryV2::versioned_find_by_typeinfo( + Typeinfo* type, uint8_t version) { + const auto& type_reg_map = static_data().version_type_reg_map; + for (int version_id = version; version_id > 0; version_id--) { + auto iter_version = type_reg_map.find(version_id); + if (iter_version != type_reg_map.end()) { + auto iter = iter_version->second.find(type); + if (iter == iter_version->second.end()) { + continue; + } else { + return iter->second; + } + } + } + return nullptr; +} + +void OprRegistryV2::versioned_add( + const OprRegistryV2& record, uint8_t min_version, uint8_t max_version) { + mgb_assert(max_version >= min_version); + + auto&& sd = static_data(); + auto id = record.type_id; + uint64_t type_id = id; + //! record.type->name is nullptr when MGB_VERBOSE_TYPEINFO_NAME==0 + if (record.type && record.type->name) { + type_id = MGB_HASH_RUNTIME(std::string(record.type->name)); + } + for (uint8_t version = min_version; version <= max_version; version++) { + auto&& registry_map = sd.version_id_reg_map[version]; + auto versioned_record = record; + versioned_record.version = version; + mgb_assert( + registry_map.find(id) == registry_map.end() || + id == dynamic_registry_v2()->type_id, + "dduplicated OprRegistryV2 of %s\n", record.name.c_str()); + auto registry_ins = registry_map.emplace(id, versioned_record); + if (!registry_ins.second) { + //! the registry is dynamic + mgb_assert(!record.converter); + registry_map[id] = versioned_record; + } + //! sometimes the register id and the hash typeinfo is not same, just as + //! dynamic Operator + if (id != type_id) { + mgb_assert( + registry_map.find(type_id) == registry_map.end(), + "dduplicated OprRegistryV2 of %s\n", record.name.c_str()); + registry_map.emplace(type_id, versioned_record); + } + auto&& registry_type_map = sd.version_type_reg_map[version]; + registry_type_map.emplace(record.type, ®istry_map[id]); + } +} + void OprRegistry::add_using_dynamic_loader( Typeinfo* type, const std::string& name, const OprDumper& dumper) { // dynamic oprs are implemented by mapping different opr types to the same @@ -140,6 +225,11 @@ void OprRegistry::add_using_dynamic_loader( {}, {}, dynamic_registry()->unversioned_type_id}); + mgb_assert(type, "type must be not nullptr"); + OprRegistryV2::versioned_add( + {type, dynamic_registry_v2()->type_id, type->name, dumper, + dynamic_registry_v2()->loader, nullptr}, + CURRENT_VERSION, CURRENT_VERSION); } #if MGB_ENABLE_DEBUG_UTIL diff --git a/src/serialization/impl/sereg_caller.cpp b/src/serialization/impl/sereg_caller.cpp index d3023b6bd..02e24a09e 100644 --- a/src/serialization/impl/sereg_caller.cpp +++ b/src/serialization/impl/sereg_caller.cpp @@ -9,10 +9,12 @@ void call_sereg() {} #include "../../opr/impl/blas.sereg.h" #include "../../opr/impl/cond.sereg.h" #include "../../opr/impl/dnn/dnn.sereg.h" +#include "../../opr/impl/dnn/dnn.sereg.v2.h" #include "./extern_c_opr.sereg.h" #include "../../opr/impl/imgproc.sereg.h" #include "../../opr/impl/indexing.sereg.h" #include "../../opr/impl/io.sereg.h" +#include "../../opr/impl/io.sereg.v2.h" #include "../../opr/impl/loop/forward.sereg.h" #include "../../opr/impl/loop/grad.sereg.h" #include "../../opr/impl/misc.sereg.h" diff --git a/src/serialization/include/megbrain/serialization/opr_registry.h b/src/serialization/include/megbrain/serialization/opr_registry.h index 2cd6f7e04..1ad8a9afe 100644 --- a/src/serialization/include/megbrain/serialization/opr_registry.h +++ b/src/serialization/include/megbrain/serialization/opr_registry.h @@ -53,7 +53,6 @@ struct OprRegistry { uint64_t unversioned_type_id; MGE_WIN_DECLSPEC_FUC static void add(const OprRegistry& record); - /*! * \brief register an operator to use dynamic loader * @@ -89,6 +88,39 @@ struct OprRegistry { #endif }; +//! Convert some modified Opr to compatible Opr +using OprConvertToCompatible = thin_function; + +//! record of a single operator +struct OprRegistryV2 { + Typeinfo* type; + uint64_t type_id; + std::string name; + OprDumper dumper; + OprLoaderWrapper loader; + OprConvertToCompatible converter; + uint8_t version = 2; + + MGE_WIN_DECLSPEC_FUC uint8_t get_version() const { return version; } + + //! register opr load/dump to version2regmap + MGE_WIN_DECLSPEC_FUC static void versioned_add( + const OprRegistryV2& record, uint8_t min_version, uint8_t max_version); + + MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_id( + const size_t id, uint8_t version); + + MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_typeinfo( + Typeinfo* type, uint8_t version); + +#if MGB_ENABLE_DEBUG_UTIL + //! dump registered oprs + MGE_WIN_DECLSPEC_FUC static std::vector> + dump_registries(); +#endif +}; + } // namespace serialization } // namespace mgb diff --git a/src/serialization/include/megbrain/serialization/sereg.h b/src/serialization/include/megbrain/serialization/sereg.h index 6fe1cf3b0..0c59d1b53 100644 --- a/src/serialization/include/megbrain/serialization/sereg.h +++ b/src/serialization/include/megbrain/serialization/sereg.h @@ -3,6 +3,7 @@ #include "megbrain/serialization/opr_load_dump.h" #include "megbrain/serialization/opr_registry.h" #include "megbrain/serialization/opr_shallow_copy.h" +#include "megbrain/serialization/oss_opr_load_dump.h" #include "megbrain/utils/hash_ct.h" namespace mgb { @@ -66,6 +67,9 @@ struct OprLoadDumpImpl { } }; +template +struct OprLoadDumpImplV2 : public OprLoadDumpImpl {}; + #define IMPL_OPR_MAKER(_arity, _args...) \ template \ struct OprMaker { \ @@ -124,6 +128,12 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl {}; __caller_OprReg##_cls##_ins; \ } +#define MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _impl) \ + namespace { \ + [[gnu::unused]] ::mgb::serialization::OprRegistryCaller<_cls, _impl> \ + __caller_V2_OprReg##_cls##_ins; \ + } + // Trim the terminating null character and a "V0" like suffix from the string // then hash it. // TODO: Get rid of this. @@ -138,17 +148,35 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl {}; : 0), \ 20160701)>::val -//! call OprRegistry::add -#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \ - do { \ - ::mgb::serialization::OprRegistry::add( \ - {_cls::typeinfo(), \ - MGB_HASH_STR(#_cls), \ - _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ - _dump, \ - _load, \ - {}, \ - MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ +//! call OprRegistry::add for old serialization +//! call OprRegistryV2::versioned_add for new serialization which is compatiable +//! with old serialization, convert is nullptr, this registry is just only for +//! varsion 1 +#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \ + do { \ + ::mgb::serialization::OprRegistry::add( \ + {_cls::typeinfo(), \ + MGB_HASH_STR(#_cls), \ + _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ + _dump, \ + _load, \ + {}, \ + MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ + ::mgb::serialization::OprRegistryV2::versioned_add( \ + {_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ + _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \ + ::mgb::VERSION_1, ::mgb::VERSION_1); \ + } while (0) + +//! call OprRegistryV2::versioned_add for new serialization, in which convert the +//! function converter the Operator to the compatiable +#define MGB_SEREG_OPR_INTL_CALL_ADD_V2( \ + _cls, _dump, _load, _convert, _version_min, _version_max) \ + do { \ + ::mgb::serialization::OprRegistryV2::versioned_add( \ + {_cls::typeinfo(), MGB_HASH_STR(#_cls), \ + _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, _convert}, \ + _version_min, _version_max); \ } while (0) /*! @@ -171,6 +199,27 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl {}; } \ MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) +//! new dump/load function should implement in OprLoadDumpImplV2, _converter is +//! optional , if not implement pass nullptr +#define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \ + namespace { \ + namespace ser = ::mgb::serialization; \ + struct _OprRegV2##_cls { \ + using Impl = ser::OprLoadDumpImplV2<_cls, _arity>; \ + static ser::OprWithOutputAccessor wrap_loader( \ + ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ + const mgb::cg::OperatorNodeConfig& config) { \ + return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ + } \ + static void entry() { \ + MGB_SEREG_OPR_INTL_CALL_ADD_V2( \ + _cls, Impl::dump, wrap_loader, _converter, _version_min, \ + _version_max); \ + } \ + }; \ + } \ + MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls) + //! use to check type is complete or not, midout need a complete type template struct IsComplete : std::false_type {}; -- GitLab