提交 50faabf6 编写于 作者: M Megvii Engine Team

feat(serialization): support the registry for new serialization format

GitOrigin-RevId: 8eacd5e77c833d06a5aac2879d097182d4689fda
上级 a694fb33
...@@ -153,4 +153,6 @@ struct EnsureHashConstexpr { ...@@ -153,4 +153,6 @@ struct EnsureHashConstexpr {
#define MGB_HASH_STR(v) \ #define MGB_HASH_STR(v) \
::mgb::EnsureHashConstexpr<::mgb::XXHash64CT::hash(v, sizeof(v), 20160701)>::val ::mgb::EnsureHashConstexpr<::mgb::XXHash64CT::hash(v, sizeof(v), 20160701)>::val
#define MGB_HASH_RUNTIME(v) XXHash64CT::hash((v).c_str(), (v).size() + 1, 20160701)
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -52,6 +52,21 @@ mgb::cg::OperatorNodeBase* custom_loader( ...@@ -52,6 +52,21 @@ mgb::cg::OperatorNodeBase* custom_loader(
} \ } \
MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls) MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls)
#define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \
namespace { \
struct _OprRegV2##cls { \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD_V2( \
cls, ::mgb::serialization::custom_dumper, \
::mgb::serialization::custom_loader, nullptr, _version_min, \
_version_max); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(cls, _OprRegV2##cls)
using namespace mgb; using namespace mgb;
using CustomOpNode = opr::CustomOpNode; using CustomOpNode = opr::CustomOpNode;
CUSTOM_OP_SEREG_REG(CustomOpNode); CUSTOM_OP_SEREG_REG(CustomOpNode);
CUSTOM_OP_SEREG_REG_V2(CustomOpNode, 2, CURRENT_VERSION);
#include "megbrain/graph/symbol_var.h"
#include "megdnn/oprs/general.h"
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/softmax.h"
#include "megbrain/serialization/oss_opr_load_dump.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/nn.h"
namespace mgb {
namespace serialization {
template <>
struct OprLoadDumpImplV2<opr::Softmax, 1> {
using Opr = opr::Softmax;
using PersisParam = opr::Softmax::Param;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param());
}
static cg::OperatorNodeBase* replace_opr(
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) {
int32_t axis = opr->cast_final_safe<Opr>().param().axis;
auto input_var = inputs[0];
auto max_reduce_out =
opr::Reduce::make(input_var, {megdnn::Reduce::Mode::MAX, axis});
auto elemwise_sub_out = opr::Elemwise::make(
{input_var, max_reduce_out}, {megdnn::Elemwise::Mode::SUB});
auto elemwise_exp_out =
opr::Elemwise::make({elemwise_sub_out}, {megdnn::Elemwise::Mode::EXP});
auto sum_reduce_out =
opr::Reduce::make(elemwise_exp_out, {megdnn::Reduce::Mode::SUM, axis});
auto out = opr::Elemwise::make(
{elemwise_exp_out, sum_reduce_out}, {megdnn::Elemwise::Mode::TRUE_DIV});
return out.node()->owner_opr();
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto param = fbs_ctx.read_param<PersisParam>(0);
return Opr::make(inputs[0], param, config).node()->owner_opr();
}
};
template <
class Opr, class Maker0, class MegDNNConv,
class Maker1 = MakeConvCallerEmpty<MegDNNConv>,
class Maker2 = MakeConvCallerEmpty<MegDNNConv>,
typename ConvParam = megdnn::param::Convolution>
struct WithPolicyOprLoadDumpImpl {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<ConvParam>(opr.param());
ctx.write_param<megdnn::param::ExecutionPolicy>(
opr.execution_policy_transient());
}
static VarNode* make(
const cg::VarNodeArray& inputs, const ConvParam& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
VarNode* ret =
Maker0::template make<Opr>(inputs, param, execution_policy, config);
if (!ret) {
ret = Maker1::template make<Opr>(inputs, param, execution_policy, config);
}
if (!ret) {
ret = Maker2::template make<Opr>(inputs, param, execution_policy, config);
}
mgb_assert(ret);
return ret;
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto fopr = reinterpret_cast<const fbs::v2::Operator*>(
fbs_ctx.get_current_opr_data());
auto conv_param = fbs_ctx.read_param<ConvParam>(0);
megdnn::param::ExecutionPolicy policy;
if (fopr->additional_params() && fopr->additional_params()->size()) {
policy = fbs_ctx.read_param<megdnn::param::ExecutionPolicy>(1);
}
return make(inputs, conv_param, policy, config)->owner_opr();
}
};
template <>
struct OprLoadDumpImplV2<opr::Convolution, 0>
: public WithPolicyOprLoadDumpImpl<
opr::Convolution, MakeConvCaller2<megdnn::Convolution>,
megdnn::Convolution> {};
template <>
struct OprLoadDumpImplV2<opr::ConvolutionBackwardData, 0>
: public WithPolicyOprLoadDumpImpl<
opr::ConvolutionBackwardData, MakeConvCaller2<megdnn::Convolution>,
megdnn::Convolution, MakeConvCaller3<megdnn::Convolution>> {};
template <>
struct OprLoadDumpImplV2<opr::ConvolutionBackwardFilter, 0>
: public WithPolicyOprLoadDumpImpl<
opr::ConvolutionBackwardFilter, MakeConvCaller3<megdnn::Convolution>,
megdnn::Convolution> {};
template <>
struct OprLoadDumpImplV2<opr::Convolution3D, 0>
: public WithPolicyOprLoadDumpImpl<
opr::Convolution3D, MakeConvCaller2<megdnn::Convolution3D>,
megdnn::Convolution3D, MakeConvCallerEmpty<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D> {};
template <>
struct OprLoadDumpImplV2<opr::Convolution3DBackwardData, 0>
: public WithPolicyOprLoadDumpImpl<
opr::Convolution3DBackwardData,
MakeConvCaller2<megdnn::Convolution3D>, megdnn::Convolution3D,
MakeConvCaller3<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D> {};
template <>
struct OprLoadDumpImplV2<opr::Convolution3DBackwardFilter, 0>
: public WithPolicyOprLoadDumpImpl<
opr::Convolution3DBackwardFilter,
MakeConvCaller3<megdnn::Convolution3D>, megdnn::Convolution3D,
MakeConvCallerEmpty<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D> {};
template <>
struct OprLoadDumpImplV2<opr::ConvBiasForward, 0>
: public WithPolicyOprLoadDumpImpl<
opr::ConvBiasForward, MakeConvCaller2<megdnn::ConvBiasForward>,
megdnn::ConvBiasForward, MakeConvCaller3<megdnn::ConvBiasForward>,
MakeConvCaller4<megdnn::ConvBiasForward>, megdnn::param::ConvBias> {};
template <>
struct OprLoadDumpImplV2<opr::BatchConvBiasForward, 0>
: public WithPolicyOprLoadDumpImpl<
opr::BatchConvBiasForward,
MakeConvCaller2<megdnn::BatchConvBiasForward>,
megdnn::BatchConvBiasForward,
MakeConvCaller3<megdnn::BatchConvBiasForward>,
MakeConvCaller4<megdnn::BatchConvBiasForward>,
megdnn::param::BatchConvBias> {};
template <>
struct OprLoadDumpImplV2<opr::LocalShare, 0>
: public WithPolicyOprLoadDumpImpl<
opr::LocalShare, MakeLocalShareCaller2<megdnn::LocalShare>,
megdnn::LocalShare, MakeLocalShareCallerEmpty<megdnn::LocalShare>,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
megdnn::param::LocalShare> {};
template <>
struct OprLoadDumpImplV2<opr::LocalShareBackwardData, 0>
: public WithPolicyOprLoadDumpImpl<
opr::LocalShareBackwardData,
MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
megdnn::param::LocalShare> {};
template <>
struct OprLoadDumpImplV2<opr::LocalShareBackwardFilter, 0>
: public WithPolicyOprLoadDumpImpl<
opr::LocalShareBackwardFilter,
MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
megdnn::param::LocalShare> {};
template <>
struct OprLoadDumpImplV2<opr::DeformableConvForward, 0>
: public WithPolicyOprLoadDumpImpl<
opr::DeformableConvForward,
MakeConvCaller4<megdnn::DeformableConvForward>, megdnn::Convolution> {
};
template <>
struct OprLoadDumpImplV2<opr::DeformableConvBackwardData, 0>
: public WithPolicyOprLoadDumpImpl<
opr::DeformableConvBackwardData,
MakeConvCaller5<megdnn::DeformableConvBackwardData>,
megdnn::Convolution> {};
template <>
struct OprLoadDumpImplV2<opr::DeformableConvBackwardFilter, 0>
: public WithPolicyOprLoadDumpImpl<
opr::DeformableConvBackwardFilter,
MakeConvCaller5<megdnn::DeformableConvBackwardFilter>,
megdnn::Convolution> {};
} // namespace serialization
namespace opr {
#define SERGE_OPR_V2_CONVERTER(_cls, _arity, _converter) \
MGB_SEREG_OPR_V2(_cls, _arity, _converter, VERSION_2, CURRENT_VERSION);
#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \
MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION);
SERGE_OPR_V2_CONVERTER(
Softmax, 1,
(mgb::serialization::OprLoadDumpImplV2<opr::Softmax, 1>::replace_opr));
SERGE_OPR_V2_NO_CONVERTER(ConvBiasForward, 0)
SERGE_OPR_V2_NO_CONVERTER(BatchConvBiasForward, 0);
SERGE_OPR_V2_NO_CONVERTER(Convolution, 0)
SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardData, 0)
SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardFilter, 0)
SERGE_OPR_V2_NO_CONVERTER(Convolution3D, 0);
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardData, 0);
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardFilter, 0);
SERGE_OPR_V2_NO_CONVERTER(LocalShareForward, 0);
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardData, 0);
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardFilter, 0);
SERGE_OPR_V2_NO_CONVERTER(DeformableConvForward, 0);
SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardData, 0);
SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardFilter, 0);
#undef SERGE_OPR_V2_CONVERTER
#undef SERGE_OPR_V2_NO_CONVERTER
} // namespace opr
} // namespace mgb
#endif
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/comp_node_env.h"
#include "megbrain/opr/dnn/softmax.h"
#include "megbrain/opr/io.h"
#include "megbrain/serialization/oss_opr_load_dump.h"
#include "megbrain/serialization/sereg.h"
#include "megbrain/serialization/internal/mgb_cpp_opr_generated.h"
#include "megbrain/serialization/internal/schema_v2_generated.h"
namespace mgb {
namespace serialization {
template <>
struct OprLoadDumpImplV2<opr::ImmutableTensor, 0> {
using Opr = opr::ImmutableTensor;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.dump_tensor(
{}, HostTensorND{}.copy_from(opr.value()).sync(),
Meth::VALUE_ANONYMOUS);
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
mgb_assert(inputs.empty());
auto fopr = reinterpret_cast<const fbs::v2::Operator*>(
fbs_ctx.get_current_opr_data());
if (fopr->tensors() && fopr->tensors()->size() > 0) {
auto val = fbs_ctx.load_tensor();
return Opr::make(fbs_ctx.graph(), *val, config).node()->owner_opr();
} else {
mgb_throw(SerializationError, "ImmutableTensor load with no tensor data.");
}
}
};
template <>
struct OprLoadDumpImplV2<opr::Host2DeviceCopy, 0> {
using Opr = opr::Host2DeviceCopy;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param(opr.param());
using Meth = OprDumpContext::TensorWriteMethod;
ctx.dump_tensor(
opr.name(), *opr.host_data(),
opr.param().dump_default_value ? Meth::VALUE_INPUT : Meth::META_INPUT);
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto param = fbs_ctx.read_param<Opr::Param>(0);
auto tensor = fbs_ctx.load_tensor();
return Opr::make(fbs_ctx.graph(), tensor, param, config).node()->owner_opr();
}
};
template <>
struct OprLoadDumpImplV2<opr::SharedDeviceTensorWithFormat, 0> {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<opr::SharedDeviceTensorWithFormat>();
HostTensorND val;
val.copy_from(opr.get_dev_tensor()).sync();
ctx.dump_tensor({}, val, Meth::VALUE_ANONYMOUS);
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto val = ctx.load_tensor();
auto dev_val =
std::make_shared<DeviceTensorND>(val->comp_node(), val->layout());
dev_val->copy_from_fixlayout(*val);
auto out_var =
opr::SharedDeviceTensorWithFormat::make(ctx.graph(), dev_val, config);
dev_val->sync();
return out_var.node()->owner_opr();
}
};
template <>
struct OprLoadDumpImplV2<opr::MultipleDeviceTensorHolder, 0> {
using Opr = opr::MultipleDeviceTensorHolder;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
uint32_t nr_val = opr.values().size();
for (uint32_t i = 0; i < nr_val; ++i) {
HostTensorND val;
val.copy_from(*opr.values()[i]).sync();
ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED);
}
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto fopr = reinterpret_cast<const fbs::v2::Operator*>(
fbs_ctx.get_current_opr_data());
uint32_t nr = 0;
if (fopr && fopr->tensors()) {
nr = fopr->tensors()->size();
}
Opr::ValueArray values(nr);
for (auto&& i : values) {
i = ctx.load_tensor_shared();
}
return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr();
}
};
template <>
struct OprLoadDumpImplV2<opr::MultipleDeviceTensorWithFormatHolder, 0> {
using Opr = opr::MultipleDeviceTensorWithFormatHolder;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
uint32_t nr_val = opr.values().size();
for (uint32_t i = 0; i < nr_val; ++i) {
HostTensorND val;
auto value = *opr.values()[i];
val.copy_from(value).sync();
ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED);
}
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto fopr = reinterpret_cast<const fbs::v2::Operator*>(
fbs_ctx.get_current_opr_data());
uint32_t nr = 0;
if (fopr && fopr->tensors()) {
nr = fopr->tensors()->size();
}
Opr::ValueArray values(nr);
for (auto&& i : values) {
i = ctx.load_tensor_shared();
//! set tensor format
TensorLayout layout_with_format = i->layout();
if (i->storage().comp_node().mem_node() ==
CompNode::default_cpu().mem_node()) {
mgb_assert(
i->storage().ptr(),
"storage should not be nullptr if mem_node is "
"default_cpu");
HostTensorND src{i->storage().comp_node(), layout_with_format};
src.copy_from_fixlayout(*i).sync();
*i = DeviceTensorND::make_proxy(src);
} else {
//! actually only layout of this tensor will be used later, see
//! src/serialization/impl/batched_device_value_loader.cpp:49. But we
//! have no way to reset layout only, so just construct a invalid
//! storage instead
auto size = layout_with_format.span().dist_byte();
DeviceTensorStorage storage;
storage.reset(i->comp_node(), size, nullptr);
i->reset(storage, layout_with_format);
}
}
return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr();
}
};
} // namespace serialization
namespace opr {
#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \
MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION);
SERGE_OPR_V2_NO_CONVERTER(ImmutableTensor, 0);
SERGE_OPR_V2_NO_CONVERTER(Host2DeviceCopy, 0);
SERGE_OPR_V2_NO_CONVERTER(SharedDeviceTensorWithFormat, 0);
SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorWithFormatHolder, 0);
SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorHolder, 0);
} // namespace opr
} // namespace mgb
#endif
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -135,6 +135,16 @@ void LoopSerializer::reg_all() { ...@@ -135,6 +135,16 @@ void LoopSerializer::reg_all() {
MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker); MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker);
MGB_SEREG_OPR_INTL_CALL_ADD( MGB_SEREG_OPR_INTL_CALL_ADD(
CounterProvider, dump_counter_provider, load_counter_provider); CounterProvider, dump_counter_provider, load_counter_provider);
MGB_SEREG_OPR_INTL_CALL_ADD_V2(
opr::Loop, dump_loop, load_loop, nullptr, 2,
CURRENT_VERSION);
MGB_SEREG_OPR_INTL_CALL_ADD_V2(
InputMaker, dump_input_maker, load_input_maker, nullptr, 2,
CURRENT_VERSION);
MGB_SEREG_OPR_INTL_CALL_ADD_V2(
CounterProvider, dump_counter_provider, load_counter_provider, nullptr, 2,
CURRENT_VERSION);
} }
void LoopSerializer::dump_loop(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { void LoopSerializer::dump_loop(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
......
...@@ -20,6 +20,11 @@ struct StaticData { ...@@ -20,6 +20,11 @@ struct StaticData {
ThinHashMap<Typeinfo*, OprRegistry*> type2reg; ThinHashMap<Typeinfo*, OprRegistry*> type2reg;
std::unordered_map<std::string, OprRegistry*> name2reg; std::unordered_map<std::string, OprRegistry*> name2reg;
ThinHashMap<size_t, OprRegistry*> unversioned_id2reg; ThinHashMap<size_t, OprRegistry*> unversioned_id2reg;
//! versioned OprRegistryV2, version_id_reg_map is used for Operator
//! load/shallow copy and version_type_reg_map is used for Operator dump
ThinHashMap<uint8_t, ThinHashMap<size_t, OprRegistryV2>> version_id_reg_map;
ThinHashMap<uint8_t, ThinHashMap<Typeinfo*, OprRegistryV2*>> version_type_reg_map;
}; };
StaticData& static_data() { StaticData& static_data() {
...@@ -47,6 +52,20 @@ const OprRegistry* dynamic_registry() { ...@@ -47,6 +52,20 @@ const OprRegistry* dynamic_registry() {
return ret; return ret;
} }
const OprRegistryV2* dynamic_registry_v2() {
static const OprRegistryV2* ret = nullptr;
if (ret)
return ret;
auto id = MGB_HASH_STR("dynamic");
OprRegistryV2::versioned_add(
{nullptr, id, {}, {}, dynamic_loader, {}}, CURRENT_VERSION,
CURRENT_VERSION);
ret = OprRegistryV2::versioned_find_by_id(id, CURRENT_VERSION);
mgb_assert(ret);
return ret;
}
class _Init { class _Init {
public: public:
_Init() { _Init() {
...@@ -63,8 +82,7 @@ void OprRegistry::add(const OprRegistry& record) { ...@@ -63,8 +82,7 @@ void OprRegistry::add(const OprRegistry& record) {
auto registry_ins = sd.id2reg.emplace(persist_id, record); auto registry_ins = sd.id2reg.emplace(persist_id, record);
mgb_assert( mgb_assert(
registry_ins.second || persist_id == dynamic_registry()->persist_type_id, registry_ins.second || persist_id == dynamic_registry()->persist_type_id,
"duplicated operator persist_type_id: %s", "duplicated operator name : %s", record.name.c_str());
std::to_string(persist_id).c_str());
OprRegistry* persis_record_ptr; OprRegistry* persis_record_ptr;
if (registry_ins.second) { if (registry_ins.second) {
...@@ -129,6 +147,73 @@ const OprRegistry* OprRegistry::find_by_unversioned_id(size_t unversioned_id) { ...@@ -129,6 +147,73 @@ const OprRegistry* OprRegistry::find_by_unversioned_id(size_t unversioned_id) {
return iter == uid2reg.end() ? nullptr : iter->second; return iter == uid2reg.end() ? nullptr : iter->second;
} }
//! find the registry equal to the giving version
const OprRegistryV2* OprRegistryV2::versioned_find_by_id(
const size_t id, uint8_t version) {
auto&& id_reg_map = static_data().version_id_reg_map;
auto iter_version = id_reg_map.find(version);
if (iter_version != id_reg_map.end()) {
auto iter = iter_version->second.find(id);
return iter == iter_version->second.end() ? nullptr : &iter->second;
}
return nullptr;
}
//! find the registry equal or below the giving version
const OprRegistryV2* OprRegistryV2::versioned_find_by_typeinfo(
Typeinfo* type, uint8_t version) {
const auto& type_reg_map = static_data().version_type_reg_map;
for (int version_id = version; version_id > 0; version_id--) {
auto iter_version = type_reg_map.find(version_id);
if (iter_version != type_reg_map.end()) {
auto iter = iter_version->second.find(type);
if (iter == iter_version->second.end()) {
continue;
} else {
return iter->second;
}
}
}
return nullptr;
}
void OprRegistryV2::versioned_add(
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version) {
mgb_assert(max_version >= min_version);
auto&& sd = static_data();
auto id = record.type_id;
uint64_t type_id = id;
//! record.type->name is nullptr when MGB_VERBOSE_TYPEINFO_NAME==0
if (record.type && record.type->name) {
type_id = MGB_HASH_RUNTIME(std::string(record.type->name));
}
for (uint8_t version = min_version; version <= max_version; version++) {
auto&& registry_map = sd.version_id_reg_map[version];
auto versioned_record = record;
versioned_record.version = version;
mgb_assert(
registry_map.find(id) == registry_map.end() ||
id == dynamic_registry_v2()->type_id,
"dduplicated OprRegistryV2 of %s\n", record.name.c_str());
auto registry_ins = registry_map.emplace(id, versioned_record);
if (!registry_ins.second) {
//! the registry is dynamic
mgb_assert(!record.converter);
registry_map[id] = versioned_record;
}
//! sometimes the register id and the hash typeinfo is not same, just as
//! dynamic Operator
if (id != type_id) {
mgb_assert(
registry_map.find(type_id) == registry_map.end(),
"dduplicated OprRegistryV2 of %s\n", record.name.c_str());
registry_map.emplace(type_id, versioned_record);
}
auto&& registry_type_map = sd.version_type_reg_map[version];
registry_type_map.emplace(record.type, &registry_map[id]);
}
}
void OprRegistry::add_using_dynamic_loader( void OprRegistry::add_using_dynamic_loader(
Typeinfo* type, const std::string& name, const OprDumper& dumper) { Typeinfo* type, const std::string& name, const OprDumper& dumper) {
// dynamic oprs are implemented by mapping different opr types to the same // dynamic oprs are implemented by mapping different opr types to the same
...@@ -140,6 +225,11 @@ void OprRegistry::add_using_dynamic_loader( ...@@ -140,6 +225,11 @@ void OprRegistry::add_using_dynamic_loader(
{}, {},
{}, {},
dynamic_registry()->unversioned_type_id}); dynamic_registry()->unversioned_type_id});
mgb_assert(type, "type must be not nullptr");
OprRegistryV2::versioned_add(
{type, dynamic_registry_v2()->type_id, type->name, dumper,
dynamic_registry_v2()->loader, nullptr},
CURRENT_VERSION, CURRENT_VERSION);
} }
#if MGB_ENABLE_DEBUG_UTIL #if MGB_ENABLE_DEBUG_UTIL
......
...@@ -9,10 +9,12 @@ void call_sereg() {} ...@@ -9,10 +9,12 @@ void call_sereg() {}
#include "../../opr/impl/blas.sereg.h" #include "../../opr/impl/blas.sereg.h"
#include "../../opr/impl/cond.sereg.h" #include "../../opr/impl/cond.sereg.h"
#include "../../opr/impl/dnn/dnn.sereg.h" #include "../../opr/impl/dnn/dnn.sereg.h"
#include "../../opr/impl/dnn/dnn.sereg.v2.h"
#include "./extern_c_opr.sereg.h" #include "./extern_c_opr.sereg.h"
#include "../../opr/impl/imgproc.sereg.h" #include "../../opr/impl/imgproc.sereg.h"
#include "../../opr/impl/indexing.sereg.h" #include "../../opr/impl/indexing.sereg.h"
#include "../../opr/impl/io.sereg.h" #include "../../opr/impl/io.sereg.h"
#include "../../opr/impl/io.sereg.v2.h"
#include "../../opr/impl/loop/forward.sereg.h" #include "../../opr/impl/loop/forward.sereg.h"
#include "../../opr/impl/loop/grad.sereg.h" #include "../../opr/impl/loop/grad.sereg.h"
#include "../../opr/impl/misc.sereg.h" #include "../../opr/impl/misc.sereg.h"
......
...@@ -53,7 +53,6 @@ struct OprRegistry { ...@@ -53,7 +53,6 @@ struct OprRegistry {
uint64_t unversioned_type_id; uint64_t unversioned_type_id;
MGE_WIN_DECLSPEC_FUC static void add(const OprRegistry& record); MGE_WIN_DECLSPEC_FUC static void add(const OprRegistry& record);
/*! /*!
* \brief register an operator to use dynamic loader * \brief register an operator to use dynamic loader
* *
...@@ -89,6 +88,39 @@ struct OprRegistry { ...@@ -89,6 +88,39 @@ struct OprRegistry {
#endif #endif
}; };
//! Convert some modified Opr to compatible Opr
using OprConvertToCompatible = thin_function<cg::OperatorNodeBase*(
cg::OperatorNodeBase*, const VarNodeArray&)>;
//! record of a single operator
struct OprRegistryV2 {
Typeinfo* type;
uint64_t type_id;
std::string name;
OprDumper dumper;
OprLoaderWrapper loader;
OprConvertToCompatible converter;
uint8_t version = 2;
MGE_WIN_DECLSPEC_FUC uint8_t get_version() const { return version; }
//! register opr load/dump to version2regmap
MGE_WIN_DECLSPEC_FUC static void versioned_add(
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version);
MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_id(
const size_t id, uint8_t version);
MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_typeinfo(
Typeinfo* type, uint8_t version);
#if MGB_ENABLE_DEBUG_UTIL
//! dump registered oprs
MGE_WIN_DECLSPEC_FUC static std::vector<std::pair<size_t, std::string>>
dump_registries();
#endif
};
} // namespace serialization } // namespace serialization
} // namespace mgb } // namespace mgb
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "megbrain/serialization/opr_load_dump.h" #include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/opr_registry.h" #include "megbrain/serialization/opr_registry.h"
#include "megbrain/serialization/opr_shallow_copy.h" #include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/serialization/oss_opr_load_dump.h"
#include "megbrain/utils/hash_ct.h" #include "megbrain/utils/hash_ct.h"
namespace mgb { namespace mgb {
...@@ -66,6 +67,9 @@ struct OprLoadDumpImpl { ...@@ -66,6 +67,9 @@ struct OprLoadDumpImpl {
} }
}; };
template <class Opr, size_t arity>
struct OprLoadDumpImplV2 : public OprLoadDumpImpl<Opr, arity> {};
#define IMPL_OPR_MAKER(_arity, _args...) \ #define IMPL_OPR_MAKER(_arity, _args...) \
template <class Opr> \ template <class Opr> \
struct OprMaker<Opr, _arity> { \ struct OprMaker<Opr, _arity> { \
...@@ -124,6 +128,12 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; ...@@ -124,6 +128,12 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
__caller_OprReg##_cls##_ins; \ __caller_OprReg##_cls##_ins; \
} }
#define MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _impl) \
namespace { \
[[gnu::unused]] ::mgb::serialization::OprRegistryCaller<_cls, _impl> \
__caller_V2_OprReg##_cls##_ins; \
}
// Trim the terminating null character and a "V0" like suffix from the string // Trim the terminating null character and a "V0" like suffix from the string
// then hash it. // then hash it.
// TODO: Get rid of this. // TODO: Get rid of this.
...@@ -138,17 +148,35 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; ...@@ -138,17 +148,35 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
: 0), \ : 0), \
20160701)>::val 20160701)>::val
//! call OprRegistry::add //! call OprRegistry::add for old serialization
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \ //! call OprRegistryV2::versioned_add for new serialization which is compatiable
do { \ //! with old serialization, convert is nullptr, this registry is just only for
::mgb::serialization::OprRegistry::add( \ //! varsion 1
{_cls::typeinfo(), \ #define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \
MGB_HASH_STR(#_cls), \ do { \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ ::mgb::serialization::OprRegistry::add( \
_dump, \ {_cls::typeinfo(), \
_load, \ MGB_HASH_STR(#_cls), \
{}, \ _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ _dump, \
_load, \
{}, \
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \
::mgb::serialization::OprRegistryV2::versioned_add( \
{_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \
::mgb::VERSION_1, ::mgb::VERSION_1); \
} while (0)
//! call OprRegistryV2::versioned_add for new serialization, in which convert the
//! function converter the Operator to the compatiable
#define MGB_SEREG_OPR_INTL_CALL_ADD_V2( \
_cls, _dump, _load, _convert, _version_min, _version_max) \
do { \
::mgb::serialization::OprRegistryV2::versioned_add( \
{_cls::typeinfo(), MGB_HASH_STR(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, _convert}, \
_version_min, _version_max); \
} while (0) } while (0)
/*! /*!
...@@ -171,6 +199,27 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; ...@@ -171,6 +199,27 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
} \ } \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls)
//! new dump/load function should implement in OprLoadDumpImplV2, _converter is
//! optional , if not implement pass nullptr
#define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \
namespace { \
namespace ser = ::mgb::serialization; \
struct _OprRegV2##_cls { \
using Impl = ser::OprLoadDumpImplV2<_cls, _arity>; \
static ser::OprWithOutputAccessor wrap_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \
} \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD_V2( \
_cls, Impl::dump, wrap_loader, _converter, _version_min, \
_version_max); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls)
//! use to check type is complete or not, midout need a complete type //! use to check type is complete or not, midout need a complete type
template <class T, class = void> template <class T, class = void>
struct IsComplete : std::false_type {}; struct IsComplete : std::false_type {};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册