From 4ab5f970e9cda65009186c3853ec8d87120e7023 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 12 May 2022 18:57:04 +0800 Subject: [PATCH] fix(build): fix ci error GitOrigin-RevId: 9cbf64dda27c8d99af009c13b34deadfa9655637 --- imperative/python/megengine/jit/tracing.py | 4 +-- src/opr/impl/dnn/dnn.sereg.v2.h | 2 ++ src/opr/impl/io.sereg.v2.h | 1 + src/serialization/impl/schema_v2.fbs | 10 ++++-- src/serialization/impl/serializer_oss_v2.cpp | 33 ++++++++++++++++--- .../serialization/oss_opr_load_dump.h | 5 ++- 6 files changed, 43 insertions(+), 12 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index cca234806..b6f549348 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -614,8 +614,8 @@ class trace: input_transform: a python expression to transform the input data. Example: data / np.std(data) dump_format: using different dump formats. the open source MegEngine - defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose, - internal MegEngine have an other choice of internal proprietary formats + defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose, + internal MegEngine have an other choice of internal proprietary formats Keyword Arguments: diff --git a/src/opr/impl/dnn/dnn.sereg.v2.h b/src/opr/impl/dnn/dnn.sereg.v2.h index d97dd3a05..cb3266730 100644 --- a/src/opr/impl/dnn/dnn.sereg.v2.h +++ b/src/opr/impl/dnn/dnn.sereg.v2.h @@ -1,3 +1,5 @@ +#pragma once + #include "megbrain/graph/symbol_var.h" #include "megdnn/oprs/general.h" #if MGB_ENABLE_FBS_SERIALIZATION diff --git a/src/opr/impl/io.sereg.v2.h b/src/opr/impl/io.sereg.v2.h index 5c68adb31..1ae18325d 100644 --- a/src/opr/impl/io.sereg.v2.h +++ b/src/opr/impl/io.sereg.v2.h @@ -1,3 +1,4 @@ +#pragma once #if MGB_ENABLE_FBS_SERIALIZATION #include "megbrain/comp_node_env.h" #include "megbrain/opr/dnn/softmax.h" diff --git a/src/serialization/impl/schema_v2.fbs b/src/serialization/impl/schema_v2.fbs index 2e1f58b90..d931d7947 100644 --- a/src/serialization/impl/schema_v2.fbs +++ b/src/serialization/impl/schema_v2.fbs @@ -5,7 +5,7 @@ include "mgb_cpp_opr.fbs"; namespace mgb.serialization.fbs.v2; -file_identifier "mge2"; +file_identifier "mgv2"; table CompNode { logical_locator:string; @@ -105,7 +105,7 @@ union OperatorParam { param.OptionalAxisV1 = 54, param.ExecutionPolicy = 55, param.AssertEqual = 56, - param.FpgaConv = 57, + Reserved0 = 57, param.CollectiveComm = 58, param.CondExecPred = 59, param.CondExecPredLogical = 60, @@ -197,6 +197,11 @@ table OutputVar { original_id:uint; } +table OutputAlias { + id:uint; + name:string; +} + table Model { /// the megengine version when serialize the model mge_version:uint; @@ -213,6 +218,7 @@ table Model { middle_tensors:[MiddleTensor]; output_vars_idx:[OutputVar]; + output_alias:[OutputAlias]; nr_shared_tensor:uint; /// the Metadata to storage the custom data or some flags diff --git a/src/serialization/impl/serializer_oss_v2.cpp b/src/serialization/impl/serializer_oss_v2.cpp index f5a9a9fe4..4903c933f 100644 --- a/src/serialization/impl/serializer_oss_v2.cpp +++ b/src/serialization/impl/serializer_oss_v2.cpp @@ -400,6 +400,18 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( output_vars_idx.push_back(foutput_vars_idx); } auto fb_output_vars = m_builder.CreateVector(output_vars_idx); + std::vector> output_vars_alias; + if (m_config.alias_name_map.size() > 0) { + for (auto&& pair : m_config.alias_name_map) { + std::string name; + SymbolVar var; + std::tie(name, var) = pair; + auto fbs_name = m_builder.CreateSharedString(name); + output_vars_alias.push_back( + fbs::v2::CreateOutputAlias(m_builder, var.node()->id(), fbs_name)); + } + } + auto fbs_output_alias = m_builder.CreateVector(output_vars_alias); auto fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors); fbs::v2::ModelBuilder model(m_builder); @@ -407,6 +419,7 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( model.add_oprs(fb_oprs); model.add_middle_tensors(fb_mid_tensor); model.add_output_vars_idx(fb_output_vars); + model.add_output_alias(fbs_output_alias); model.add_nr_shared_tensor(m_nr_shared_tensor); model.add_metadata(fbmeta); m_builder.FinishSizePrefixed(model.Finish(), fbs::v2::ModelIdentifier()); @@ -469,7 +482,7 @@ void GraphDumperOSSV2::dump_tensor( if (dumper) { mgb_log_warn( "serialization v2 format is pure flatbuffer format, not support " - "user tensor value dumper"); + "user tensor value dumper callback."); } data = m_builder.CreateVector( reinterpret_cast(tensor.raw_ptr()), layout.span().high_byte); @@ -568,7 +581,7 @@ std::shared_ptr GraphLoaderOSSV2::OprLoadContextImpl::load_tensor( if (loader) { mgb_log_warn( "serialization v2 format is pure flatbuffer format, not support " - "user tensor value loader"); + "user tensor value loader callback."); } memcpy(ret->raw_ptr(), tensor->data()->data(), tensor->data()->size()); } @@ -677,15 +690,14 @@ void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr( //! opr version must be exist uint8_t opr_version = fbopr->opr_version(); auto type_id = fbopr->type_id(); - auto opr_type = fbopr->type()->str(); const OprRegistryV2* registry = OprRegistryV2::versioned_find_by_id(type_id, opr_version); mgb_throw_if( !registry, SerializationError, - "failed to find opr with type %s id is %zu, use python env " + "failed to find opr with type %s , use python env " "config.dump_registered_oprs() to get a dict that maps from " "opr id to opr name", - fbopr->type()->str().c_str(), type_id); + fbopr->type()->str().c_str()); // load inputs VarNodeArray inputs; @@ -817,6 +829,17 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re auto metadata = ctx.load_metadata(); auto result = ctx.load_oprs(); result.metadata = metadata; + if (m_model->output_alias() && m_model->output_alias()->size() > 0) { + auto nr_alias = m_model->output_alias()->size(); + result.output_var_list.resize(nr_alias); + for (size_t i = 0; i < nr_alias; i++) { + auto output_alias = m_model->output_alias()->Get(i); + std::string name = output_alias->name()->str(); + size_t id = output_alias->id(); + result.output_var_map[name] = result.output_var_map_id[id]; + result.output_var_list[i] = result.output_var_map_id[id]; + } + } m_model_loaded = true; result.graph_compile_ahead(); return result; diff --git a/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h b/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h index 655adf016..428a369ca 100644 --- a/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h +++ b/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h @@ -233,9 +233,8 @@ public: int addition_index = index - 1; if (addition_index >= static_cast(m_current_opr->additional_params()->size())) { - mgb_log_warn( - "Model has no addition param of index %d, just construct a " - "default one.", + mgb_throw( + SerializationError, "Model has no addition param of index %d.", addition_index); } else { mgb_assert( -- GitLab