提交 accb2d8d 编写于 作者: M Megvii Engine Team

fix(mgb/serialize): fix flatbuffer compatibility issues

GitOrigin-RevId: e4771d6bc43a987a7fe725b5949b77da8769815d
上级 5b1383e0
......@@ -33,9 +33,6 @@ class ConverterWriter(IndentWriterBase):
self._last_param = p
self._param_fields = []
self._fb_fields = ["builder"]
if p.is_legacy:
self._skip_current_param = True
return
self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {",
p.name, indent=1)
self._write("using MegDNNType = megdnn::param::%s;", p.name)
......
......@@ -80,9 +80,6 @@ class FlatBuffersWriter(IndentWriterBase):
def _on_param_begin(self, p):
self._last_param = p
self._cur_const_val = {}
if p.is_legacy:
self._skip_current_param = True
return
self._write_doc(p.name)
self._write("table %s {", p.name, indent=1)
......
......@@ -52,9 +52,6 @@ class ConverterWriter(IndentWriterBase):
def _on_param_begin(self, p):
self._last_param = p
if p.is_legacy:
self._skip_current_param = True
return
self._packed = True
self._current_tparams = []
self._const = set()
......
......@@ -62,6 +62,37 @@ struct PersistentAddUpdateParam {
} // namespace opr_add_update
// Old SerializedDType used in MegBrain 7.22.0 - 7.23.1
// Should be kept as-is even if there are new dtypes.
struct SerializedDTypeV1 {
static constexpr uint32_t TAG = megdnn::param::FakeSerializedDType::TAG;
DTypeEnum enumv;
union {
megdnn::DTypeParam<dtype::Quantized8Asymm> Quantized8Asymm;
megdnn::DTypeParam<dtype::QuantizedS8> QuantizedS8;
megdnn::DTypeParam<dtype::QuantizedS32> QuantizedS32;
} param;
operator DType() const {
switch (enumv) {
#define cb(_dt) \
case DTypeEnum::_dt: \
return DType::from_enum(enumv);
MEGDNN_FOREACH_DTYPE_NAME(cb)
#undef cb
case DTypeEnum::Quantized8Asymm:
return dtype::Quantized8Asymm{param.Quantized8Asymm};
case DTypeEnum::QuantizedS8:
return dtype::QuantizedS8{param.QuantizedS8};
case DTypeEnum::QuantizedS32:
return dtype::QuantizedS32{param.QuantizedS32};
default:
mgb_assert(
false, "unexpected old serialized dtype: invalid enumv %d",
static_cast<uint32_t>(enumv));
}
}
};
template <>
struct OprPersistentParam<opr::AddUpdate> {
using Param = opr_add_update::PersistentAddUpdateParam;
......@@ -104,7 +135,18 @@ struct ParamConverter<megdnn::DType> {
return fbs::intl::build_dtype(builder, dtype);
}
};
} // namespace fbs
template <>
struct ParamConverter<SerializedDTypeV1> {
using FlatBufferType = SerializedDTypeV1;
static SerializedDTypeV1 to_param(const FlatBufferType* fb) {
mgb_assert(
false,
"You are calling SerializedDTypeV1 in flatbuffer, you should not call "
"here, this code is just to avoid compiling errors, but not be used in "
"flatbuffer.");
}
};
}; // namespace fbs
#endif
template <>
......
......@@ -16,6 +16,7 @@
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h"
......@@ -907,5 +908,39 @@ TEST(TestOprBlas, MatrixMulExePolicy) {
}
#endif
#if MGB_ENABLE_FBS_SERIALIZATION
TEST(TestOprDNN, MatrixMulSerialization) {
using namespace serialization;
auto fname = output_file("MatrixMulSerializationTest");
auto dump = [&]() {
opr::MatrixMul::Param param;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
HostTensorND a_host{cn, {24, 24}, dtype::Float32()};
HostTensorND b_host{cn, {24, 24}, dtype::Float32()};
auto a = opr::ImmutableTensor::make(*graph, a_host);
auto b = opr::ImmutableTensor::make(*graph, b_host);
auto opr = opr::MatrixMul::make(a, b, param, {});
auto dumper = GraphDumper::make(
OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS);
auto rst = dumper->dump({opr});
ASSERT_EQ(rst.outputs.size(), 1u);
};
auto load = [&]() {
auto loader = GraphLoader::make(
InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS);
auto rst = loader->load();
ASSERT_EQ(rst.output_var_list.size(), 1u);
auto opr = rst.output_var_list[0].node()->owner_opr();
ASSERT_TRUE(opr->same_type<opr::MatrixMul>());
};
dump();
load();
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
//
\ No newline at end of file
......@@ -47,7 +47,13 @@ namespace {
constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH;
constexpr uint32_t MGB_MAGIC = 0x5342474D;
constexpr uint32_t MGB_MAGIC = 0x4342474D;
// In order to maintain compatibility and to allow old models to be loaded, we keep
// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC)
constexpr uint32_t MAGIC_V0 = 0x5342474D;
// Used to judge whether Magic is old or new, the new magic(MGB_MAGIC) is true and the
// old magic(MAGIC_V0) is false.
bool magic_compare = true;
template <typename T>
bool contains_any_in_set(const SmallVector<T>& list, const ThinHashSet<T>& set) {
......@@ -79,6 +85,18 @@ void check_tensor_value_valid(const std::string& name, const HostTensorND& tenso
}
}
//! feature bits for backward compatibility; default value should be 0
struct FeatureBits64 {
//! reserved for new fields
uint64_t : 64;
static void write(OutputFile& fout) {
static_assert(sizeof(FeatureBits64) == 8, "bad feature bits");
FeatureBits64 fb64;
memset(&fb64, 0, sizeof(fb64));
fout.write(&fb64, 8);
}
};
} // namespace
namespace mgb {
......@@ -266,7 +284,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
}
fbs::OperatorBuilder builder(m_builder);
builder.add_type_id(registry->unversioned_type_id);
builder.add_type_id(registry->persist_type_id);
builder.add_inputs(inputs);
if (m_config.keep_opr_priority) {
builder.add_priority(opr->node_prop().attribute().priority);
......@@ -322,6 +340,8 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
uint32_t magic = MGB_MAGIC;
m_file->write(&magic, sizeof(magic));
// write FeatureBits
FeatureBits64::write(*m_file);
// Padding
uint32_t reserved = 0;
m_file->write(&reserved, sizeof(reserved));
......@@ -459,6 +479,7 @@ void GraphDumperOSS::dump_buf_with_len(const void* data, uint32_t size) {
class GraphLoaderOSS final : public GraphLoader {
const LoadConfig* m_cur_load_config = nullptr;
std::unique_ptr<InputFile> m_file;
FeatureBits64 m_feature_bits;
SharedBuffer m_graph_buf{{}, 0};
const fbs::Graph* m_graph;
SharedTensorIDMap m_shared_tensor_map;
......@@ -754,8 +775,12 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(const fbs::Operator* fb
}
config.comp_node_arr(comp_node_arr);
}
auto registry = OprRegistry::find_by_unversioned_id(fbopr->type_id());
const OprRegistry* registry;
if (magic_compare) {
registry = OprRegistry::find_by_id(fbopr->type_id());
} else {
registry = OprRegistry::find_by_unversioned_id(fbopr->type_id());
}
mgb_throw_if(
!registry, SerializationError,
"failed to find opr with type %s, use python env "
......@@ -841,10 +866,17 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi
uint32_t magic;
m_file->read(&magic, sizeof(magic));
mgb_throw_if(
magic != MGB_MAGIC, SerializationError,
"wrong magic: wanted %#08x, actual %#08x (not a invalid fbs "
(magic != MGB_MAGIC) && (magic != MAGIC_V0), SerializationError,
"wrong magic: wanted %#08x or %#08x, actual %#08x (not a invalid fbs "
"model?)",
MGB_MAGIC, magic);
MGB_MAGIC, MAGIC_V0, magic);
if (magic == MGB_MAGIC) {
// read FeatureBits
magic_compare = true;
m_file->read(&m_feature_bits, 8);
} else {
magic_compare = false;
}
m_file->skip(4);
uint64_t offset_to_fbs;
......@@ -929,7 +961,7 @@ bool is_fbs_file(InputFile& file) {
uint64_t magic_with_reserved = 0;
file.read(&magic_with_reserved, sizeof(magic_with_reserved));
file.skip(-sizeof(magic_with_reserved));
return magic_with_reserved == MGB_MAGIC;
return (magic_with_reserved == MGB_MAGIC) || (magic_with_reserved == MAGIC_V0);
}
} // namespace serialization
......
......@@ -199,7 +199,7 @@ struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {};
static ser::OprWithOutputAccessor compat_loader( \
ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \
const mgb::cg::OperatorNodeConfig& config) { \
auto&& ctx_ = static_cast<ser::OprLoadContextRawPOD&>(ctx); \
auto&& ctx_ = static_cast<ser::OprLoadContext&>(ctx); \
return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), _accessor); \
} \
static void entry() { \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册