提交 a694fb33 编写于 作者: M Megvii Engine Team

feat(serialization): implement the new serialization format

GitOrigin-RevId: 00f87f7ccdae7d313d8a97108c3ae712b1c27da3
上级 ca4a5da0
#include "batched_device_value_loader.h" #include "megbrain/serialization/batched_device_value_loader.h"
#include "megbrain/utils/arith_helper.h" #include "megbrain/utils/arith_helper.h"
namespace mgb { namespace mgb {
......
...@@ -57,7 +57,11 @@ GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() { ...@@ -57,7 +57,11 @@ GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() {
} }
std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file); std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file);
std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file); std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file);
std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file);
std::unique_ptr<GraphDumper> make_fbs_v2_dumper(std::unique_ptr<OutputFile> file);
bool is_fbs_file(InputFile& file); bool is_fbs_file(InputFile& file);
bool is_fbs_v2_file(InputFile& file);
bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) { bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) {
#if MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
...@@ -73,6 +77,11 @@ std::unique_ptr<GraphDumper> GraphDumper::make( ...@@ -73,6 +77,11 @@ std::unique_ptr<GraphDumper> GraphDumper::make(
case GraphDumpFormat::FLATBUFFERS: case GraphDumpFormat::FLATBUFFERS:
#if MGB_ENABLE_FBS_SERIALIZATION #if MGB_ENABLE_FBS_SERIALIZATION
return make_fbs_dumper(std::move(file)); return make_fbs_dumper(std::move(file));
#endif
MGB_FALLTHRU
case GraphDumpFormat::FLATBUFFERS_V2:
#if MGB_ENABLE_FBS_SERIALIZATION
return make_fbs_v2_dumper(std::move(file));
#endif #endif
MGB_FALLTHRU MGB_FALLTHRU
default: default:
...@@ -87,6 +96,11 @@ std::unique_ptr<GraphLoader> GraphLoader::make( ...@@ -87,6 +96,11 @@ std::unique_ptr<GraphLoader> GraphLoader::make(
case GraphDumpFormat::FLATBUFFERS: case GraphDumpFormat::FLATBUFFERS:
#if MGB_ENABLE_FBS_SERIALIZATION #if MGB_ENABLE_FBS_SERIALIZATION
return make_fbs_loader(std::move(file)); return make_fbs_loader(std::move(file));
#endif
MGB_FALLTHRU
case GraphDumpFormat::FLATBUFFERS_V2:
#if MGB_ENABLE_FBS_SERIALIZATION
return make_fbs_v2_loader(std::move(file));
#endif #endif
MGB_FALLTHRU MGB_FALLTHRU
default: default:
...@@ -100,6 +114,9 @@ Maybe<GraphDumpFormat> GraphLoader::identify_graph_dump_format(InputFile& file) ...@@ -100,6 +114,9 @@ Maybe<GraphDumpFormat> GraphLoader::identify_graph_dump_format(InputFile& file)
if (is_fbs_file(file)) { if (is_fbs_file(file)) {
return GraphDumpFormat::FLATBUFFERS; return GraphDumpFormat::FLATBUFFERS;
} }
if (is_fbs_v2_file(file)) {
return GraphDumpFormat::FLATBUFFERS_V2;
}
#endif #endif
return {}; return {};
} }
......
...@@ -11,17 +11,16 @@ ...@@ -11,17 +11,16 @@
*/ */
#if MGB_ENABLE_FBS_SERIALIZATION #if MGB_ENABLE_FBS_SERIALIZATION
#include "batched_device_value_loader.h"
#include "megbrain/graph/exc_extra_info.h" #include "megbrain/graph/exc_extra_info.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
#include "megbrain/serialization/batched_device_value_loader.h"
#include "megbrain/serialization/helper.h" #include "megbrain/serialization/helper.h"
#include "megbrain/serialization/internal/flatbuffers_helper.h" #include "megbrain/serialization/internal/flatbuffers_helper.h"
#include "megbrain/serialization/internal/schema_generated.h" #include "megbrain/serialization/internal/schema_generated.h"
#include "megbrain/serialization/metadata.h" #include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/opr_load_dump.h" #include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/serializer.h" #include "megbrain/serialization/serializer.h"
#include "megbrain/version.h" #include "serializer_oss_common.h"
#include <flatbuffers/flatbuffers.h> #include <flatbuffers/flatbuffers.h>
...@@ -33,47 +32,8 @@ using namespace mgb; ...@@ -33,47 +32,8 @@ using namespace mgb;
using namespace mgb::serialization; using namespace mgb::serialization;
namespace { namespace {
constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH;
constexpr uint32_t MGB_MAGIC = 0x4342474D;
// In order to maintain compatibility and to allow old models to be loaded, we keep
// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC)
constexpr uint32_t MAGIC_V0 = 0x5342474D;
// Used to judge whether Magic is old or new, the new magic(MGB_MAGIC) is true and the
// old magic(MAGIC_V0) is false.
bool magic_compare = true; bool magic_compare = true;
template <typename T>
bool contains_any_in_set(const SmallVector<T>& list, const ThinHashSet<T>& set) {
for (const auto& x : list) {
if (set.count(x)) {
return true;
}
}
return false;
}
void check_tensor_value_valid(const std::string& name, const HostTensorND& tensor) {
bool cond_normal = tensor.layout().format.is_default() &&
tensor.layout().is_physical_contiguous();
bool cond_lowbit = tensor.layout().dtype.is_quantized_lowbit() &&
tensor.layout().format.is_lowbit_aligned() &&
tensor.layout().is_contiguous();
mgb_assert(
cond_normal || cond_lowbit, "non-contiguous tensor: name=%s layout=%s",
name.c_str(), tensor.layout().to_string().c_str());
if (tensor.dtype() == dtype::Float32()) {
auto ptr = tensor.ptr<float>();
for (size_t i = 0, it = tensor.shape().total_nr_elems(); i < it; ++i) {
if (!std::isfinite(ptr[i])) {
mgb_log_warn("invalid tensor value in %s: %g", name.c_str(), ptr[i]);
break;
}
}
}
}
//! feature bits for backward compatibility; default value should be 0 //! feature bits for backward compatibility; default value should be 0
struct FeatureBits64 { struct FeatureBits64 {
//! reserved for new fields //! reserved for new fields
...@@ -947,13 +907,6 @@ std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file) { ...@@ -947,13 +907,6 @@ std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file) {
return std::make_unique<GraphLoaderOSS>(std::move(file)); return std::make_unique<GraphLoaderOSS>(std::move(file));
} }
bool is_fbs_file(InputFile& file) {
uint64_t magic_with_reserved = 0;
file.read(&magic_with_reserved, sizeof(magic_with_reserved));
file.skip(-sizeof(magic_with_reserved));
return (magic_with_reserved == MGB_MAGIC) || (magic_with_reserved == MAGIC_V0);
}
} // namespace serialization } // namespace serialization
} // namespace mgb } // namespace mgb
......
#if MGB_ENABLE_FBS_SERIALIZATION
#include "serializer_oss_common.h"
namespace mgb {
namespace serialization {
bool is_fbs_file(InputFile& file) {
//! check whether the model format is flatbuffer v2
uint64_t magic_with_reserved = 0;
file.read(&magic_with_reserved, sizeof(magic_with_reserved));
file.skip(-sizeof(magic_with_reserved));
return (magic_with_reserved == MGB_MAGIC) || (magic_with_reserved == MAGIC_V0);
}
void check_tensor_value_valid(const std::string& name, const HostTensorND& tensor) {
bool cond_normal = tensor.layout().format.is_default() &&
tensor.layout().is_physical_contiguous();
bool cond_lowbit = tensor.layout().dtype.is_quantized_lowbit() &&
tensor.layout().format.is_lowbit_aligned() &&
tensor.layout().is_contiguous();
mgb_assert(
cond_normal || cond_lowbit, "non-contiguous tensor: name=%s layout=%s",
name.c_str(), tensor.layout().to_string().c_str());
if (tensor.dtype() == dtype::Float32()) {
auto ptr = tensor.ptr<float>();
for (size_t i = 0, it = tensor.shape().total_nr_elems(); i < it; ++i) {
if (!std::isfinite(ptr[i])) {
mgb_log_warn("invalid tensor value in %s: %g", name.c_str(), ptr[i]);
break;
}
}
}
}
} // namespace serialization
} // namespace mgb
#endif
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/serialization/serializer.h"
#include "megbrain/version.h"
namespace mgb {
namespace serialization {
constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH;
constexpr uint32_t MGB_MAGIC = 0x4342474D;
// In order to maintain compatibility and to allow old models to be loaded, we keep
// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC)
constexpr uint32_t MAGIC_V0 = 0x5342474D;
void check_tensor_value_valid(const std::string& name, const HostTensorND& tensor);
template <typename T>
bool contains_any_in_set(const SmallVector<T>& list, const ThinHashSet<T>& set) {
for (const auto& x : list) {
if (set.count(x)) {
return true;
}
}
return false;
}
} // namespace serialization
} // namespace mgb
#endif
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/comp_node_env.h"
#include "megbrain/opr/io.h"
#include "megbrain/serialization/helper.h"
#include "megbrain/serialization/internal/flatbuffers_helper.h"
#include "megbrain/serialization/internal/schema_v2_generated.h"
#include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/oss_opr_load_dump.h"
#include "megbrain/utils/hash_ct.h"
#include "megdnn/tensor_format.h"
#include "serializer_oss_common.h"
#include "megbrain/gopt/framework.h"
namespace mgb {
namespace serialization {
/*!
* \brief replace the the opr who has the replace_opr methord in OprLoadDumpImplV2
*/
class PassConvertToCompatible : public gopt::Pass {
ThinHashMap<
Typeinfo*, thin_function<cg::OperatorNodeBase*(
cg::OperatorNodeBase*, const VarNodeArray&)>>
m_opr_replace_func;
gopt::VarReplaceCheckFlag m_var_replace_check_flag =
gopt::VarReplaceCheckFlag::CHECK_ALL;
public:
const char* name() const override { return "PassConvertToCompatible"; };
PassConvertToCompatible& set_var_replace_check_flag(
gopt::VarReplaceCheckFlag flag) {
m_var_replace_check_flag = flag;
return *this;
}
void apply(gopt::OptState& state) const override {
state.set_var_replace_check_flag(m_var_replace_check_flag);
auto rewriter = state.graph().make_rewriter();
auto on_opr = [this, &rewriter](cg::OperatorNodeBase* opr) {
auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
if (it != m_opr_replace_func.end()) {
VarNodeArray new_inp;
new_inp.clear();
new_inp.reserve(opr->input().size());
for (auto i : opr->input()) {
new_inp.push_back(rewriter.get_var(i));
}
auto new_opr = (it->second)(opr, new_inp);
auto &&origin_out = opr->output(), &&cur_out = new_opr->output();
for (size_t i = 0; i < std::min(origin_out.size(), cur_out.size());
i++) {
rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
}
} else {
rewriter.auto_replace_outputs(opr);
}
};
state.graph().iter(on_opr);
rewriter.apply_inplace();
}
static std::unique_ptr<PassConvertToCompatible> make(
const SymbolVarArray& output_vars) {
auto ret = std::make_unique<PassConvertToCompatible>();
// iterate oprs to init
auto on_opr = [&](cg::OperatorNodeBase* opr) {
if (!GraphDumper::should_remove_in_dump(opr)) {
auto registry = OprRegistryV2::versioned_find_by_typeinfo(
opr->dyn_typeinfo(), CURRENT_VERSION);
mgb_throw_if(
!registry,
cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>,
"serialization as FlatBuffers is not supported for "
"operator %s, typeinfo %p",
opr->dyn_typeinfo()->name, opr->dyn_typeinfo());
if (registry->converter) {
ret->m_opr_replace_func[opr->dyn_typeinfo()] = registry->converter;
}
}
};
cg::DepOprIter dep_opr_iter{on_opr};
for (auto i : output_vars) {
dep_opr_iter.add(i.node()->owner_opr());
}
return ret;
};
};
namespace {
fbs::v2::TensorFormat get_flatbuffer_tensor_format_type(
const TensorLayout::Format& format) {
using Type = megdnn::TensorFormat::Type;
switch (format.type()) {
case Type::DEFAULT:
return fbs::v2::TensorFormat::TensorFormat_DefaultTensorFormat;
case Type::IMAGE2D_PACK4:
return fbs::v2::TensorFormat::TensorFormat_Image2DPackedTensorFormat;
case Type::LOWBITS_ALIGNED_TO_BYTE:
return fbs::v2::TensorFormat::TensorFormat_LowbitsAlignedTensorFormat;
default:
mgb_throw(
SerializationError, "invalid tensor format type in serialization.");
}
}
} // namespace
flatbuffers::Offset<fbs::DType> GraphDumperOSSV2::build_dtype(DType dtype) {
return fbs::intl::build_dtype(m_builder, dtype);
}
flatbuffers::Offset<void> GraphDumperOSSV2::build_tensor_format(
const TensorLayout::Format& format) {
using Type = megdnn::TensorFormat::Type;
switch (format.type()) {
case Type::DEFAULT:
return fbs::v2::CreateDefaultTensorFormat(m_builder).Union();
case Type::IMAGE2D_PACK4:
return fbs::v2::CreateImage2DPackedTensorFormat(
m_builder, format.as_impl<megdnn::Image2DPack4TensorFormat>()
.align_axis())
.Union();
case Type::LOWBITS_ALIGNED_TO_BYTE: {
auto size_bite = format.as_impl<megdnn::LowbitsAlignedToBytesTensorFormat>()
.size_nbits();
auto align_size_in_bits =
format.as_impl<megdnn::LowbitsAlignedToBytesTensorFormat>()
.align_size_in_bits();
return fbs::v2::CreateLowbitsAlignedTensorFormat(
m_builder, size_bite, align_size_in_bits)
.Union();
}
default:
mgb_throw(
SerializationError, "invalid tensor format type in serialization.");
}
}
flatbuffers::Offset<fbs::v2::MiddleTensor> GraphDumperOSSV2::build_middle_tensor(
const SymbolVar var) {
mgb_assert(var.node());
auto fbname = m_builder.CreateSharedString(var.node()->name());
flatbuffers::Offset<fbs::v2::MiddleTensor> serialized_middle_tensor;
if (var.node()->dev_tensor_valid()) {
auto layout = var.node()->layout();
auto fshape =
m_builder.CreateVectorScalarCast<uint32_t>(layout.shape, layout.ndim);
auto fcomp_node = fbs::v2::CreateCompNode(
m_builder, m_builder.CreateSharedString(
var.node()->comp_node().to_string_logical()));
auto fdtype = build_dtype(layout.dtype);
auto fformat_type = get_flatbuffer_tensor_format_type(layout.format);
auto fformat = build_tensor_format(layout.format);
serialized_middle_tensor = fbs::v2::CreateMiddleTensor(
m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat);
}
serialized_middle_tensor = fbs::v2::CreateMiddleTensor(m_builder, fbname);
return serialized_middle_tensor;
}
flatbuffers::Offset<fbs::v2::OutputVar> GraphDumperOSSV2::build_output_var(
const SymbolVar var) {
auto out_node = var.node();
if (m_var2midtensor_id.find(var.node()) == m_var2midtensor_id.end()) {
mgb_assert(m_var_remove_in_dump.find(var.node()) != m_var_remove_in_dump.end());
out_node = m_var_remove_in_dump[var.node()];
}
return fbs::v2::CreateOutputVar(
m_builder, m_var2midtensor_id.at(out_node), var.node()->id());
}
void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) {
m_oprs_to_dump.clear();
// iterate oprs to init
auto on_opr = [&](cg::OperatorNodeBase* opr) {
if (should_remove_in_dump(opr)) {
mgb_assert(opr->input().size() == 1);
// Copy input ID to output
for (auto i : opr->output()) {
if (m_var_remove_in_dump.find(opr->input(0)) !=
m_var_remove_in_dump.end()) {
m_var_remove_in_dump[i] = m_var_remove_in_dump[opr->input(0)];
} else {
m_var_remove_in_dump[i] = opr->input(0);
}
}
} else {
auto registry = OprRegistryV2::versioned_find_by_typeinfo(
opr->dyn_typeinfo(), CURRENT_VERSION);
if (!registry || !registry->dumper) {
mgb_throw(
cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>,
"serialization as FlatBuffers is not supported for "
"operator %s",
opr->dyn_typeinfo()->name);
}
m_oprs_to_dump.emplace_back(opr, registry);
}
};
cg::DepOprIter dep_opr_iter{on_opr};
for (auto i : endpoints) {
dep_opr_iter.add(i.node()->owner_opr());
}
}
flatbuffers::Offset<fbs::v2::Metadata> GraphDumperOSSV2::build_metadata(
const Metadata& metadata) {
auto user_info = m_builder.CreateSharedString(metadata.user_info);
fbs::v2::MetadataBuilder builder(m_builder);
builder.add_is_valid(metadata.is_valid);
builder.add_graph_modified(metadata.graph_modified);
builder.add_optimize_options(metadata.optimize_options);
builder.add_user_info(user_info);
return builder.Finish();
}
flatbuffers::Offset<fbs::v2::Operator> GraphDumperOSSV2::build_single_opr(
cg::OperatorNodeBase* opr, const OprRegistryV2* registry) {
m_cur_opr = opr;
++m_cur_rst.nr_opr;
using namespace flatbuffers;
Offset<Vector<uint32_t>> inputs;
if (m_cur_opr->input().size()) {
std::vector<uint32_t> v;
v.reserve(m_cur_opr->input().size());
for (auto inp : m_cur_opr->input()) {
if (m_var2midtensor_id.find(inp) != m_var2midtensor_id.end()) {
v.emplace_back(m_var2midtensor_id.at(inp));
} else {
mgb_assert(
m_var_remove_in_dump.find(inp) != m_var_remove_in_dump.end(),
"when dump the model, the dependence of var is wrong.");
v.emplace_back(m_var2midtensor_id.at(m_var_remove_in_dump[inp]));
}
}
inputs = m_builder.CreateVector(v);
}
m_cur_opr_tensor.clear();
m_blobs.clear();
m_cur_opr_param.clear();
m_cur_opr_param_type.clear();
registry->dumper(*this, *m_cur_opr);
Offset<Vector<Offset<fbs::v2::CompNode>>> comp_node;
auto& config = m_cur_opr->config();
if (config.has_comp_node_set()) {
std::vector<flatbuffers::Offset<fbs::v2::CompNode>> cns;
for (const auto& cn : config.comp_node()) {
cns.emplace_back(fbs::v2::CreateCompNode(
m_builder, m_builder.CreateSharedString(cn.to_string_logical())));
}
comp_node = m_builder.CreateVector(cns);
}
Offset<String> operator_name;
if (m_config.keep_op_name) {
operator_name = m_builder.CreateSharedString(m_cur_opr->name());
}
auto output_dtype = build_dtype(config.output_dtype());
Offset<Vector<uint32_t>> outputs;
if (m_cur_opr->output().size()) {
std::vector<uint32_t> v;
v.reserve(m_cur_opr->output().size());
for (auto out : m_cur_opr->output()) {
if (!out->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
auto fbs_out = build_middle_tensor(out);
m_model_middle_tensors.push_back(fbs_out);
m_var2midtensor_id[out] = m_model_middle_tensors.size() - 1;
v.emplace_back(m_var2midtensor_id.at(out));
}
}
outputs = m_builder.CreateVector(v);
}
Offset<Vector<Offset<fbs::v2::Tensor>>> tensors;
if (m_cur_opr_tensor.size())
tensors = m_builder.CreateVector(m_cur_opr_tensor);
//! the blobs data is used by custom data
//! m_blobs will be filled by the Operator dumper function
Offset<Vector<Offset<fbs::v2::Blob>>> blobs;
if (m_blobs.size())
blobs = m_builder.CreateVector(m_blobs);
Offset<Vector<uint8_t>> additional_params_type;
Offset<Vector<Offset<void>>> additional_params;
auto param_cnt = m_cur_opr_param_type.size();
if (param_cnt > 1) {
additional_params_type = m_builder.CreateVectorScalarCast<uint8_t>(
m_cur_opr_param_type.data() + 1, param_cnt - 1);
additional_params =
m_builder.CreateVector(m_cur_opr_param.data() + 1, param_cnt - 1);
}
auto opr_type = m_builder.CreateSharedString(registry->name);
fbs::v2::OperatorBuilder builder(m_builder);
builder.add_type(opr_type);
builder.add_type_id(registry->type_id);
builder.add_inputs(inputs);
builder.add_outputs(outputs);
if (m_config.keep_opr_priority) {
builder.add_priority(opr->node_prop().attribute().priority);
}
builder.add_comp_node(comp_node);
builder.add_opr_version(registry->get_version());
builder.add_name(operator_name);
builder.add_output_dtype(output_dtype);
if (param_cnt > 0) {
builder.add_param_type(m_cur_opr_param_type[0]);
builder.add_param(m_cur_opr_param[0]);
}
if (param_cnt > 1) {
builder.add_additional_params_type(additional_params_type);
builder.add_additional_params(additional_params);
}
builder.add_tensors(tensors);
builder.add_custom_data(blobs);
m_cur_opr = nullptr;
return builder.Finish();
}
SymbolVarArray GraphDumperOSSV2::converter_all_opr_to_compatiable(
const SymbolVarArray& output_vars) {
gopt::GraphOptimizer optimizer;
VarNodeArray rets_var;
for (auto& symbolvar : output_vars) {
rets_var.push_back(symbolvar.node());
}
optimizer.add_pass(PassConvertToCompatible::make(output_vars));
optimizer.apply_inplace(rets_var);
SymbolVarArray dst_vars;
for (auto& var : rets_var) {
dst_vars.push_back({var});
}
return dst_vars;
}
GraphDumper::DumpResult GraphDumperOSSV2::dump(
const SymbolVarArray& output_vars, const DumpConfig& config,
const Metadata& metadata) {
mgb_throw_if(output_vars.empty(), SerializationError, "Can't dump empty graph");
auto&& new_output_vars = converter_all_opr_to_compatiable(output_vars);
auto begin_pos = m_file->tell();
m_config = config;
m_builder.Reset();
m_output_vars.clear();
m_cur_rst = {};
m_used_input_names.clear();
m_used_param_names.clear();
m_var_remove_in_dump.clear();
m_model_middle_tensors.clear();
m_var2midtensor_id.clear();
m_nr_shared_tensor = 0;
// process output vars
bool keep_output_var_name = m_config.keep_var_name >= 1;
std::unordered_set<std::string> output_var_names;
for (auto i : new_output_vars) {
mgb_assert(
!i.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
"can not dump var with VOLATILE_CONTENT flag: %s",
cg::dump_var_info({i.node()}).c_str());
if (m_output_vars.insert(i.node()).second && keep_output_var_name) {
auto name_ins = output_var_names.insert(i.node()->name()).second;
mgb_assert(name_ins, "duplicated output var name: %s", i.node()->cname());
}
}
// Dump metadata
auto fbmeta = build_metadata(metadata);
// Dump operators
init_oprs_to_dump(new_output_vars);
std::vector<flatbuffers::Offset<fbs::v2::Operator>> oprs;
for (auto&& i : m_oprs_to_dump) {
oprs.emplace_back(build_single_opr(i.first, i.second));
}
auto fb_oprs = m_builder.CreateVector(oprs);
// Dump output vars
std::vector<flatbuffers::Offset<fbs::v2::OutputVar>> output_vars_idx;
output_vars_idx.reserve(new_output_vars.size());
for (auto i : new_output_vars) {
auto foutput_vars_idx = build_output_var(i);
output_vars_idx.push_back(foutput_vars_idx);
}
auto fb_output_vars = m_builder.CreateVector(output_vars_idx);
auto fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors);
fbs::v2::ModelBuilder model(m_builder);
model.add_mge_version(MGB_VERSION);
model.add_oprs(fb_oprs);
model.add_middle_tensors(fb_mid_tensor);
model.add_output_vars_idx(fb_output_vars);
model.add_nr_shared_tensor(m_nr_shared_tensor);
model.add_metadata(fbmeta);
m_builder.FinishSizePrefixed(model.Finish(), fbs::v2::ModelIdentifier());
// Write serialized fbs::Graph
m_file->write(m_builder.GetBufferPointer(), m_builder.GetSize());
// Finalize DumpResult
auto&& ret = m_cur_rst;
for (size_t i = 0; i < new_output_vars.size(); i++) {
ret.outputs.emplace_back(
keep_output_var_name ? new_output_vars[i].node()->cname()
: ssprintf("unnamed%zu", i));
}
std::sort(ret.inputs.begin(), ret.inputs.end());
mgb_assert(ret.nr_opr == m_oprs_to_dump.size());
ret.tot_bytes = m_file->tell() - begin_pos;
return ret;
}
void GraphDumperOSSV2::dump_tensor(
const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) {
using namespace flatbuffers;
using Meth = TensorWriteMethod;
mgb_assert(
(method == Meth::VALUE_ANONYMOUS) ^ (!name.empty()),
"name must be non-empty for non Meth::VALUE_ANONYMOUS tensors");
bool has_value = method != Meth::META_INPUT;
bool should_keep_name = true;
switch (method) {
case Meth::VALUE_ANONYMOUS:
should_keep_name = false;
break;
case Meth::VALUE_SHARED:
should_keep_name = m_config.keep_param_name;
++m_nr_shared_tensor;
if (m_config.keep_param_name) {
mgb_assert(
m_used_param_names.insert(name).second,
"duplicated VALUE_SHARED tensor name: %s", name.c_str());
m_cur_rst.params.emplace_back(name);
}
break;
case Meth::META_INPUT:
case Meth::VALUE_INPUT:
mgb_assert(!name.empty(), "empty input tensor name");
mgb_assert(
m_used_input_names.insert(name).second,
"duplicated input tensor name: %s", name.c_str());
m_cur_rst.inputs.emplace_back(name);
break;
}
auto& layout = tensor.layout();
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> data;
if (has_value) {
check_tensor_value_valid(name, tensor);
auto&& dumper = m_config.tensor_value_dumper;
if (dumper) {
mgb_log_warn(
"serialization v2 format is pure flatbuffer format, not support "
"user tensor value dumper");
}
data = m_builder.CreateVector(
reinterpret_cast<uint8_t*>(tensor.raw_ptr()), layout.span().high_byte);
m_cur_rst.tensor_value_bytes += layout.span().high_byte;
}
auto fbname = should_keep_name ? m_builder.CreateSharedString(name) : 0;
auto fshape = m_builder.CreateVectorScalarCast<uint32_t>(layout.shape, layout.ndim);
auto fcomp_node = fbs::v2::CreateCompNode(
m_builder,
m_builder.CreateSharedString(tensor.comp_node().to_string_logical()));
auto fdtype = build_dtype(layout.dtype);
auto fformat_type = get_flatbuffer_tensor_format_type(layout.format);
auto fformat = build_tensor_format(layout.format);
auto serialized_tensor = fbs::v2::CreateTensor(
m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat, data);
m_cur_opr_tensor.emplace_back(serialized_tensor);
}
void GraphDumperOSSV2::dump_buf_with_len(const void* data, uint32_t size) {
auto blob = fbs::v2::CreateBlob(
m_builder, m_builder.CreateVector(static_cast<const uint8_t*>(data), size));
m_blobs.emplace_back(blob);
}
// ----------------------------- Loader --------------------------------------
CompNode GraphLoaderOSSV2::OprLoadContextImpl::load_comp_node(
const fbs::v2::CompNode* comp_node) {
mgb_assert(comp_node);
if (!comp_node->logical_locator())
return {};
auto loc = CompNode::Locator::parse(comp_node->logical_locator()->str());
m_loader->m_cur_load_config->comp_node_mapper(loc);
return CompNode::load(loc);
}
TensorFormat load_tensor_format(
const fbs::v2::TensorFormat fformat_type, const void* fformat,
const CompNode& comp_node) {
switch (fformat_type) {
case fbs::v2::TensorFormat_DefaultTensorFormat:
return megdnn::DefaultTensorFormat::make();
case fbs::v2::TensorFormat_Image2DPackedTensorFormat: {
auto image_format =
static_cast<const fbs::v2::Image2DPackedTensorFormat*>(fformat);
auto handle =
MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node)).handle();
return megdnn::Image2DPack4TensorFormat::make(
image_format->align_axis(), handle);
}
case fbs::v2::TensorFormat_LowbitsAlignedTensorFormat: {
auto lowbit_format =
static_cast<const fbs::v2::LowbitsAlignedTensorFormat*>(fformat);
return megdnn::LowbitsAlignedToBytesTensorFormat::make(
lowbit_format->size_nbits());
}
default:
mgb_throw(
SerializationError, "invalid tensor format type in serialization.");
}
}
TensorLayout load_tensor_layout(
const fbs::v2::Tensor* tensor, const CompNode& comp_node) {
TensorLayout layout;
if (tensor->shape()) {
layout.ndim = tensor->shape()->size();
std::copy(tensor->shape()->begin(), tensor->shape()->end(), layout.shape);
}
if (tensor->dtype()) {
// modify data type inplace for TensorLayout
layout.modify_dtype_inplace(fbs::intl::load_dtype(tensor->dtype()));
}
if (tensor->format() && tensor->format_type()) {
layout.format =
load_tensor_format(tensor->format_type(), tensor->format(), comp_node);
}
layout.init_contiguous_stride();
return layout;
}
//! the opr loader should make sure the exist of tensors and the number of
//! tensor, here just assert it.
std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor() {
mgb_assert(
m_current_opr->tensors() &&
m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
auto comp_node = load_comp_node(tensor->comp_node());
auto layout = load_tensor_layout(tensor, comp_node);
auto ret = std::make_shared<HostTensorND>(comp_node, layout);
auto&& loader = m_loader->m_cur_load_config->tensor_value_loader;
if (tensor->data() && tensor->data()->size() > 0) {
if (loader) {
mgb_log_warn(
"serialization v2 format is pure flatbuffer format, not support "
"user tensor value loader");
}
memcpy(ret->raw_ptr(), tensor->data()->data(), tensor->data()->size());
}
if (tensor->name()) {
m_tensor_map[tensor->name()->str()] = ret;
}
if (auto&& mod = m_loader->m_cur_load_config->tensor_modifier) {
bool has_value = false;
if (tensor && tensor->data()) {
has_value = tensor->data()->size() != 0;
}
mod(tensor->name() ? tensor->name()->str() : "", has_value, *ret);
}
return ret;
}
std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
load_tensor_shared() {
mgb_assert(
m_current_opr->tensors() &&
m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
auto comp_node = load_comp_node(tensor->comp_node());
auto layout = load_tensor_layout(tensor, comp_node);
mgb_assert(tensor->data());
auto&& shared_pair = m_loader->m_shared_tensor_map.at(m_cur_shared_tensor_idx++);
auto&& shared_tensor_ref = shared_pair.second[comp_node.mem_node()];
if (shared_tensor_ref) {
if (shared_tensor_ref->comp_node() == comp_node)
return shared_tensor_ref;
// same mem node but different comp node, change comp node and share
// value
auto ret = std::make_shared<DeviceTensorND>(*shared_tensor_ref);
ret->comp_node(comp_node);
return ret;
}
if (tensor->name()) {
shared_pair.first = tensor->name()->str();
}
if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) {
// directly forward CPU memory
HostTensorND hv{comp_node};
if (tensor->data() && tensor->data()->size() > 0) {
hv.dtype(layout.dtype).resize(layout);
memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size());
}
shared_tensor_ref = std::make_shared<DeviceTensorND>();
*shared_tensor_ref = DeviceTensorND::make_proxy(hv);
} else {
// use lazy load for non-CPU devices
HostTensorND hv{CompNode::default_cpu()};
if (tensor->data() && tensor->data()->size() > 0) {
hv.dtype(layout.dtype).resize(layout);
memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size());
}
shared_tensor_ref = m_device_value_loader.make(comp_node, std::move(hv));
}
return shared_tensor_ref;
}
Metadata GraphLoaderOSSV2::OprLoadContextImpl::load_metadata() {
const auto* fbmeta = m_loader->m_model->metadata();
Metadata ret;
if (fbmeta) {
ret.is_valid = fbmeta->is_valid();
ret.graph_modified = fbmeta->graph_modified();
if (fbmeta->user_info()) {
ret.user_info = fbmeta->user_info()->str();
ret.has_user_info = true;
}
if (fbmeta->optimize_options()) {
ret.optimize_options = fbmeta->optimize_options();
ret.optimized_for_inference = true;
}
}
return ret;
}
void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr(
const fbs::v2::Operator* fbopr) {
m_cur_opr_tensor_cnt = 0;
m_cur_opr_blob_cnt = 0;
m_cur_opr_param_cnt = 0;
OperatorNodeConfig config;
if (fbopr->output_dtype()) {
config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype()));
}
if (fbopr->name()) {
config.name(fbopr->name()->str());
}
if (fbopr->comp_node()) {
auto cnt = fbopr->comp_node()->size();
cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt);
for (size_t i = 0; i < cnt; i++) {
CompNode cn{};
auto node = fbopr->comp_node()->Get(i);
if (node) {
cn = load_comp_node(node);
}
comp_node_arr[i] = cn;
}
config.comp_node_arr(comp_node_arr);
}
//! opr version must be exist
uint8_t opr_version = fbopr->opr_version();
auto type_id = fbopr->type_id();
auto opr_type = fbopr->type()->str();
const OprRegistryV2* registry =
OprRegistryV2::versioned_find_by_id(type_id, opr_version);
mgb_throw_if(
!registry, SerializationError,
"failed to find opr with type %s id is %zu, use python env "
"config.dump_registered_oprs() to get a dict that maps from "
"opr id to opr name",
fbopr->type()->str().c_str(), type_id);
// load inputs
VarNodeArray inputs;
if (fbopr->inputs()) {
inputs.resize(fbopr->inputs()->size());
for (size_t i = 0; i < inputs.size(); ++i) {
inputs[i] = m_id2varnode.at(fbopr->inputs()->Get(i));
}
}
// call loader
auto accessor = registry->loader(*this, inputs, config);
auto opr = accessor.opr();
// check opr type; note that:
// 1. registry->type may be empty for dynamic opr loaders or legacy oprs
// 2. due to some optimization, an opr may be replaced by ImmutableTensor
mgb_assert(
opr && (opr->dyn_typeinfo() == registry->type || !registry->type ||
opr->same_type<opr::ImmutableTensor>()),
"got_type=%s expected_type=%s", opr ? opr->dyn_typeinfo()->name : nullptr,
registry->type->name);
// record output vars; read output names
size_t i = 0;
for (auto ovar : accessor.output()) {
if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
m_id2varnode.push_back(ovar);
if (fbopr->outputs()) {
auto id = fbopr->outputs()->Get(i);
mgb_assert(
m_id2varnode.size() - 1 == fbopr->outputs()->Get(i),
"id2var is %zu, fbs get id is %d\n", m_id2varnode.size() - 1,
fbopr->outputs()->Get(i));
if (m_middle_tensors.size() > i) {
auto name = m_middle_tensors[id]->name()->str();
ovar->name(name);
}
}
i++;
}
}
opr->node_prop().attribute().priority = fbopr->priority();
}
GraphLoader::LoadResult GraphLoaderOSSV2::OprLoadContextImpl::load_oprs() {
// load oprs
const auto* oprs = m_loader->m_model->oprs();
{
// inplace arith graph optimization is disabled during opr load
// it tries to restore the same graph as it was dumped
// see test TestSerializer2.LOGEXP for example
GraphLoader::ScopedGraphOptDisabler _(m_graph);
for (flatbuffers::uoffset_t i = 0; i < oprs->size(); ++i) {
m_current_opr = oprs->Get(i);
load_single_opr(m_current_opr);
}
}
// batched loading device values
m_device_value_loader.apply();
LoadResult ret;
ret.graph = m_graph;
ret.tensor_map = m_tensor_map;
const auto* outputs = m_loader->m_model->output_vars_idx();
ret.output_var_list.resize(outputs->size());
for (flatbuffers::uoffset_t i = 0; i < outputs->size(); i++) {
auto out = outputs->Get(i);
auto var = m_id2varnode.at(out->compact_id());
ret.output_var_map[var->name()] = var;
ret.output_var_map_id[out->original_id()] = var;
ret.output_var_list[i] = var;
}
mgb_assert(m_cur_shared_tensor_idx == m_loader->m_shared_tensor_map.size());
return ret;
}
void GraphLoaderOSSV2::OprLoadContextImpl::load_middle_tensor() {
auto model = m_loader->m_model;
if (model->middle_tensors()) {
for (unsigned int i = 0; i < m_loader->m_model->middle_tensors()->size(); i++) {
m_middle_tensors.push_back(model->middle_tensors()->Get(i));
}
}
}
GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool rewind) {
mgb_assert(m_file);
m_cur_load_config = &config;
if (rewind) {
m_file->rewind();
}
// Read fbs::Graph
uint32_t size;
m_file->read(&size, sizeof(size));
m_model_buf = m_file->read_shared(size);
mgb_throw_if(
!fbs::v2::ModelBufferHasIdentifier(m_model_buf.data()), SerializationError,
"invalid fbs model");
{
flatbuffers::Verifier verifier(
static_cast<const uint8_t*>(m_model_buf.data()), m_model_buf.size());
mgb_throw_if(
!fbs::v2::VerifyModelBuffer(verifier), SerializationError,
"model verification failed (invalid or corrupted model?)");
}
m_model = fbs::v2::GetModel(m_model_buf.data());
m_mgb_version = m_model->mge_version();
if (m_model->mge_version() > MGB_VERSION) {
mgb_log_warn(
"loading model from future runtime: version=%u "
"model_version=%u",
MGB_VERSION, m_model->mge_version());
}
if (m_shared_tensor_map.empty()) {
m_shared_tensor_map.resize(m_model->nr_shared_tensor());
} else {
mgb_assert(m_shared_tensor_map.size() == m_model->nr_shared_tensor());
}
OprLoadContextImpl ctx{this, m_model->mge_version()};
ctx.load_middle_tensor();
auto metadata = ctx.load_metadata();
auto result = ctx.load_oprs();
result.metadata = metadata;
m_model_loaded = true;
result.graph_compile_ahead();
return result;
}
std::unique_ptr<GraphDumper> make_fbs_v2_dumper(std::unique_ptr<OutputFile> file) {
return std::make_unique<GraphDumperOSSV2>(std::move(file));
}
std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file) {
return std::make_unique<GraphLoaderOSSV2>(std::move(file));
}
bool is_fbs_v2_file(InputFile& file) {
constexpr size_t identifier_length = 25;
char identifier[identifier_length];
file.read(identifier, identifier_length);
file.skip(-identifier_length);
//! skip the size in prefix of the file
return fbs::v2::ModelBufferHasIdentifier(identifier + sizeof(uint32_t));
}
} // namespace serialization
} // namespace mgb
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -5,6 +5,7 @@ namespace serialization { ...@@ -5,6 +5,7 @@ namespace serialization {
enum class GraphDumpFormat { enum class GraphDumpFormat {
FLATBUFFERS, FLATBUFFERS,
FLATBUFFERS_V2,
}; };
} // namespace serialization } // namespace serialization
......
...@@ -20,8 +20,12 @@ class FlatBufferBuilder; ...@@ -20,8 +20,12 @@ class FlatBufferBuilder;
} // namespace flatbuffers } // namespace flatbuffers
namespace mgb { namespace mgb {
namespace serialization { constexpr uint8_t CURRENT_VERSION = 2u;
constexpr uint8_t BEGIN_VERSION = 0u;
constexpr uint8_t VERSION_1 = 1u;
constexpr uint8_t VERSION_2 = 2u;
namespace serialization {
namespace fbs { namespace fbs {
template <typename T> template <typename T>
struct OperatorParamTraits; struct OperatorParamTraits;
...@@ -187,6 +191,9 @@ class OprLoadContext : public UserDataContainer::UserData { ...@@ -187,6 +191,9 @@ class OprLoadContext : public UserDataContainer::UserData {
friend class OprLoadContextRawPOD; friend class OprLoadContextRawPOD;
friend class OprLoadContextFlatBuffers; friend class OprLoadContextFlatBuffers;
protected:
virtual ~OprLoadContext() = default;
public: public:
//! get current computing graph //! get current computing graph
virtual ComputingGraph& graph() = 0; virtual ComputingGraph& graph() = 0;
...@@ -224,6 +231,12 @@ public: ...@@ -224,6 +231,12 @@ public:
*/ */
virtual SharedBuffer load_shared_buf_with_len() = 0; virtual SharedBuffer load_shared_buf_with_len() = 0;
/*!
* \brief get the serialization data of the current opr
*
*/
virtual const void* get_current_opr_data() { return nullptr; };
/*! /*!
* \brief read a param and check that tag matches * \brief read a param and check that tag matches
*/ */
......
#pragma once
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/serialization/batched_device_value_loader.h"
#include "megbrain/serialization/internal/schema_v2_generated.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/serializer.h"
#define CAST_TO_FBS_V2_CTX(cvt) static_cast<GraphLoaderOSSV2::OprLoadContextImpl&>(ctx)
namespace mgb {
namespace serialization {
class GraphDumperOSSV2 final : public GraphDumper, OprDumpContextFlatBuffers {
const std::unique_ptr<OutputFile> m_file;
flatbuffers::FlatBufferBuilder m_builder;
DumpConfig m_config;
DumpResult m_cur_rst;
size_t m_nr_shared_tensor;
std::vector<std::pair<cg::OperatorNodeBase*, const OprRegistryV2*>> m_oprs_to_dump;
ThinHashMap<VarNode*, VarNode*> m_var_remove_in_dump;
//! set of output vars specified by user
ThinHashSet<VarNode*> m_output_vars;
std::unordered_set<std::string> m_used_input_names, m_used_param_names;
//! current opr to be dumped
cg::OperatorNodeBase* m_cur_opr = nullptr;
// Will be filled in dump_tensor
std::vector<flatbuffers::Offset<fbs::v2::Tensor>> m_cur_opr_tensor;
std::vector<flatbuffers::Offset<fbs::v2::Blob>> m_blobs;
std::vector<fbs::v2::OperatorParam> m_cur_opr_param_type;
std::vector<flatbuffers::Offset<void>> m_cur_opr_param;
std::vector<flatbuffers::Offset<fbs::v2::MiddleTensor>> m_model_middle_tensors;
ThinHashMap<VarNode*, size_t> m_var2midtensor_id;
SymbolVarArray converter_all_opr_to_compatiable(const SymbolVarArray& output_vars);
void init_oprs_to_dump(const SymbolVarArray& endpoints);
flatbuffers::Offset<fbs::v2::Metadata> build_metadata(const Metadata& metadata);
flatbuffers::Offset<fbs::v2::Operator> build_single_opr(
cg::OperatorNodeBase* opr, const OprRegistryV2* registry);
flatbuffers::Offset<fbs::DType> build_dtype(DType dtype);
public:
GraphDumperOSSV2(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {}
DumpResult dump(
const SymbolVarArray& output_vars, const DumpConfig& config = {},
const Metadata& metadata = {}) override;
const GraphDumpConfig& config() const override { return m_config; }
void dump_tensor(
const std::string& name, const HostTensorND& tensor,
TensorWriteMethod method) override;
void append_param(uint32_t type, uint32_t value) override {
static_assert(
std::is_same<uint32_t, flatbuffers::uoffset_t>::value,
"append_param depends on uoffset_t being uint32_t");
static_assert(
std::is_standard_layout<flatbuffers::Offset<void>>::value,
"append_param depends on flatbuffers::Offset having "
"standard memory layout");
mgb_assert(type != fbs::v2::OperatorParam_NONE);
m_cur_opr_param_type.emplace_back(static_cast<fbs::v2::OperatorParam>(type));
m_cur_opr_param.emplace_back(value);
}
flatbuffers::FlatBufferBuilder& builder() override { return m_builder; }
void dump_buf_with_len(const void* data, uint32_t size) override;
GraphDumpFormat format() const override { return GraphDumpFormat::FLATBUFFERS_V2; }
flatbuffers::Offset<fbs::v2::MiddleTensor> build_middle_tensor(const SymbolVar var);
flatbuffers::Offset<fbs::v2::OutputVar> build_output_var(const SymbolVar var);
flatbuffers::Offset<void> build_tensor_format(const TensorLayout::Format& format);
void set_current_opr(cg::OperatorNodeBase* cur_opr) { m_cur_opr = cur_opr; }
};
// ----------------------------- Loader --------------------------------------
class GraphLoaderOSSV2 final : public GraphLoader {
const LoadConfig* m_cur_load_config = nullptr;
std::unique_ptr<InputFile> m_file;
SharedBuffer m_model_buf{{}, 0};
const fbs::v2::Model* m_model;
SharedTensorIDMap m_shared_tensor_map;
uint32_t m_mgb_version = 0;
bool m_model_loaded = false;
void verify();
public:
class OprLoadContextImpl;
friend class OprLoadContextImpl;
GraphLoaderOSSV2(std::unique_ptr<InputFile> input_file)
: m_file{std::move(input_file)} {}
std::unique_ptr<InputFile> reset_file(std::unique_ptr<InputFile> file) override {
file.swap(m_file);
return file;
}
LoadResult load(const LoadConfig& config, bool rewind) override;
const SharedTensorIDMap& shared_tensor_id_map() const override {
mgb_assert(m_model_loaded, "graph not loaded yet");
return m_shared_tensor_map;
}
GraphDumpFormat format() const override { return GraphDumpFormat::FLATBUFFERS_V2; }
};
class GraphLoaderOSSV2::OprLoadContextImpl final : public OprLoadContextFlatBuffers {
GraphLoaderOSSV2* const m_loader;
size_t m_cur_shared_tensor_idx = 0;
std::shared_ptr<ComputingGraph> m_graph;
LoadResult::TensorMap m_tensor_map;
VarNodeArray m_id2varnode;
std::vector<const fbs::v2::MiddleTensor*> m_middle_tensors;
BatchedDeviceValueLoader m_device_value_loader;
const fbs::v2::Operator* m_current_opr;
size_t m_cur_opr_tensor_cnt;
size_t m_cur_opr_blob_cnt;
size_t m_cur_opr_param_cnt;
public:
ComputingGraph& graph() override { return *m_graph; }
const GraphLoadConfig& config() const override {
return *m_loader->m_cur_load_config;
}
std::shared_ptr<HostTensorND> load_tensor() override;
std::shared_ptr<DeviceTensorND> load_tensor_shared() override;
void load_single_opr(const fbs::v2::Operator* opr);
OprLoadContextImpl(GraphLoaderOSSV2* loader, uint32_t version)
: OprLoadContextFlatBuffers(version), m_loader{loader} {
m_graph = loader->m_cur_load_config->comp_graph;
if (!m_graph) {
m_graph = ComputingGraph::make();
}
auto maker = [this]() {
return std::shared_ptr<OprLoadContext>{
std::shared_ptr<OprLoadContext>{}, this};
};
auto got = m_graph->options().user_data.get_user_data_or_create<OprLoadContext>(
maker);
mgb_assert(got == this);
}
~OprLoadContextImpl() noexcept {
auto nr = m_graph->options().user_data.pop_user_data<OprLoadContext>();
mgb_assert(nr == 1);
}
Metadata load_metadata();
LoadResult load_oprs();
CompNode load_comp_node(const fbs::v2::CompNode* comp_node);
void load_middle_tensor();
const void* get_next_param(uint32_t enumv) override {
auto type = static_cast<fbs::v2::OperatorParam>(enumv);
if (m_cur_opr_param_cnt == 0) {
m_cur_opr_param_cnt++;
if (m_current_opr->param_type() == type) {
return m_current_opr->param();
} else {
mgb_throw(
SerializationError,
"The param type is not match when load the opr.");
}
}
mgb_throw(
SerializationError,
"When load multi param in one Operator, please use read_param(index) "
"interface. ");
}
std::string load_buf_with_len() override {
mgb_assert(
m_current_opr->custom_data() &&
m_cur_opr_blob_cnt < m_current_opr->custom_data()->size());
auto blob = m_current_opr->custom_data()->Get(m_cur_opr_blob_cnt++);
mgb_assert(blob && blob->data());
auto data = blob->data()->data();
return {reinterpret_cast<const char*>(data), blob->data()->size()};
}
SharedBuffer load_shared_buf_with_len() override {
mgb_assert(
m_current_opr->custom_data() &&
m_cur_opr_blob_cnt < m_current_opr->custom_data()->size());
auto blob = m_current_opr->custom_data()->Get(m_cur_opr_blob_cnt++);
mgb_assert(blob && blob->data());
auto size = blob->data()->size();
std::shared_ptr<uint8_t> shptr{
new uint8_t[size], [](uint8_t* p) { delete[] p; }};
memcpy(shptr.get(), blob->data()->data(), size);
return {std::move(shptr), size};
}
const void* get_current_opr_data() override {
return reinterpret_cast<const void*>(m_current_opr);
}
template <class T>
T read_param(int index) {
using SourceType = typename fbs::ParamConverter<T>::FlatBufferType;
auto enumv = fbs::OperatorParamTraits<SourceType>::enum_value;
auto type = static_cast<fbs::v2::OperatorParam>(enumv);
if (index == 0) {
mgb_assert(
m_current_opr->param_type() == type,
"Load param error, the param type is not right.");
return fbs::ParamConverter<T>::to_param(
static_cast<const SourceType*>(m_current_opr->param()));
} else {
int addition_index = index - 1;
if (addition_index >=
static_cast<int>(m_current_opr->additional_params()->size())) {
mgb_log_warn(
"Model has no addition param of index %d, just construct a "
"default one.",
addition_index);
} else {
mgb_assert(
m_current_opr->additional_params_type()->Get(addition_index) ==
type,
"Load param error, the addition param type is not right.");
return fbs::ParamConverter<T>::to_param(static_cast<const SourceType*>(
m_current_opr->additional_params()->Get(addition_index)));
}
}
}
};
} // namespace serialization
} // namespace mgb
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册