From accb2d8d47732c71493623e1f6805720005f32fd Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 29 Sep 2021 17:15:55 +0800 Subject: [PATCH] fix(mgb/serialize): fix flatbuffer compatibility issues GitOrigin-RevId: e4771d6bc43a987a7fe725b5949b77da8769815d --- dnn/scripts/gen_flatbuffers_converter.py | 3 -- dnn/scripts/gen_flatbuffers_schema.py | 3 -- dnn/scripts/gen_tablegen.py | 3 -- src/opr/impl/basic_arith.sereg.h | 44 ++++++++++++++++- src/opr/test/blas.cpp | 37 +++++++++++++- src/serialization/impl/serializer_oss.cpp | 48 +++++++++++++++---- .../include/megbrain/serialization/sereg.h | 2 +- 7 files changed, 120 insertions(+), 20 deletions(-) diff --git a/dnn/scripts/gen_flatbuffers_converter.py b/dnn/scripts/gen_flatbuffers_converter.py index 45e806fe..14db38f9 100755 --- a/dnn/scripts/gen_flatbuffers_converter.py +++ b/dnn/scripts/gen_flatbuffers_converter.py @@ -33,9 +33,6 @@ class ConverterWriter(IndentWriterBase): self._last_param = p self._param_fields = [] self._fb_fields = ["builder"] - if p.is_legacy: - self._skip_current_param = True - return self._write("template<>\nstruct ParamConverter {", p.name, indent=1) self._write("using MegDNNType = megdnn::param::%s;", p.name) diff --git a/dnn/scripts/gen_flatbuffers_schema.py b/dnn/scripts/gen_flatbuffers_schema.py index d6165d0d..11805b14 100755 --- a/dnn/scripts/gen_flatbuffers_schema.py +++ b/dnn/scripts/gen_flatbuffers_schema.py @@ -80,9 +80,6 @@ class FlatBuffersWriter(IndentWriterBase): def _on_param_begin(self, p): self._last_param = p self._cur_const_val = {} - if p.is_legacy: - self._skip_current_param = True - return self._write_doc(p.name) self._write("table %s {", p.name, indent=1) diff --git a/dnn/scripts/gen_tablegen.py b/dnn/scripts/gen_tablegen.py index 911cf749..4de6eb65 100755 --- a/dnn/scripts/gen_tablegen.py +++ b/dnn/scripts/gen_tablegen.py @@ -52,9 +52,6 @@ class ConverterWriter(IndentWriterBase): def _on_param_begin(self, p): self._last_param = p - if p.is_legacy: - self._skip_current_param = True - return self._packed = True self._current_tparams = [] self._const = set() diff --git a/src/opr/impl/basic_arith.sereg.h b/src/opr/impl/basic_arith.sereg.h index aced1fed..b7746741 100644 --- a/src/opr/impl/basic_arith.sereg.h +++ b/src/opr/impl/basic_arith.sereg.h @@ -62,6 +62,37 @@ struct PersistentAddUpdateParam { } // namespace opr_add_update +// Old SerializedDType used in MegBrain 7.22.0 - 7.23.1 +// Should be kept as-is even if there are new dtypes. +struct SerializedDTypeV1 { + static constexpr uint32_t TAG = megdnn::param::FakeSerializedDType::TAG; + DTypeEnum enumv; + union { + megdnn::DTypeParam Quantized8Asymm; + megdnn::DTypeParam QuantizedS8; + megdnn::DTypeParam QuantizedS32; + } param; + + operator DType() const { + switch (enumv) { +#define cb(_dt) \ + case DTypeEnum::_dt: \ + return DType::from_enum(enumv); + MEGDNN_FOREACH_DTYPE_NAME(cb) +#undef cb + case DTypeEnum::Quantized8Asymm: + return dtype::Quantized8Asymm{param.Quantized8Asymm}; + case DTypeEnum::QuantizedS8: + return dtype::QuantizedS8{param.QuantizedS8}; + case DTypeEnum::QuantizedS32: + return dtype::QuantizedS32{param.QuantizedS32}; + default: + mgb_assert( + false, "unexpected old serialized dtype: invalid enumv %d", + static_cast(enumv)); + } + } +}; template <> struct OprPersistentParam { using Param = opr_add_update::PersistentAddUpdateParam; @@ -104,7 +135,18 @@ struct ParamConverter { return fbs::intl::build_dtype(builder, dtype); } }; -} // namespace fbs +template <> +struct ParamConverter { + using FlatBufferType = SerializedDTypeV1; + static SerializedDTypeV1 to_param(const FlatBufferType* fb) { + mgb_assert( + false, + "You are calling SerializedDTypeV1 in flatbuffer, you should not call " + "here, this code is just to avoid compiling errors, but not be used in " + "flatbuffer."); + } +}; +}; // namespace fbs #endif template <> diff --git a/src/opr/test/blas.cpp b/src/opr/test/blas.cpp index 7be6a410..e94e96a6 100644 --- a/src/opr/test/blas.cpp +++ b/src/opr/test/blas.cpp @@ -16,6 +16,7 @@ #include "megbrain/opr/io.h" #include "megbrain/opr/tensor_gen.h" #include "megbrain/opr/tensor_manip.h" +#include "megbrain/serialization/serializer.h" #include "megbrain/test/autocheck.h" #include "megbrain/test/helper.h" #include "megbrain/test/megdnn_helper.h" @@ -907,5 +908,39 @@ TEST(TestOprBlas, MatrixMulExePolicy) { } #endif +#if MGB_ENABLE_FBS_SERIALIZATION +TEST(TestOprDNN, MatrixMulSerialization) { + using namespace serialization; + + auto fname = output_file("MatrixMulSerializationTest"); + auto dump = [&]() { + opr::MatrixMul::Param param; + + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + HostTensorND a_host{cn, {24, 24}, dtype::Float32()}; + HostTensorND b_host{cn, {24, 24}, dtype::Float32()}; + auto a = opr::ImmutableTensor::make(*graph, a_host); + auto b = opr::ImmutableTensor::make(*graph, b_host); + auto opr = opr::MatrixMul::make(a, b, param, {}); + auto dumper = GraphDumper::make( + OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS); + auto rst = dumper->dump({opr}); + ASSERT_EQ(rst.outputs.size(), 1u); + }; + + auto load = [&]() { + auto loader = GraphLoader::make( + InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS); + auto rst = loader->load(); + ASSERT_EQ(rst.output_var_list.size(), 1u); + auto opr = rst.output_var_list[0].node()->owner_opr(); + ASSERT_TRUE(opr->same_type()); + }; + + dump(); + load(); +} +#endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -// +// \ No newline at end of file diff --git a/src/serialization/impl/serializer_oss.cpp b/src/serialization/impl/serializer_oss.cpp index b2430237..b2aa7370 100644 --- a/src/serialization/impl/serializer_oss.cpp +++ b/src/serialization/impl/serializer_oss.cpp @@ -47,7 +47,13 @@ namespace { constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH; -constexpr uint32_t MGB_MAGIC = 0x5342474D; +constexpr uint32_t MGB_MAGIC = 0x4342474D; +// In order to maintain compatibility and to allow old models to be loaded, we keep +// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC) +constexpr uint32_t MAGIC_V0 = 0x5342474D; +// Used to judge whether Magic is old or new, the new magic(MGB_MAGIC) is true and the +// old magic(MAGIC_V0) is false. +bool magic_compare = true; template bool contains_any_in_set(const SmallVector& list, const ThinHashSet& set) { @@ -79,6 +85,18 @@ void check_tensor_value_valid(const std::string& name, const HostTensorND& tenso } } +//! feature bits for backward compatibility; default value should be 0 +struct FeatureBits64 { + //! reserved for new fields + uint64_t : 64; + static void write(OutputFile& fout) { + static_assert(sizeof(FeatureBits64) == 8, "bad feature bits"); + FeatureBits64 fb64; + memset(&fb64, 0, sizeof(fb64)); + fout.write(&fb64, 8); + } +}; + } // namespace namespace mgb { @@ -266,7 +284,7 @@ flatbuffers::Offset GraphDumperOSS::build_single_opr( } fbs::OperatorBuilder builder(m_builder); - builder.add_type_id(registry->unversioned_type_id); + builder.add_type_id(registry->persist_type_id); builder.add_inputs(inputs); if (m_config.keep_opr_priority) { builder.add_priority(opr->node_prop().attribute().priority); @@ -322,6 +340,8 @@ GraphDumper::DumpResult GraphDumperOSS::dump( uint32_t magic = MGB_MAGIC; m_file->write(&magic, sizeof(magic)); + // write FeatureBits + FeatureBits64::write(*m_file); // Padding uint32_t reserved = 0; m_file->write(&reserved, sizeof(reserved)); @@ -459,6 +479,7 @@ void GraphDumperOSS::dump_buf_with_len(const void* data, uint32_t size) { class GraphLoaderOSS final : public GraphLoader { const LoadConfig* m_cur_load_config = nullptr; std::unique_ptr m_file; + FeatureBits64 m_feature_bits; SharedBuffer m_graph_buf{{}, 0}; const fbs::Graph* m_graph; SharedTensorIDMap m_shared_tensor_map; @@ -754,8 +775,12 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(const fbs::Operator* fb } config.comp_node_arr(comp_node_arr); } - - auto registry = OprRegistry::find_by_unversioned_id(fbopr->type_id()); + const OprRegistry* registry; + if (magic_compare) { + registry = OprRegistry::find_by_id(fbopr->type_id()); + } else { + registry = OprRegistry::find_by_unversioned_id(fbopr->type_id()); + } mgb_throw_if( !registry, SerializationError, "failed to find opr with type %s, use python env " @@ -841,10 +866,17 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi uint32_t magic; m_file->read(&magic, sizeof(magic)); mgb_throw_if( - magic != MGB_MAGIC, SerializationError, - "wrong magic: wanted %#08x, actual %#08x (not a invalid fbs " + (magic != MGB_MAGIC) && (magic != MAGIC_V0), SerializationError, + "wrong magic: wanted %#08x or %#08x, actual %#08x (not a invalid fbs " "model?)", - MGB_MAGIC, magic); + MGB_MAGIC, MAGIC_V0, magic); + if (magic == MGB_MAGIC) { + // read FeatureBits + magic_compare = true; + m_file->read(&m_feature_bits, 8); + } else { + magic_compare = false; + } m_file->skip(4); uint64_t offset_to_fbs; @@ -929,7 +961,7 @@ bool is_fbs_file(InputFile& file) { uint64_t magic_with_reserved = 0; file.read(&magic_with_reserved, sizeof(magic_with_reserved)); file.skip(-sizeof(magic_with_reserved)); - return magic_with_reserved == MGB_MAGIC; + return (magic_with_reserved == MGB_MAGIC) || (magic_with_reserved == MAGIC_V0); } } // namespace serialization diff --git a/src/serialization/include/megbrain/serialization/sereg.h b/src/serialization/include/megbrain/serialization/sereg.h index c1ee6cf9..348a7125 100644 --- a/src/serialization/include/megbrain/serialization/sereg.h +++ b/src/serialization/include/megbrain/serialization/sereg.h @@ -199,7 +199,7 @@ struct IsComplete : std::true_type {}; static ser::OprWithOutputAccessor compat_loader( \ ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ const mgb::cg::OperatorNodeConfig& config) { \ - auto&& ctx_ = static_cast(ctx); \ + auto&& ctx_ = static_cast(ctx); \ return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), _accessor); \ } \ static void entry() { \ -- GitLab