From b9a6932341e7513341aa286c787e1c7b672217b9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 12 May 2022 15:42:09 +0800 Subject: [PATCH] feat(imperative): channel default model format to fbs v2 GitOrigin-RevId: 6066516c313553568c706f4b0bf584d747c9dea8 --- imperative/python/megengine/core/tensor/megbrain_graph.py | 1 + imperative/python/megengine/jit/tracing.py | 5 +++-- imperative/python/src/graph_rt.cpp | 5 ++--- src/serialization/include/megbrain/serialization/sereg.h | 1 - 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 5dc0f8f0c..774e9df4b 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -430,6 +430,7 @@ def dump_graph( dump_format_map = { None: None, + "FBS_V2": SerializationFormat.FBS_V2, "FBS": SerializationFormat.FBS, } dump_format = dump_format_map[dump_format] diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 289295c9e..cca234806 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -613,8 +613,9 @@ class trace: resize_input: whether resize input image to fit input var shape. input_transform: a python expression to transform the input data. Example: data / np.std(data) - dump_format: using different dump formats. the open source MegEngine defaults to the FBS - format. internal MegEngine have a choice of FBS and internal proprietary formats + dump_format: using different dump formats. the open source MegEngine + defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose, + internal MegEngine have an other choice of internal proprietary formats Keyword Arguments: diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index beabf78e3..332ccc707 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -308,6 +308,7 @@ void init_graph_rt(py::module m) { py::enum_<_SerializationFormat>(m, "SerializationFormat") .value("FBS", _SerializationFormat::FLATBUFFERS) + .value("FBS_V2", _SerializationFormat::FLATBUFFERS_V2) .export_values(); m.def("optimize_for_inference", @@ -384,11 +385,9 @@ void init_graph_rt(py::module m) { std::optional<_SerializationFormat> dump_format, py::list& stat, py::list& inputs, py::list& outputs, py::list& params) { std::vector buf; - ser::GraphDumpFormat format; + ser::GraphDumpFormat format = ser::GraphDumpFormat::FLATBUFFERS_V2; if (dump_format.has_value()) { format = dump_format.value(); - } else { - format = {}; } auto dumper = ser::GraphDumper::make( ser::OutputFile::make_vector_proxy(&buf), format); diff --git a/src/serialization/include/megbrain/serialization/sereg.h b/src/serialization/include/megbrain/serialization/sereg.h index 0c59d1b53..4849b0451 100644 --- a/src/serialization/include/megbrain/serialization/sereg.h +++ b/src/serialization/include/megbrain/serialization/sereg.h @@ -3,7 +3,6 @@ #include "megbrain/serialization/opr_load_dump.h" #include "megbrain/serialization/opr_registry.h" #include "megbrain/serialization/opr_shallow_copy.h" -#include "megbrain/serialization/oss_opr_load_dump.h" #include "megbrain/utils/hash_ct.h" namespace mgb { -- GitLab