diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 774e9df4b68b3b71e7f23b70ef48350ad0c90050..3053c4342846063e45394972169f01bdb9db71cf 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -367,10 +367,12 @@ def dump_graph( keep_opr_name: bool = False, keep_param_name: bool = False, keep_opr_priority: bool = False, + no_change_graph: bool = False, strip_info_file=None, append_json=False, metadata=None, - dump_format=None + dump_format=None, + model_version: int = 2 ) -> Tuple[bytes, CompGraphDumpResult]: r"""serialize the computing graph of `output_vars` and get byte result. @@ -386,12 +388,22 @@ def dump_graph( keep_param_name: whether to keep param names, so param values can be easily manipulated after loading model keep_opr_priority: whether to keep priority setting for operators + no_change_graph: whether to change the compute graph when dump, for + model compatibility, some operators will convert to its compatible + format in this version. + + * if set False, some operators maybe convert to other operator for + compatibility, all operators will ensure compatibility. + * if set True, no operator will change in the graph when dump. + strip_info_file: a string for path or a file handler. if is not None, then the dump information for code strip would be written to ``strip_info_file`` append_json: will be check when `strip_info_file` is not None. if set true, the information for code strip will be append to strip_info_file. if set false, will rewrite strip_info_file dump_format: using different dump formats. + model_version: the model version of "FBS_V2", begin with version 2, this + works only when dump format is "FBS_V2". Note: The underlying C++ API only accepts a var list. If a dict is given, @@ -441,8 +453,10 @@ def dump_graph( keep_opr_name, keep_param_name, keep_opr_priority, + no_change_graph, metadata, dump_format, + model_version, stat, inputs, outputs, diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index b6f5493484e71fe397d93934f8ef40b6f8659e7f..1d1fc5df70767157ff38f1093ec3db3dd1f3d035 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -549,6 +549,7 @@ class trace: keep_opr_name: bool = False, keep_param_name: bool = False, keep_opr_priority: bool = False, + no_change_graph: bool = False, strip_info_file=None, append_json=False, optimize_for_inference=True, @@ -562,6 +563,7 @@ class trace: resize_input=False, input_transform=None, dump_format: str = None, + model_version: int = 2, **kwargs ): r"""Serializes trace to file system. @@ -583,6 +585,14 @@ class trace: keep_param_name: whether to keep param names, so param values can be easily manipulated after loading model keep_opr_priority: whether to keep priority setting for operators + no_change_graph: whether to change the compute graph when dump, for + model compatibility, some operators will convert to its compatible + format in this version. + + * if set False, some operators maybe convert to other operator for + compatibility, all operators will ensure compatibility. + * if set True, no operator will change in the graph when dump. + strip_info_file: a string for path or a file handler. if is not None, then the dump information for code strip would be written to ``strip_info_file`` append_json: will be check when `strip_info_file` is not None. if set @@ -616,6 +626,9 @@ class trace: dump_format: using different dump formats. the open source MegEngine defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose, internal MegEngine have an other choice of internal proprietary formats + model_version: the model version of FBS_V2, begin with version 2, this + works only when dump format is FBS_V2. + Keyword Arguments: @@ -762,10 +775,12 @@ class trace: keep_opr_name=keep_opr_name, keep_param_name=keep_param_name, keep_opr_priority=keep_opr_priority, + no_change_graph=no_change_graph, strip_info_file=strip_info_file, append_json=append_json, metadata=metadata, dump_format=dump_format, + model_version=model_version, ) file.write(dump_content) diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 332ccc707314381d9b878d508d3f8cae720cebeb..754f3f35e00a4e292c9626d1555cf2a29981a9ea 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -381,20 +381,26 @@ void init_graph_rt(py::module m) { m.def("dump_graph", [](const std::vector& dest_vars, int keep_var_name, bool keep_opr_name, bool keep_param_name, bool keep_opr_priority, - std::optional<_SerializationMetadata> metadata, - std::optional<_SerializationFormat> dump_format, py::list& stat, - py::list& inputs, py::list& outputs, py::list& params) { + bool no_change_graph, std::optional<_SerializationMetadata> metadata, + std::optional<_SerializationFormat> dump_format, + std::optional model_version, py::list& stat, py::list& inputs, + py::list& outputs, py::list& params) { std::vector buf; ser::GraphDumpFormat format = ser::GraphDumpFormat::FLATBUFFERS_V2; + int version = 2; if (dump_format.has_value()) { format = dump_format.value(); } + if (model_version.has_value()) { + version = model_version.value(); + } auto dumper = ser::GraphDumper::make( - ser::OutputFile::make_vector_proxy(&buf), format); + ser::OutputFile::make_vector_proxy(&buf), format, version); SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); ser::GraphDumper::DumpConfig config{ keep_var_name, keep_param_name, keep_opr_priority, keep_opr_name}; + config.no_change_graph = no_change_graph; ser::GraphDumper::DumpResult rst; if (metadata) diff --git a/src/opr/impl/dnn/dnn.sereg.v2.h b/src/opr/impl/dnn/dnn.sereg.v2.h index cb3266730bff92dbde5dc9d1045a673a40c3fcf1..bd8b467d1f7162e3f84edad6f7376740e927afae 100644 --- a/src/opr/impl/dnn/dnn.sereg.v2.h +++ b/src/opr/impl/dnn/dnn.sereg.v2.h @@ -21,6 +21,13 @@ struct OprLoadDumpImplV2 { ctx.write_param(opr.cast_final_safe().param()); } + /** This converter is just a example for Operator serialization compatible, + * Just in this situation: when optimize the softmax Operator by + * fusing the elemwise and reduce to a big Operator, but the whole softmax + * Operator can't be recognized by old version, in order to model + * compatibility the softmax Operator should be covert to elemwise and + * reduce Operators when dump the model + */ static cg::OperatorNodeBase* replace_opr( cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { int32_t axis = opr->cast_final_safe().param().axis; @@ -196,9 +203,11 @@ namespace opr { #define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \ MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION); -SERGE_OPR_V2_CONVERTER( +//! this is just a example for Operator compatibility +/*SERGE_OPR_V2_CONVERTER( Softmax, 1, - (mgb::serialization::OprLoadDumpImplV2::replace_opr)); + (mgb::serialization::OprLoadDumpImplV2::replace_opr));*/ +SERGE_OPR_V2_NO_CONVERTER(Softmax, 1) SERGE_OPR_V2_NO_CONVERTER(ConvBiasForward, 0) SERGE_OPR_V2_NO_CONVERTER(BatchConvBiasForward, 0); diff --git a/src/serialization/impl/serializer.cpp b/src/serialization/impl/serializer.cpp index 17fbc6227d01d5b40b51b4c453c26f250c3866ee..5ec543016313e2bf5a30e25c0254649a60c23935 100644 --- a/src/serialization/impl/serializer.cpp +++ b/src/serialization/impl/serializer.cpp @@ -59,7 +59,8 @@ std::unique_ptr make_fbs_loader(std::unique_ptr file); std::unique_ptr make_fbs_dumper(std::unique_ptr file); std::unique_ptr make_fbs_v2_loader(std::unique_ptr file); -std::unique_ptr make_fbs_v2_dumper(std::unique_ptr file); +std::unique_ptr make_fbs_v2_dumper( + std::unique_ptr file, int version); bool is_fbs_file(InputFile& file); bool is_fbs_v2_file(InputFile& file); @@ -72,7 +73,7 @@ bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) { } std::unique_ptr GraphDumper::make( - std::unique_ptr file, GraphDumpFormat format) { + std::unique_ptr file, GraphDumpFormat format, int version) { switch (format) { case GraphDumpFormat::FLATBUFFERS: #if MGB_ENABLE_FBS_SERIALIZATION @@ -81,7 +82,7 @@ std::unique_ptr GraphDumper::make( MGB_FALLTHRU case GraphDumpFormat::FLATBUFFERS_V2: #if MGB_ENABLE_FBS_SERIALIZATION - return make_fbs_v2_dumper(std::move(file)); + return make_fbs_v2_dumper(std::move(file), version); #endif MGB_FALLTHRU default: diff --git a/src/serialization/impl/serializer_oss_v2.cpp b/src/serialization/impl/serializer_oss_v2.cpp index 4903c933f57a8eb4789347962c5ca268a1d1cbe4..e9d49a1c9ce8a7802b3d3ea882ec889786073136 100644 --- a/src/serialization/impl/serializer_oss_v2.cpp +++ b/src/serialization/impl/serializer_oss_v2.cpp @@ -194,7 +194,7 @@ void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) { } } else { auto registry = OprRegistryV2::versioned_find_by_typeinfo( - opr->dyn_typeinfo(), CURRENT_VERSION); + opr->dyn_typeinfo(), m_version); if (!registry || !registry->dumper) { mgb_throw( cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make, @@ -202,6 +202,9 @@ void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) { "operator %s", opr->dyn_typeinfo()->name); } + mgb_assert( + registry->version <= m_version, + "The Operator version should less than model version"); m_oprs_to_dump.emplace_back(opr, registry); } }; @@ -352,7 +355,10 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( const Metadata& metadata) { mgb_throw_if(output_vars.empty(), SerializationError, "Can't dump empty graph"); - auto&& new_output_vars = converter_all_opr_to_compatiable(output_vars); + auto new_output_vars = output_vars; + if (!config.no_change_graph) { + new_output_vars = converter_all_opr_to_compatiable(output_vars); + } auto begin_pos = m_file->tell(); m_config = config; @@ -416,6 +422,7 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( fbs::v2::ModelBuilder model(m_builder); model.add_mge_version(MGB_VERSION); + model.add_model_version(m_version); model.add_oprs(fb_oprs); model.add_middle_tensors(fb_mid_tensor); model.add_output_vars_idx(fb_output_vars); @@ -694,10 +701,8 @@ void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr( OprRegistryV2::versioned_find_by_id(type_id, opr_version); mgb_throw_if( !registry, SerializationError, - "failed to find opr with type %s , use python env " - "config.dump_registered_oprs() to get a dict that maps from " - "opr id to opr name", - fbopr->type()->str().c_str()); + "failed to find opr with type %s and version %d.", + fbopr->type()->str().c_str(), opr_version); // load inputs VarNodeArray inputs; @@ -811,12 +816,19 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re m_model = fbs::v2::GetModel(m_model_buf.data()); m_mgb_version = m_model->mge_version(); + m_model_version = m_model->model_version(); if (m_model->mge_version() > MGB_VERSION) { mgb_log_warn( "loading model from future runtime: version=%u " "model_version=%u", MGB_VERSION, m_model->mge_version()); } + if (m_model_version > CURRENT_VERSION) { + mgb_log_warn( + "The model dump in the future version %d, try to load it, maybe case " + "load error in %d version.", + m_model_version, CURRENT_VERSION); + } if (m_shared_tensor_map.empty()) { m_shared_tensor_map.resize(m_model->nr_shared_tensor()); @@ -845,8 +857,9 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re return result; } -std::unique_ptr make_fbs_v2_dumper(std::unique_ptr file) { - return std::make_unique(std::move(file)); +std::unique_ptr make_fbs_v2_dumper( + std::unique_ptr file, int version) { + return std::make_unique(std::move(file), version); } std::unique_ptr make_fbs_v2_loader(std::unique_ptr file) { diff --git a/src/serialization/include/megbrain/serialization/load_dump_config.h b/src/serialization/include/megbrain/serialization/load_dump_config.h index 8a2d1ead05d4fea501a85e07331c5693d302cf26..401dea8f597a31a7dbdaceb8ca7a516dccf5442c 100644 --- a/src/serialization/include/megbrain/serialization/load_dump_config.h +++ b/src/serialization/include/megbrain/serialization/load_dump_config.h @@ -58,18 +58,25 @@ struct GraphDumpConfig { //! names. this list record the mapping between output node and it's name std::vector> alias_name_map; + //! whether just to dump all the op with no change the graph, sometimes the + //! opr maybe not compatible, if false, some opr will converter to the compatibility + //! format and then dump + bool no_change_graph; + GraphDumpConfig( int keep_var_name_ = 1, bool keep_param_name_ = false, bool keep_opr_priority_ = false, bool keep_op_name_ = true, const std::shared_ptr& user_data_ = std::make_shared(), - const TensorValueDumper& tensor_value_dumper_ = {}) + const TensorValueDumper& tensor_value_dumper_ = {}, + bool no_change_graph_ = false) : keep_var_name{keep_var_name_}, keep_param_name{keep_param_name_}, keep_opr_priority{keep_opr_priority_}, keep_op_name{keep_op_name_}, user_data{user_data_}, - tensor_value_dumper{tensor_value_dumper_} {} + tensor_value_dumper{tensor_value_dumper_}, + no_change_graph{no_change_graph_} {} }; //! config for loading a whole graph; setup in GraphLoader diff --git a/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h b/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h index 428a369ca12c0097fdd79d3e40178b57141e1941..55097b99dedef7b002d2ff7fcbff2f7be0cac854 100644 --- a/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h +++ b/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h @@ -15,6 +15,7 @@ namespace serialization { class GraphDumperOSSV2 final : public GraphDumper, OprDumpContextFlatBuffers { const std::unique_ptr m_file; + int m_version; flatbuffers::FlatBufferBuilder m_builder; DumpConfig m_config; @@ -51,7 +52,8 @@ class GraphDumperOSSV2 final : public GraphDumper, OprDumpContextFlatBuffers { flatbuffers::Offset build_dtype(DType dtype); public: - GraphDumperOSSV2(std::unique_ptr file) : m_file{std::move(file)} {} + GraphDumperOSSV2(std::unique_ptr file, int version) + : m_file{std::move(file)}, m_version{version} {} DumpResult dump( const SymbolVarArray& output_vars, const DumpConfig& config = {}, @@ -95,6 +97,7 @@ class GraphLoaderOSSV2 final : public GraphLoader { const fbs::v2::Model* m_model; SharedTensorIDMap m_shared_tensor_map; uint32_t m_mgb_version = 0; + uint32_t m_model_version = CURRENT_VERSION; bool m_model_loaded = false; void verify(); diff --git a/src/serialization/include/megbrain/serialization/serializer.h b/src/serialization/include/megbrain/serialization/serializer.h index 581b6e4e6be3464869f0c1fc374a8b2fe0b9a5e9..bfacede4b204506ff7c2b3986fc21b81010f7f73 100644 --- a/src/serialization/include/megbrain/serialization/serializer.h +++ b/src/serialization/include/megbrain/serialization/serializer.h @@ -5,6 +5,7 @@ #include "megbrain/serialization/file.h" #include "megbrain/serialization/load_dump_config.h" #include "megbrain/serialization/metadata.h" +#include "megbrain/serialization/opr_load_dump.h" namespace mgb { namespace serialization { @@ -160,7 +161,8 @@ public: }; MGE_WIN_DECLSPEC_FUC static std::unique_ptr make( - std::unique_ptr file, GraphDumpFormat format = {}); + std::unique_ptr file, GraphDumpFormat format = {}, + int version = VERSION_2); virtual ~GraphDumper() = default; diff --git a/src/serialization/test/serializer_oss.cpp b/src/serialization/test/serializer_oss.cpp index 845afaaae44f9021daf298a771c2297ca21847ce..f821827744fe9b4502d4e95013fc0f56d6e1d4b4 100644 --- a/src/serialization/test/serializer_oss.cpp +++ b/src/serialization/test/serializer_oss.cpp @@ -987,7 +987,9 @@ TEST(TestSerializer2, TestSoftMaxLoadDump) { OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); auto rst = dumper->dump({x}); func->execute().wait(); - ASSERT_EQ(rst.nr_opr, 6); + //! if convert to reduce and elemwise, nr_opr is 6 + // ASSERT_EQ(rst.nr_opr, 6); + ASSERT_EQ(rst.nr_opr, 2); ASSERT_EQ(rst.inputs.size(), 1); ASSERT_EQ(rst.outputs.size(), 1); ASSERT_EQ(rst.params.size(), 0);