提交 4ab5f970 编写于 作者: M Megvii Engine Team

fix(build): fix ci error

GitOrigin-RevId: 9cbf64dda27c8d99af009c13b34deadfa9655637
上级 b9a69323
......@@ -614,8 +614,8 @@ class trace:
input_transform: a python expression to transform the input data.
Example: data / np.std(data)
dump_format: using different dump formats. the open source MegEngine
defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose,
internal MegEngine have an other choice of internal proprietary formats
defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose,
internal MegEngine have an other choice of internal proprietary formats
Keyword Arguments:
......
#pragma once
#include "megbrain/graph/symbol_var.h"
#include "megdnn/oprs/general.h"
#if MGB_ENABLE_FBS_SERIALIZATION
......
#pragma once
#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/comp_node_env.h"
#include "megbrain/opr/dnn/softmax.h"
......
......@@ -5,7 +5,7 @@ include "mgb_cpp_opr.fbs";
namespace mgb.serialization.fbs.v2;
file_identifier "mge2";
file_identifier "mgv2";
table CompNode {
logical_locator:string;
......@@ -105,7 +105,7 @@ union OperatorParam {
param.OptionalAxisV1 = 54,
param.ExecutionPolicy = 55,
param.AssertEqual = 56,
param.FpgaConv = 57,
Reserved0 = 57,
param.CollectiveComm = 58,
param.CondExecPred = 59,
param.CondExecPredLogical = 60,
......@@ -197,6 +197,11 @@ table OutputVar {
original_id:uint;
}
table OutputAlias {
id:uint;
name:string;
}
table Model {
/// the megengine version when serialize the model
mge_version:uint;
......@@ -213,6 +218,7 @@ table Model {
middle_tensors:[MiddleTensor];
output_vars_idx:[OutputVar];
output_alias:[OutputAlias];
nr_shared_tensor:uint;
/// the Metadata to storage the custom data or some flags
......
......@@ -400,6 +400,18 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
output_vars_idx.push_back(foutput_vars_idx);
}
auto fb_output_vars = m_builder.CreateVector(output_vars_idx);
std::vector<flatbuffers::Offset<fbs::v2::OutputAlias>> output_vars_alias;
if (m_config.alias_name_map.size() > 0) {
for (auto&& pair : m_config.alias_name_map) {
std::string name;
SymbolVar var;
std::tie(name, var) = pair;
auto fbs_name = m_builder.CreateSharedString(name);
output_vars_alias.push_back(
fbs::v2::CreateOutputAlias(m_builder, var.node()->id(), fbs_name));
}
}
auto fbs_output_alias = m_builder.CreateVector(output_vars_alias);
auto fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors);
fbs::v2::ModelBuilder model(m_builder);
......@@ -407,6 +419,7 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
model.add_oprs(fb_oprs);
model.add_middle_tensors(fb_mid_tensor);
model.add_output_vars_idx(fb_output_vars);
model.add_output_alias(fbs_output_alias);
model.add_nr_shared_tensor(m_nr_shared_tensor);
model.add_metadata(fbmeta);
m_builder.FinishSizePrefixed(model.Finish(), fbs::v2::ModelIdentifier());
......@@ -469,7 +482,7 @@ void GraphDumperOSSV2::dump_tensor(
if (dumper) {
mgb_log_warn(
"serialization v2 format is pure flatbuffer format, not support "
"user tensor value dumper");
"user tensor value dumper callback.");
}
data = m_builder.CreateVector(
reinterpret_cast<uint8_t*>(tensor.raw_ptr()), layout.span().high_byte);
......@@ -568,7 +581,7 @@ std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor(
if (loader) {
mgb_log_warn(
"serialization v2 format is pure flatbuffer format, not support "
"user tensor value loader");
"user tensor value loader callback.");
}
memcpy(ret->raw_ptr(), tensor->data()->data(), tensor->data()->size());
}
......@@ -677,15 +690,14 @@ void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr(
//! opr version must be exist
uint8_t opr_version = fbopr->opr_version();
auto type_id = fbopr->type_id();
auto opr_type = fbopr->type()->str();
const OprRegistryV2* registry =
OprRegistryV2::versioned_find_by_id(type_id, opr_version);
mgb_throw_if(
!registry, SerializationError,
"failed to find opr with type %s id is %zu, use python env "
"failed to find opr with type %s , use python env "
"config.dump_registered_oprs() to get a dict that maps from "
"opr id to opr name",
fbopr->type()->str().c_str(), type_id);
fbopr->type()->str().c_str());
// load inputs
VarNodeArray inputs;
......@@ -817,6 +829,17 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re
auto metadata = ctx.load_metadata();
auto result = ctx.load_oprs();
result.metadata = metadata;
if (m_model->output_alias() && m_model->output_alias()->size() > 0) {
auto nr_alias = m_model->output_alias()->size();
result.output_var_list.resize(nr_alias);
for (size_t i = 0; i < nr_alias; i++) {
auto output_alias = m_model->output_alias()->Get(i);
std::string name = output_alias->name()->str();
size_t id = output_alias->id();
result.output_var_map[name] = result.output_var_map_id[id];
result.output_var_list[i] = result.output_var_map_id[id];
}
}
m_model_loaded = true;
result.graph_compile_ahead();
return result;
......
......@@ -233,9 +233,8 @@ public:
int addition_index = index - 1;
if (addition_index >=
static_cast<int>(m_current_opr->additional_params()->size())) {
mgb_log_warn(
"Model has no addition param of index %d, just construct a "
"default one.",
mgb_throw(
SerializationError, "Model has no addition param of index %d.",
addition_index);
} else {
mgb_assert(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册