提交 50faabf6 编写于 作者: M Megvii Engine Team

feat(serialization): support the registry for new serialization format

GitOrigin-RevId: 8eacd5e77c833d06a5aac2879d097182d4689fda
上级 a694fb33
......@@ -153,4 +153,6 @@ struct EnsureHashConstexpr {
#define MGB_HASH_STR(v) \
::mgb::EnsureHashConstexpr<::mgb::XXHash64CT::hash(v, sizeof(v), 20160701)>::val
#define MGB_HASH_RUNTIME(v) XXHash64CT::hash((v).c_str(), (v).size() + 1, 20160701)
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -52,6 +52,21 @@ mgb::cg::OperatorNodeBase* custom_loader(
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls)
#define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \
namespace { \
struct _OprRegV2##cls { \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD_V2( \
cls, ::mgb::serialization::custom_dumper, \
::mgb::serialization::custom_loader, nullptr, _version_min, \
_version_max); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(cls, _OprRegV2##cls)
using namespace mgb;
using CustomOpNode = opr::CustomOpNode;
CUSTOM_OP_SEREG_REG(CustomOpNode);
CUSTOM_OP_SEREG_REG_V2(CustomOpNode, 2, CURRENT_VERSION);
#include "megbrain/graph/symbol_var.h"
#include "megdnn/oprs/general.h"
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/softmax.h"
#include "megbrain/serialization/oss_opr_load_dump.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/nn.h"
namespace mgb {
namespace serialization {
template <>
struct OprLoadDumpImplV2<opr::Softmax, 1> {
using Opr = opr::Softmax;
using PersisParam = opr::Softmax::Param;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param());
}
static cg::OperatorNodeBase* replace_opr(
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) {
int32_t axis = opr->cast_final_safe<Opr>().param().axis;
auto input_var = inputs[0];
auto max_reduce_out =
opr::Reduce::make(input_var, {megdnn::Reduce::Mode::MAX, axis});
auto elemwise_sub_out = opr::Elemwise::make(
{input_var, max_reduce_out}, {megdnn::Elemwise::Mode::SUB});
auto elemwise_exp_out =
opr::Elemwise::make({elemwise_sub_out}, {megdnn::Elemwise::Mode::EXP});
auto sum_reduce_out =
opr::Reduce::make(elemwise_exp_out, {megdnn::Reduce::Mode::SUM, axis});
auto out = opr::Elemwise::make(
{elemwise_exp_out, sum_reduce_out}, {megdnn::Elemwise::Mode::TRUE_DIV});
return out.node()->owner_opr();
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto param = fbs_ctx.read_param<PersisParam>(0);
return Opr::make(inputs[0], param, config).node()->owner_opr();
}
};
template <
class Opr, class Maker0, class MegDNNConv,
class Maker1 = MakeConvCallerEmpty<MegDNNConv>,
class Maker2 = MakeConvCallerEmpty<MegDNNConv>,
typename ConvParam = megdnn::param::Convolution>
struct WithPolicyOprLoadDumpImpl {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<ConvParam>(opr.param());
ctx.write_param<megdnn::param::ExecutionPolicy>(
opr.execution_policy_transient());
}
static VarNode* make(
const cg::VarNodeArray& inputs, const ConvParam& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
VarNode* ret =
Maker0::template make<Opr>(inputs, param, execution_policy, config);
if (!ret) {
ret = Maker1::template make<Opr>(inputs, param, execution_policy, config);
}
if (!ret) {
ret = Maker2::template make<Opr>(inputs, param, execution_policy, config);
}
mgb_assert(ret);
return ret;
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto fopr = reinterpret_cast<const fbs::v2::Operator*>(
fbs_ctx.get_current_opr_data());
auto conv_param = fbs_ctx.read_param<ConvParam>(0);
megdnn::param::ExecutionPolicy policy;
if (fopr->additional_params() && fopr->additional_params()->size()) {
policy = fbs_ctx.read_param<megdnn::param::ExecutionPolicy>(1);
}
return make(inputs, conv_param, policy, config)->owner_opr();
}
};
template <>
struct OprLoadDumpImplV2<opr::Convolution, 0>
: public WithPolicyOprLoadDumpImpl<
opr::Convolution, MakeConvCaller2<megdnn::Convolution>,
megdnn::Convolution> {};
template <>
struct OprLoadDumpImplV2<opr::ConvolutionBackwardData, 0>
: public WithPolicyOprLoadDumpImpl<
opr::ConvolutionBackwardData, MakeConvCaller2<megdnn::Convolution>,
megdnn::Convolution, MakeConvCaller3<megdnn::Convolution>> {};
template <>
struct OprLoadDumpImplV2<opr::ConvolutionBackwardFilter, 0>
: public WithPolicyOprLoadDumpImpl<
opr::ConvolutionBackwardFilter, MakeConvCaller3<megdnn::Convolution>,
megdnn::Convolution> {};
template <>
struct OprLoadDumpImplV2<opr::Convolution3D, 0>
: public WithPolicyOprLoadDumpImpl<
opr::Convolution3D, MakeConvCaller2<megdnn::Convolution3D>,
megdnn::Convolution3D, MakeConvCallerEmpty<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D> {};
template <>
struct OprLoadDumpImplV2<opr::Convolution3DBackwardData, 0>
: public WithPolicyOprLoadDumpImpl<
opr::Convolution3DBackwardData,
MakeConvCaller2<megdnn::Convolution3D>, megdnn::Convolution3D,
MakeConvCaller3<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D> {};
template <>
struct OprLoadDumpImplV2<opr::Convolution3DBackwardFilter, 0>
: public WithPolicyOprLoadDumpImpl<
opr::Convolution3DBackwardFilter,
MakeConvCaller3<megdnn::Convolution3D>, megdnn::Convolution3D,
MakeConvCallerEmpty<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D> {};
template <>
struct OprLoadDumpImplV2<opr::ConvBiasForward, 0>
: public WithPolicyOprLoadDumpImpl<
opr::ConvBiasForward, MakeConvCaller2<megdnn::ConvBiasForward>,
megdnn::ConvBiasForward, MakeConvCaller3<megdnn::ConvBiasForward>,
MakeConvCaller4<megdnn::ConvBiasForward>, megdnn::param::ConvBias> {};
template <>
struct OprLoadDumpImplV2<opr::BatchConvBiasForward, 0>
: public WithPolicyOprLoadDumpImpl<
opr::BatchConvBiasForward,
MakeConvCaller2<megdnn::BatchConvBiasForward>,
megdnn::BatchConvBiasForward,
MakeConvCaller3<megdnn::BatchConvBiasForward>,
MakeConvCaller4<megdnn::BatchConvBiasForward>,
megdnn::param::BatchConvBias> {};
template <>
struct OprLoadDumpImplV2<opr::LocalShare, 0>
: public WithPolicyOprLoadDumpImpl<
opr::LocalShare, MakeLocalShareCaller2<megdnn::LocalShare>,
megdnn::LocalShare, MakeLocalShareCallerEmpty<megdnn::LocalShare>,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
megdnn::param::LocalShare> {};
template <>
struct OprLoadDumpImplV2<opr::LocalShareBackwardData, 0>
: public WithPolicyOprLoadDumpImpl<
opr::LocalShareBackwardData,
MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
megdnn::param::LocalShare> {};
template <>
struct OprLoadDumpImplV2<opr::LocalShareBackwardFilter, 0>
: public WithPolicyOprLoadDumpImpl<
opr::LocalShareBackwardFilter,
MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
MakeLocalShareCallerEmpty<megdnn::LocalShare>,
megdnn::param::LocalShare> {};
template <>
struct OprLoadDumpImplV2<opr::DeformableConvForward, 0>
: public WithPolicyOprLoadDumpImpl<
opr::DeformableConvForward,
MakeConvCaller4<megdnn::DeformableConvForward>, megdnn::Convolution> {
};
template <>
struct OprLoadDumpImplV2<opr::DeformableConvBackwardData, 0>
: public WithPolicyOprLoadDumpImpl<
opr::DeformableConvBackwardData,
MakeConvCaller5<megdnn::DeformableConvBackwardData>,
megdnn::Convolution> {};
template <>
struct OprLoadDumpImplV2<opr::DeformableConvBackwardFilter, 0>
: public WithPolicyOprLoadDumpImpl<
opr::DeformableConvBackwardFilter,
MakeConvCaller5<megdnn::DeformableConvBackwardFilter>,
megdnn::Convolution> {};
} // namespace serialization
namespace opr {
#define SERGE_OPR_V2_CONVERTER(_cls, _arity, _converter) \
MGB_SEREG_OPR_V2(_cls, _arity, _converter, VERSION_2, CURRENT_VERSION);
#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \
MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION);
SERGE_OPR_V2_CONVERTER(
Softmax, 1,
(mgb::serialization::OprLoadDumpImplV2<opr::Softmax, 1>::replace_opr));
SERGE_OPR_V2_NO_CONVERTER(ConvBiasForward, 0)
SERGE_OPR_V2_NO_CONVERTER(BatchConvBiasForward, 0);
SERGE_OPR_V2_NO_CONVERTER(Convolution, 0)
SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardData, 0)
SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardFilter, 0)
SERGE_OPR_V2_NO_CONVERTER(Convolution3D, 0);
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardData, 0);
SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardFilter, 0);
SERGE_OPR_V2_NO_CONVERTER(LocalShareForward, 0);
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardData, 0);
SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardFilter, 0);
SERGE_OPR_V2_NO_CONVERTER(DeformableConvForward, 0);
SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardData, 0);
SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardFilter, 0);
#undef SERGE_OPR_V2_CONVERTER
#undef SERGE_OPR_V2_NO_CONVERTER
} // namespace opr
} // namespace mgb
#endif
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/comp_node_env.h"
#include "megbrain/opr/dnn/softmax.h"
#include "megbrain/opr/io.h"
#include "megbrain/serialization/oss_opr_load_dump.h"
#include "megbrain/serialization/sereg.h"
#include "megbrain/serialization/internal/mgb_cpp_opr_generated.h"
#include "megbrain/serialization/internal/schema_v2_generated.h"
namespace mgb {
namespace serialization {
template <>
struct OprLoadDumpImplV2<opr::ImmutableTensor, 0> {
using Opr = opr::ImmutableTensor;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.dump_tensor(
{}, HostTensorND{}.copy_from(opr.value()).sync(),
Meth::VALUE_ANONYMOUS);
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
mgb_assert(inputs.empty());
auto fopr = reinterpret_cast<const fbs::v2::Operator*>(
fbs_ctx.get_current_opr_data());
if (fopr->tensors() && fopr->tensors()->size() > 0) {
auto val = fbs_ctx.load_tensor();
return Opr::make(fbs_ctx.graph(), *val, config).node()->owner_opr();
} else {
mgb_throw(SerializationError, "ImmutableTensor load with no tensor data.");
}
}
};
template <>
struct OprLoadDumpImplV2<opr::Host2DeviceCopy, 0> {
using Opr = opr::Host2DeviceCopy;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param(opr.param());
using Meth = OprDumpContext::TensorWriteMethod;
ctx.dump_tensor(
opr.name(), *opr.host_data(),
opr.param().dump_default_value ? Meth::VALUE_INPUT : Meth::META_INPUT);
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto param = fbs_ctx.read_param<Opr::Param>(0);
auto tensor = fbs_ctx.load_tensor();
return Opr::make(fbs_ctx.graph(), tensor, param, config).node()->owner_opr();
}
};
template <>
struct OprLoadDumpImplV2<opr::SharedDeviceTensorWithFormat, 0> {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<opr::SharedDeviceTensorWithFormat>();
HostTensorND val;
val.copy_from(opr.get_dev_tensor()).sync();
ctx.dump_tensor({}, val, Meth::VALUE_ANONYMOUS);
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto val = ctx.load_tensor();
auto dev_val =
std::make_shared<DeviceTensorND>(val->comp_node(), val->layout());
dev_val->copy_from_fixlayout(*val);
auto out_var =
opr::SharedDeviceTensorWithFormat::make(ctx.graph(), dev_val, config);
dev_val->sync();
return out_var.node()->owner_opr();
}
};
template <>
struct OprLoadDumpImplV2<opr::MultipleDeviceTensorHolder, 0> {
using Opr = opr::MultipleDeviceTensorHolder;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
uint32_t nr_val = opr.values().size();
for (uint32_t i = 0; i < nr_val; ++i) {
HostTensorND val;
val.copy_from(*opr.values()[i]).sync();
ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED);
}
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto fopr = reinterpret_cast<const fbs::v2::Operator*>(
fbs_ctx.get_current_opr_data());
uint32_t nr = 0;
if (fopr && fopr->tensors()) {
nr = fopr->tensors()->size();
}
Opr::ValueArray values(nr);
for (auto&& i : values) {
i = ctx.load_tensor_shared();
}
return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr();
}
};
template <>
struct OprLoadDumpImplV2<opr::MultipleDeviceTensorWithFormatHolder, 0> {
using Opr = opr::MultipleDeviceTensorWithFormatHolder;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
using Meth = OprDumpContext::TensorWriteMethod;
auto&& opr = opr_.cast_final_safe<Opr>();
uint32_t nr_val = opr.values().size();
for (uint32_t i = 0; i < nr_val; ++i) {
HostTensorND val;
auto value = *opr.values()[i];
val.copy_from(value).sync();
ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED);
}
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
mgb_assert(inputs.empty());
auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
auto fopr = reinterpret_cast<const fbs::v2::Operator*>(
fbs_ctx.get_current_opr_data());
uint32_t nr = 0;
if (fopr && fopr->tensors()) {
nr = fopr->tensors()->size();
}
Opr::ValueArray values(nr);
for (auto&& i : values) {
i = ctx.load_tensor_shared();
//! set tensor format
TensorLayout layout_with_format = i->layout();
if (i->storage().comp_node().mem_node() ==
CompNode::default_cpu().mem_node()) {
mgb_assert(
i->storage().ptr(),
"storage should not be nullptr if mem_node is "
"default_cpu");
HostTensorND src{i->storage().comp_node(), layout_with_format};
src.copy_from_fixlayout(*i).sync();
*i = DeviceTensorND::make_proxy(src);
} else {
//! actually only layout of this tensor will be used later, see
//! src/serialization/impl/batched_device_value_loader.cpp:49. But we
//! have no way to reset layout only, so just construct a invalid
//! storage instead
auto size = layout_with_format.span().dist_byte();
DeviceTensorStorage storage;
storage.reset(i->comp_node(), size, nullptr);
i->reset(storage, layout_with_format);
}
}
return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr();
}
};
} // namespace serialization
namespace opr {
#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \
MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION);
SERGE_OPR_V2_NO_CONVERTER(ImmutableTensor, 0);
SERGE_OPR_V2_NO_CONVERTER(Host2DeviceCopy, 0);
SERGE_OPR_V2_NO_CONVERTER(SharedDeviceTensorWithFormat, 0);
SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorWithFormatHolder, 0);
SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorHolder, 0);
} // namespace opr
} // namespace mgb
#endif
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -135,6 +135,16 @@ void LoopSerializer::reg_all() {
MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker);
MGB_SEREG_OPR_INTL_CALL_ADD(
CounterProvider, dump_counter_provider, load_counter_provider);
MGB_SEREG_OPR_INTL_CALL_ADD_V2(
opr::Loop, dump_loop, load_loop, nullptr, 2,
CURRENT_VERSION);
MGB_SEREG_OPR_INTL_CALL_ADD_V2(
InputMaker, dump_input_maker, load_input_maker, nullptr, 2,
CURRENT_VERSION);
MGB_SEREG_OPR_INTL_CALL_ADD_V2(
CounterProvider, dump_counter_provider, load_counter_provider, nullptr, 2,
CURRENT_VERSION);
}
void LoopSerializer::dump_loop(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
......
......@@ -20,6 +20,11 @@ struct StaticData {
ThinHashMap<Typeinfo*, OprRegistry*> type2reg;
std::unordered_map<std::string, OprRegistry*> name2reg;
ThinHashMap<size_t, OprRegistry*> unversioned_id2reg;
//! versioned OprRegistryV2, version_id_reg_map is used for Operator
//! load/shallow copy and version_type_reg_map is used for Operator dump
ThinHashMap<uint8_t, ThinHashMap<size_t, OprRegistryV2>> version_id_reg_map;
ThinHashMap<uint8_t, ThinHashMap<Typeinfo*, OprRegistryV2*>> version_type_reg_map;
};
StaticData& static_data() {
......@@ -47,6 +52,20 @@ const OprRegistry* dynamic_registry() {
return ret;
}
const OprRegistryV2* dynamic_registry_v2() {
static const OprRegistryV2* ret = nullptr;
if (ret)
return ret;
auto id = MGB_HASH_STR("dynamic");
OprRegistryV2::versioned_add(
{nullptr, id, {}, {}, dynamic_loader, {}}, CURRENT_VERSION,
CURRENT_VERSION);
ret = OprRegistryV2::versioned_find_by_id(id, CURRENT_VERSION);
mgb_assert(ret);
return ret;
}
class _Init {
public:
_Init() {
......@@ -63,8 +82,7 @@ void OprRegistry::add(const OprRegistry& record) {
auto registry_ins = sd.id2reg.emplace(persist_id, record);
mgb_assert(
registry_ins.second || persist_id == dynamic_registry()->persist_type_id,
"duplicated operator persist_type_id: %s",
std::to_string(persist_id).c_str());
"duplicated operator name : %s", record.name.c_str());
OprRegistry* persis_record_ptr;
if (registry_ins.second) {
......@@ -129,6 +147,73 @@ const OprRegistry* OprRegistry::find_by_unversioned_id(size_t unversioned_id) {
return iter == uid2reg.end() ? nullptr : iter->second;
}
//! find the registry equal to the giving version
const OprRegistryV2* OprRegistryV2::versioned_find_by_id(
const size_t id, uint8_t version) {
auto&& id_reg_map = static_data().version_id_reg_map;
auto iter_version = id_reg_map.find(version);
if (iter_version != id_reg_map.end()) {
auto iter = iter_version->second.find(id);
return iter == iter_version->second.end() ? nullptr : &iter->second;
}
return nullptr;
}
//! find the registry equal or below the giving version
const OprRegistryV2* OprRegistryV2::versioned_find_by_typeinfo(
Typeinfo* type, uint8_t version) {
const auto& type_reg_map = static_data().version_type_reg_map;
for (int version_id = version; version_id > 0; version_id--) {
auto iter_version = type_reg_map.find(version_id);
if (iter_version != type_reg_map.end()) {
auto iter = iter_version->second.find(type);
if (iter == iter_version->second.end()) {
continue;
} else {
return iter->second;
}
}
}
return nullptr;
}
void OprRegistryV2::versioned_add(
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version) {
mgb_assert(max_version >= min_version);
auto&& sd = static_data();
auto id = record.type_id;
uint64_t type_id = id;
//! record.type->name is nullptr when MGB_VERBOSE_TYPEINFO_NAME==0
if (record.type && record.type->name) {
type_id = MGB_HASH_RUNTIME(std::string(record.type->name));
}
for (uint8_t version = min_version; version <= max_version; version++) {
auto&& registry_map = sd.version_id_reg_map[version];
auto versioned_record = record;
versioned_record.version = version;
mgb_assert(
registry_map.find(id) == registry_map.end() ||
id == dynamic_registry_v2()->type_id,
"dduplicated OprRegistryV2 of %s\n", record.name.c_str());
auto registry_ins = registry_map.emplace(id, versioned_record);
if (!registry_ins.second) {
//! the registry is dynamic
mgb_assert(!record.converter);
registry_map[id] = versioned_record;
}
//! sometimes the register id and the hash typeinfo is not same, just as
//! dynamic Operator
if (id != type_id) {
mgb_assert(
registry_map.find(type_id) == registry_map.end(),
"dduplicated OprRegistryV2 of %s\n", record.name.c_str());
registry_map.emplace(type_id, versioned_record);
}
auto&& registry_type_map = sd.version_type_reg_map[version];
registry_type_map.emplace(record.type, &registry_map[id]);
}
}
void OprRegistry::add_using_dynamic_loader(
Typeinfo* type, const std::string& name, const OprDumper& dumper) {
// dynamic oprs are implemented by mapping different opr types to the same
......@@ -140,6 +225,11 @@ void OprRegistry::add_using_dynamic_loader(
{},
{},
dynamic_registry()->unversioned_type_id});
mgb_assert(type, "type must be not nullptr");
OprRegistryV2::versioned_add(
{type, dynamic_registry_v2()->type_id, type->name, dumper,
dynamic_registry_v2()->loader, nullptr},
CURRENT_VERSION, CURRENT_VERSION);
}
#if MGB_ENABLE_DEBUG_UTIL
......
......@@ -9,10 +9,12 @@ void call_sereg() {}
#include "../../opr/impl/blas.sereg.h"
#include "../../opr/impl/cond.sereg.h"
#include "../../opr/impl/dnn/dnn.sereg.h"
#include "../../opr/impl/dnn/dnn.sereg.v2.h"
#include "./extern_c_opr.sereg.h"
#include "../../opr/impl/imgproc.sereg.h"
#include "../../opr/impl/indexing.sereg.h"
#include "../../opr/impl/io.sereg.h"
#include "../../opr/impl/io.sereg.v2.h"
#include "../../opr/impl/loop/forward.sereg.h"
#include "../../opr/impl/loop/grad.sereg.h"
#include "../../opr/impl/misc.sereg.h"
......
......@@ -53,7 +53,6 @@ struct OprRegistry {
uint64_t unversioned_type_id;
MGE_WIN_DECLSPEC_FUC static void add(const OprRegistry& record);
/*!
* \brief register an operator to use dynamic loader
*
......@@ -89,6 +88,39 @@ struct OprRegistry {
#endif
};
//! Convert some modified Opr to compatible Opr
using OprConvertToCompatible = thin_function<cg::OperatorNodeBase*(
cg::OperatorNodeBase*, const VarNodeArray&)>;
//! record of a single operator
struct OprRegistryV2 {
Typeinfo* type;
uint64_t type_id;
std::string name;
OprDumper dumper;
OprLoaderWrapper loader;
OprConvertToCompatible converter;
uint8_t version = 2;
MGE_WIN_DECLSPEC_FUC uint8_t get_version() const { return version; }
//! register opr load/dump to version2regmap
MGE_WIN_DECLSPEC_FUC static void versioned_add(
const OprRegistryV2& record, uint8_t min_version, uint8_t max_version);
MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_id(
const size_t id, uint8_t version);
MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_typeinfo(
Typeinfo* type, uint8_t version);
#if MGB_ENABLE_DEBUG_UTIL
//! dump registered oprs
MGE_WIN_DECLSPEC_FUC static std::vector<std::pair<size_t, std::string>>
dump_registries();
#endif
};
} // namespace serialization
} // namespace mgb
......
......@@ -3,6 +3,7 @@
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/opr_registry.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/serialization/oss_opr_load_dump.h"
#include "megbrain/utils/hash_ct.h"
namespace mgb {
......@@ -66,6 +67,9 @@ struct OprLoadDumpImpl {
}
};
template <class Opr, size_t arity>
struct OprLoadDumpImplV2 : public OprLoadDumpImpl<Opr, arity> {};
#define IMPL_OPR_MAKER(_arity, _args...) \
template <class Opr> \
struct OprMaker<Opr, _arity> { \
......@@ -124,6 +128,12 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
__caller_OprReg##_cls##_ins; \
}
#define MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _impl) \
namespace { \
[[gnu::unused]] ::mgb::serialization::OprRegistryCaller<_cls, _impl> \
__caller_V2_OprReg##_cls##_ins; \
}
// Trim the terminating null character and a "V0" like suffix from the string
// then hash it.
// TODO: Get rid of this.
......@@ -138,7 +148,10 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
: 0), \
20160701)>::val
//! call OprRegistry::add
//! call OprRegistry::add for old serialization
//! call OprRegistryV2::versioned_add for new serialization which is compatiable
//! with old serialization, convert is nullptr, this registry is just only for
//! varsion 1
#define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \
do { \
::mgb::serialization::OprRegistry::add( \
......@@ -149,6 +162,21 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
_load, \
{}, \
MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \
::mgb::serialization::OprRegistryV2::versioned_add( \
{_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \
::mgb::VERSION_1, ::mgb::VERSION_1); \
} while (0)
//! call OprRegistryV2::versioned_add for new serialization, in which convert the
//! function converter the Operator to the compatiable
#define MGB_SEREG_OPR_INTL_CALL_ADD_V2( \
_cls, _dump, _load, _convert, _version_min, _version_max) \
do { \
::mgb::serialization::OprRegistryV2::versioned_add( \
{_cls::typeinfo(), MGB_HASH_STR(#_cls), \
_MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, _convert}, \
_version_min, _version_max); \
} while (0)
/*!
......@@ -171,6 +199,27 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {};
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls)
//! new dump/load function should implement in OprLoadDumpImplV2, _converter is
//! optional , if not implement pass nullptr
#define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \
namespace { \
namespace ser = ::mgb::serialization; \
struct _OprRegV2##_cls { \
using Impl = ser::OprLoadDumpImplV2<_cls, _arity>; \
static ser::OprWithOutputAccessor wrap_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \
} \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD_V2( \
_cls, Impl::dump, wrap_loader, _converter, _version_min, \
_version_max); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls)
//! use to check type is complete or not, midout need a complete type
template <class T, class = void>
struct IsComplete : std::false_type {};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册