diff --git a/src/opr/impl/nn_int.sereg.h b/src/opr/impl/nn_int.sereg.h index 4527d44b1396298cf5a333b6e18e3c7f00255919..9cb0fc93ad9fe77eae45d936f38a216b24fa6cf6 100644 --- a/src/opr/impl/nn_int.sereg.h +++ b/src/opr/impl/nn_int.sereg.h @@ -72,7 +72,7 @@ struct OprLoadDumpImplV2 { namespace opr { MGB_SEREG_OPR_CONDITION(ElemwiseMultiType, 0, false); -MGB_SEREG_OPR_V2( +MGB_SEREG_OPR_V2_HASH_WITHOUT_TAIL_0( ElemwiseMultiType, 0, (mgb::serialization::OprLoadDumpImplV2::replace_opr), VERSION_1, VERSION_1); diff --git a/src/serialization/impl/opr_registry.cpp b/src/serialization/impl/opr_registry.cpp index 69ad8676bc106a6f0c182ca78611d5268bb14f71..48f1bbc89618671aee8fa5ccdb084a44190802a4 100644 --- a/src/serialization/impl/opr_registry.cpp +++ b/src/serialization/impl/opr_registry.cpp @@ -64,8 +64,8 @@ const OprRegistryV2* dynamic_registry_v2() { auto id = MGB_HASH_STR("dynamic"); OprRegistryV2::versioned_add( - {nullptr, id, {}, {}, dynamic_loader, {}}, CURRENT_VERSION, - CURRENT_VERSION); + {nullptr, id, {}, {}, dynamic_loader, {}}, CURRENT_VERSION, CURRENT_VERSION, + true); ret = OprRegistryV2::versioned_find_by_id(id, CURRENT_VERSION); mgb_assert(ret); return ret; @@ -182,7 +182,8 @@ const OprRegistryV2* OprRegistryV2::versioned_find_by_typeinfo( } void OprRegistryV2::versioned_add( - const OprRegistryV2& record, uint8_t min_version, uint8_t max_version) { + const OprRegistryV2& record, uint8_t min_version, uint8_t max_version, + bool dynamic) { mgb_assert(max_version >= min_version); auto&& sd = static_data(); @@ -190,7 +191,7 @@ void OprRegistryV2::versioned_add( uint64_t type_id = id; //! record.type->name is nullptr when MGB_VERBOSE_TYPEINFO_NAME==0 #if MGB_VERBOSE_TYPEINFO_NAME - if (record.type && record.type->name) { + if (dynamic && record.type && record.type->name) { type_id = MGB_HASH_RUNTIME(std::string(record.type->name)); } #endif @@ -236,7 +237,7 @@ void OprRegistry::add_using_dynamic_loader( OprRegistryV2::versioned_add( {type, dynamic_registry_v2()->type_id, type->name, dumper, dynamic_registry_v2()->loader, nullptr}, - CURRENT_VERSION, CURRENT_VERSION); + CURRENT_VERSION, CURRENT_VERSION, true); } #if MGB_ENABLE_DEBUG_UTIL diff --git a/src/serialization/include/megbrain/serialization/opr_registry.h b/src/serialization/include/megbrain/serialization/opr_registry.h index c623e1bfc27c910daf1e4d8b4d04deda7e574fac..ae28464773af44716b4b31616a3b5d51998be00b 100644 --- a/src/serialization/include/megbrain/serialization/opr_registry.h +++ b/src/serialization/include/megbrain/serialization/opr_registry.h @@ -111,7 +111,8 @@ struct OprRegistryV2 { //! register opr load/dump to version2regmap MGE_WIN_DECLSPEC_FUC static void versioned_add( - const OprRegistryV2& record, uint8_t min_version, uint8_t max_version); + const OprRegistryV2& record, uint8_t min_version, uint8_t max_version, + bool dynamic = false); MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_id( const size_t id, uint8_t version); diff --git a/src/serialization/include/megbrain/serialization/sereg.h b/src/serialization/include/megbrain/serialization/sereg.h index 3c61bed5d8450724c3a0f08526abfd931b19efca..3776e9d37022ea57c66c1df1930b6b8443a064bf 100644 --- a/src/serialization/include/megbrain/serialization/sereg.h +++ b/src/serialization/include/megbrain/serialization/sereg.h @@ -180,6 +180,18 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl {}; _version_min, _version_max); \ } while (0) +//! in order to compatibility with MGB_SEREG_OPR_INTL_CALL_ADD, the macro use +//! the same hash with MGB_SEREG_OPR_INTL_CALL_ADD, +//! MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION is different with MGB_HASH_STR +#define MGB_SEREG_OPR_INTL_CALL_ADD_V2_WITHOUT_TAIL_0_AND_VERSION_HASH( \ + _cls, _dump, _load, _convert, _version_min, _version_max) \ + do { \ + ::mgb::serialization::OprRegistryV2::versioned_add( \ + {_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ + _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, _convert}, \ + _version_min, _version_max); \ + } while (0) + /*! * \brief register opr serialization methods */ @@ -223,6 +235,27 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl {}; } \ MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls) +//! using MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION macro to get the type id +#define MGB_SEREG_OPR_V2_HASH_WITHOUT_TAIL_0( \ + _cls, _arity, _converter, _version_min, _version_max) \ + namespace { \ + namespace ser = ::mgb::serialization; \ + struct _OprRegV2##_cls { \ + using Impl = ser::OprLoadDumpImplV2<_cls, _arity>; \ + static ser::OprWithOutputAccessor wrap_loader( \ + ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ + const mgb::cg::OperatorNodeConfig& config) { \ + return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ + } \ + static void entry() { \ + MGB_SEREG_OPR_INTL_CALL_ADD_V2_WITHOUT_TAIL_0_AND_VERSION_HASH( \ + _cls, Impl::dump, wrap_loader, _converter, _version_min, \ + _version_max); \ + } \ + }; \ + } \ + MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls) + //! use to check type is complete or not, midout need a complete type template struct IsComplete : std::false_type {};