提交 a694fb33 编写于 作者: M Megvii Engine Team

feat(serialization): implement the new serialization format

GitOrigin-RevId: 00f87f7ccdae7d313d8a97108c3ae712b1c27da3
上级 ca4a5da0
#include "batched_device_value_loader.h" #include "megbrain/serialization/batched_device_value_loader.h"
#include "megbrain/utils/arith_helper.h" #include "megbrain/utils/arith_helper.h"
namespace mgb { namespace mgb {
......
...@@ -57,7 +57,11 @@ GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() { ...@@ -57,7 +57,11 @@ GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() {
} }
std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file); std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file);
std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file); std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file);
std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file);
std::unique_ptr<GraphDumper> make_fbs_v2_dumper(std::unique_ptr<OutputFile> file);
bool is_fbs_file(InputFile& file); bool is_fbs_file(InputFile& file);
bool is_fbs_v2_file(InputFile& file);
bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) { bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) {
#if MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
...@@ -73,6 +77,11 @@ std::unique_ptr<GraphDumper> GraphDumper::make( ...@@ -73,6 +77,11 @@ std::unique_ptr<GraphDumper> GraphDumper::make(
case GraphDumpFormat::FLATBUFFERS: case GraphDumpFormat::FLATBUFFERS:
#if MGB_ENABLE_FBS_SERIALIZATION #if MGB_ENABLE_FBS_SERIALIZATION
return make_fbs_dumper(std::move(file)); return make_fbs_dumper(std::move(file));
#endif
MGB_FALLTHRU
case GraphDumpFormat::FLATBUFFERS_V2:
#if MGB_ENABLE_FBS_SERIALIZATION
return make_fbs_v2_dumper(std::move(file));
#endif #endif
MGB_FALLTHRU MGB_FALLTHRU
default: default:
...@@ -87,6 +96,11 @@ std::unique_ptr<GraphLoader> GraphLoader::make( ...@@ -87,6 +96,11 @@ std::unique_ptr<GraphLoader> GraphLoader::make(
case GraphDumpFormat::FLATBUFFERS: case GraphDumpFormat::FLATBUFFERS:
#if MGB_ENABLE_FBS_SERIALIZATION #if MGB_ENABLE_FBS_SERIALIZATION
return make_fbs_loader(std::move(file)); return make_fbs_loader(std::move(file));
#endif
MGB_FALLTHRU
case GraphDumpFormat::FLATBUFFERS_V2:
#if MGB_ENABLE_FBS_SERIALIZATION
return make_fbs_v2_loader(std::move(file));
#endif #endif
MGB_FALLTHRU MGB_FALLTHRU
default: default:
...@@ -100,6 +114,9 @@ Maybe<GraphDumpFormat> GraphLoader::identify_graph_dump_format(InputFile& file) ...@@ -100,6 +114,9 @@ Maybe<GraphDumpFormat> GraphLoader::identify_graph_dump_format(InputFile& file)
if (is_fbs_file(file)) { if (is_fbs_file(file)) {
return GraphDumpFormat::FLATBUFFERS; return GraphDumpFormat::FLATBUFFERS;
} }
if (is_fbs_v2_file(file)) {
return GraphDumpFormat::FLATBUFFERS_V2;
}
#endif #endif
return {}; return {};
} }
......
...@@ -11,17 +11,16 @@ ...@@ -11,17 +11,16 @@
*/ */
#if MGB_ENABLE_FBS_SERIALIZATION #if MGB_ENABLE_FBS_SERIALIZATION
#include "batched_device_value_loader.h"
#include "megbrain/graph/exc_extra_info.h" #include "megbrain/graph/exc_extra_info.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
#include "megbrain/serialization/batched_device_value_loader.h"
#include "megbrain/serialization/helper.h" #include "megbrain/serialization/helper.h"
#include "megbrain/serialization/internal/flatbuffers_helper.h" #include "megbrain/serialization/internal/flatbuffers_helper.h"
#include "megbrain/serialization/internal/schema_generated.h" #include "megbrain/serialization/internal/schema_generated.h"
#include "megbrain/serialization/metadata.h" #include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/opr_load_dump.h" #include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/serializer.h" #include "megbrain/serialization/serializer.h"
#include "megbrain/version.h" #include "serializer_oss_common.h"
#include <flatbuffers/flatbuffers.h> #include <flatbuffers/flatbuffers.h>
...@@ -33,47 +32,8 @@ using namespace mgb; ...@@ -33,47 +32,8 @@ using namespace mgb;
using namespace mgb::serialization; using namespace mgb::serialization;
namespace { namespace {
constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH;
constexpr uint32_t MGB_MAGIC = 0x4342474D;
// In order to maintain compatibility and to allow old models to be loaded, we keep
// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC)
constexpr uint32_t MAGIC_V0 = 0x5342474D;
// Used to judge whether Magic is old or new, the new magic(MGB_MAGIC) is true and the
// old magic(MAGIC_V0) is false.
bool magic_compare = true; bool magic_compare = true;
template <typename T>
bool contains_any_in_set(const SmallVector<T>& list, const ThinHashSet<T>& set) {
for (const auto& x : list) {
if (set.count(x)) {
return true;
}
}
return false;
}
void check_tensor_value_valid(const std::string& name, const HostTensorND& tensor) {
bool cond_normal = tensor.layout().format.is_default() &&
tensor.layout().is_physical_contiguous();
bool cond_lowbit = tensor.layout().dtype.is_quantized_lowbit() &&
tensor.layout().format.is_lowbit_aligned() &&
tensor.layout().is_contiguous();
mgb_assert(
cond_normal || cond_lowbit, "non-contiguous tensor: name=%s layout=%s",
name.c_str(), tensor.layout().to_string().c_str());
if (tensor.dtype() == dtype::Float32()) {
auto ptr = tensor.ptr<float>();
for (size_t i = 0, it = tensor.shape().total_nr_elems(); i < it; ++i) {
if (!std::isfinite(ptr[i])) {
mgb_log_warn("invalid tensor value in %s: %g", name.c_str(), ptr[i]);
break;
}
}
}
}
//! feature bits for backward compatibility; default value should be 0 //! feature bits for backward compatibility; default value should be 0
struct FeatureBits64 { struct FeatureBits64 {
//! reserved for new fields //! reserved for new fields
...@@ -947,13 +907,6 @@ std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file) { ...@@ -947,13 +907,6 @@ std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file) {
return std::make_unique<GraphLoaderOSS>(std::move(file)); return std::make_unique<GraphLoaderOSS>(std::move(file));
} }
bool is_fbs_file(InputFile& file) {
uint64_t magic_with_reserved = 0;
file.read(&magic_with_reserved, sizeof(magic_with_reserved));
file.skip(-sizeof(magic_with_reserved));
return (magic_with_reserved == MGB_MAGIC) || (magic_with_reserved == MAGIC_V0);
}
} // namespace serialization } // namespace serialization
} // namespace mgb } // namespace mgb
......
#if MGB_ENABLE_FBS_SERIALIZATION
#include "serializer_oss_common.h"
namespace mgb {
namespace serialization {
bool is_fbs_file(InputFile& file) {
//! check whether the model format is flatbuffer v2
uint64_t magic_with_reserved = 0;
file.read(&magic_with_reserved, sizeof(magic_with_reserved));
file.skip(-sizeof(magic_with_reserved));
return (magic_with_reserved == MGB_MAGIC) || (magic_with_reserved == MAGIC_V0);
}
void check_tensor_value_valid(const std::string& name, const HostTensorND& tensor) {
bool cond_normal = tensor.layout().format.is_default() &&
tensor.layout().is_physical_contiguous();
bool cond_lowbit = tensor.layout().dtype.is_quantized_lowbit() &&
tensor.layout().format.is_lowbit_aligned() &&
tensor.layout().is_contiguous();
mgb_assert(
cond_normal || cond_lowbit, "non-contiguous tensor: name=%s layout=%s",
name.c_str(), tensor.layout().to_string().c_str());
if (tensor.dtype() == dtype::Float32()) {
auto ptr = tensor.ptr<float>();
for (size_t i = 0, it = tensor.shape().total_nr_elems(); i < it; ++i) {
if (!std::isfinite(ptr[i])) {
mgb_log_warn("invalid tensor value in %s: %g", name.c_str(), ptr[i]);
break;
}
}
}
}
} // namespace serialization
} // namespace mgb
#endif
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/serialization/serializer.h"
#include "megbrain/version.h"
namespace mgb {
namespace serialization {
constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH;
constexpr uint32_t MGB_MAGIC = 0x4342474D;
// In order to maintain compatibility and to allow old models to be loaded, we keep
// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC)
constexpr uint32_t MAGIC_V0 = 0x5342474D;
void check_tensor_value_valid(const std::string& name, const HostTensorND& tensor);
template <typename T>
bool contains_any_in_set(const SmallVector<T>& list, const ThinHashSet<T>& set) {
for (const auto& x : list) {
if (set.count(x)) {
return true;
}
}
return false;
}
} // namespace serialization
} // namespace mgb
#endif
此差异已折叠。
...@@ -5,6 +5,7 @@ namespace serialization { ...@@ -5,6 +5,7 @@ namespace serialization {
enum class GraphDumpFormat { enum class GraphDumpFormat {
FLATBUFFERS, FLATBUFFERS,
FLATBUFFERS_V2,
}; };
} // namespace serialization } // namespace serialization
......
...@@ -20,8 +20,12 @@ class FlatBufferBuilder; ...@@ -20,8 +20,12 @@ class FlatBufferBuilder;
} // namespace flatbuffers } // namespace flatbuffers
namespace mgb { namespace mgb {
namespace serialization { constexpr uint8_t CURRENT_VERSION = 2u;
constexpr uint8_t BEGIN_VERSION = 0u;
constexpr uint8_t VERSION_1 = 1u;
constexpr uint8_t VERSION_2 = 2u;
namespace serialization {
namespace fbs { namespace fbs {
template <typename T> template <typename T>
struct OperatorParamTraits; struct OperatorParamTraits;
...@@ -187,6 +191,9 @@ class OprLoadContext : public UserDataContainer::UserData { ...@@ -187,6 +191,9 @@ class OprLoadContext : public UserDataContainer::UserData {
friend class OprLoadContextRawPOD; friend class OprLoadContextRawPOD;
friend class OprLoadContextFlatBuffers; friend class OprLoadContextFlatBuffers;
protected:
virtual ~OprLoadContext() = default;
public: public:
//! get current computing graph //! get current computing graph
virtual ComputingGraph& graph() = 0; virtual ComputingGraph& graph() = 0;
...@@ -224,6 +231,12 @@ public: ...@@ -224,6 +231,12 @@ public:
*/ */
virtual SharedBuffer load_shared_buf_with_len() = 0; virtual SharedBuffer load_shared_buf_with_len() = 0;
/*!
* \brief get the serialization data of the current opr
*
*/
virtual const void* get_current_opr_data() { return nullptr; };
/*! /*!
* \brief read a param and check that tag matches * \brief read a param and check that tag matches
*/ */
......
#pragma once
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/serialization/batched_device_value_loader.h"
#include "megbrain/serialization/internal/schema_v2_generated.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/serializer.h"
#define CAST_TO_FBS_V2_CTX(cvt) static_cast<GraphLoaderOSSV2::OprLoadContextImpl&>(ctx)
namespace mgb {
namespace serialization {
class GraphDumperOSSV2 final : public GraphDumper, OprDumpContextFlatBuffers {
const std::unique_ptr<OutputFile> m_file;
flatbuffers::FlatBufferBuilder m_builder;
DumpConfig m_config;
DumpResult m_cur_rst;
size_t m_nr_shared_tensor;
std::vector<std::pair<cg::OperatorNodeBase*, const OprRegistryV2*>> m_oprs_to_dump;
ThinHashMap<VarNode*, VarNode*> m_var_remove_in_dump;
//! set of output vars specified by user
ThinHashSet<VarNode*> m_output_vars;
std::unordered_set<std::string> m_used_input_names, m_used_param_names;
//! current opr to be dumped
cg::OperatorNodeBase* m_cur_opr = nullptr;
// Will be filled in dump_tensor
std::vector<flatbuffers::Offset<fbs::v2::Tensor>> m_cur_opr_tensor;
std::vector<flatbuffers::Offset<fbs::v2::Blob>> m_blobs;
std::vector<fbs::v2::OperatorParam> m_cur_opr_param_type;
std::vector<flatbuffers::Offset<void>> m_cur_opr_param;
std::vector<flatbuffers::Offset<fbs::v2::MiddleTensor>> m_model_middle_tensors;
ThinHashMap<VarNode*, size_t> m_var2midtensor_id;
SymbolVarArray converter_all_opr_to_compatiable(const SymbolVarArray& output_vars);
void init_oprs_to_dump(const SymbolVarArray& endpoints);
flatbuffers::Offset<fbs::v2::Metadata> build_metadata(const Metadata& metadata);
flatbuffers::Offset<fbs::v2::Operator> build_single_opr(
cg::OperatorNodeBase* opr, const OprRegistryV2* registry);
flatbuffers::Offset<fbs::DType> build_dtype(DType dtype);
public:
GraphDumperOSSV2(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {}
DumpResult dump(
const SymbolVarArray& output_vars, const DumpConfig& config = {},
const Metadata& metadata = {}) override;
const GraphDumpConfig& config() const override { return m_config; }
void dump_tensor(
const std::string& name, const HostTensorND& tensor,
TensorWriteMethod method) override;
void append_param(uint32_t type, uint32_t value) override {
static_assert(
std::is_same<uint32_t, flatbuffers::uoffset_t>::value,
"append_param depends on uoffset_t being uint32_t");
static_assert(
std::is_standard_layout<flatbuffers::Offset<void>>::value,
"append_param depends on flatbuffers::Offset having "
"standard memory layout");
mgb_assert(type != fbs::v2::OperatorParam_NONE);
m_cur_opr_param_type.emplace_back(static_cast<fbs::v2::OperatorParam>(type));
m_cur_opr_param.emplace_back(value);
}
flatbuffers::FlatBufferBuilder& builder() override { return m_builder; }
void dump_buf_with_len(const void* data, uint32_t size) override;
GraphDumpFormat format() const override { return GraphDumpFormat::FLATBUFFERS_V2; }
flatbuffers::Offset<fbs::v2::MiddleTensor> build_middle_tensor(const SymbolVar var);
flatbuffers::Offset<fbs::v2::OutputVar> build_output_var(const SymbolVar var);
flatbuffers::Offset<void> build_tensor_format(const TensorLayout::Format& format);
void set_current_opr(cg::OperatorNodeBase* cur_opr) { m_cur_opr = cur_opr; }
};
// ----------------------------- Loader --------------------------------------
class GraphLoaderOSSV2 final : public GraphLoader {
const LoadConfig* m_cur_load_config = nullptr;
std::unique_ptr<InputFile> m_file;
SharedBuffer m_model_buf{{}, 0};
const fbs::v2::Model* m_model;
SharedTensorIDMap m_shared_tensor_map;
uint32_t m_mgb_version = 0;
bool m_model_loaded = false;
void verify();
public:
class OprLoadContextImpl;
friend class OprLoadContextImpl;
GraphLoaderOSSV2(std::unique_ptr<InputFile> input_file)
: m_file{std::move(input_file)} {}
std::unique_ptr<InputFile> reset_file(std::unique_ptr<InputFile> file) override {
file.swap(m_file);
return file;
}
LoadResult load(const LoadConfig& config, bool rewind) override;
const SharedTensorIDMap& shared_tensor_id_map() const override {
mgb_assert(m_model_loaded, "graph not loaded yet");
return m_shared_tensor_map;
}
GraphDumpFormat format() const override { return GraphDumpFormat::FLATBUFFERS_V2; }
};
class GraphLoaderOSSV2::OprLoadContextImpl final : public OprLoadContextFlatBuffers {
GraphLoaderOSSV2* const m_loader;
size_t m_cur_shared_tensor_idx = 0;
std::shared_ptr<ComputingGraph> m_graph;
LoadResult::TensorMap m_tensor_map;
VarNodeArray m_id2varnode;
std::vector<const fbs::v2::MiddleTensor*> m_middle_tensors;
BatchedDeviceValueLoader m_device_value_loader;
const fbs::v2::Operator* m_current_opr;
size_t m_cur_opr_tensor_cnt;
size_t m_cur_opr_blob_cnt;
size_t m_cur_opr_param_cnt;
public:
ComputingGraph& graph() override { return *m_graph; }
const GraphLoadConfig& config() const override {
return *m_loader->m_cur_load_config;
}
std::shared_ptr<HostTensorND> load_tensor() override;
std::shared_ptr<DeviceTensorND> load_tensor_shared() override;
void load_single_opr(const fbs::v2::Operator* opr);
OprLoadContextImpl(GraphLoaderOSSV2* loader, uint32_t version)
: OprLoadContextFlatBuffers(version), m_loader{loader} {
m_graph = loader->m_cur_load_config->comp_graph;
if (!m_graph) {
m_graph = ComputingGraph::make();
}
auto maker = [this]() {
return std::shared_ptr<OprLoadContext>{
std::shared_ptr<OprLoadContext>{}, this};
};
auto got = m_graph->options().user_data.get_user_data_or_create<OprLoadContext>(
maker);
mgb_assert(got == this);
}
~OprLoadContextImpl() noexcept {
auto nr = m_graph->options().user_data.pop_user_data<OprLoadContext>();
mgb_assert(nr == 1);
}
Metadata load_metadata();
LoadResult load_oprs();
CompNode load_comp_node(const fbs::v2::CompNode* comp_node);
void load_middle_tensor();
const void* get_next_param(uint32_t enumv) override {
auto type = static_cast<fbs::v2::OperatorParam>(enumv);
if (m_cur_opr_param_cnt == 0) {
m_cur_opr_param_cnt++;
if (m_current_opr->param_type() == type) {
return m_current_opr->param();
} else {
mgb_throw(
SerializationError,
"The param type is not match when load the opr.");
}
}
mgb_throw(
SerializationError,
"When load multi param in one Operator, please use read_param(index) "
"interface. ");
}
std::string load_buf_with_len() override {
mgb_assert(
m_current_opr->custom_data() &&
m_cur_opr_blob_cnt < m_current_opr->custom_data()->size());
auto blob = m_current_opr->custom_data()->Get(m_cur_opr_blob_cnt++);
mgb_assert(blob && blob->data());
auto data = blob->data()->data();
return {reinterpret_cast<const char*>(data), blob->data()->size()};
}
SharedBuffer load_shared_buf_with_len() override {
mgb_assert(
m_current_opr->custom_data() &&
m_cur_opr_blob_cnt < m_current_opr->custom_data()->size());
auto blob = m_current_opr->custom_data()->Get(m_cur_opr_blob_cnt++);
mgb_assert(blob && blob->data());
auto size = blob->data()->size();
std::shared_ptr<uint8_t> shptr{
new uint8_t[size], [](uint8_t* p) { delete[] p; }};
memcpy(shptr.get(), blob->data()->data(), size);
return {std::move(shptr), size};
}
const void* get_current_opr_data() override {
return reinterpret_cast<const void*>(m_current_opr);
}
template <class T>
T read_param(int index) {
using SourceType = typename fbs::ParamConverter<T>::FlatBufferType;
auto enumv = fbs::OperatorParamTraits<SourceType>::enum_value;
auto type = static_cast<fbs::v2::OperatorParam>(enumv);
if (index == 0) {
mgb_assert(
m_current_opr->param_type() == type,
"Load param error, the param type is not right.");
return fbs::ParamConverter<T>::to_param(
static_cast<const SourceType*>(m_current_opr->param()));
} else {
int addition_index = index - 1;
if (addition_index >=
static_cast<int>(m_current_opr->additional_params()->size())) {
mgb_log_warn(
"Model has no addition param of index %d, just construct a "
"default one.",
addition_index);
} else {
mgb_assert(
m_current_opr->additional_params_type()->Get(addition_index) ==
type,
"Load param error, the addition param type is not right.");
return fbs::ParamConverter<T>::to_param(static_cast<const SourceType*>(
m_current_opr->additional_params()->Get(addition_index)));
}
}
}
};
} // namespace serialization
} // namespace mgb
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册