提交 dee02289 编写于 作者: M Megvii Engine Team

fix(serialization): when the dump fbsv2 model is used, the middle_tensor becomes optional

GitOrigin-RevId: 3d0bbfd44136c39f50c18cd4a3da5c87fcca50fb
上级 b2959589
......@@ -161,8 +161,16 @@ flatbuffers::Offset<fbs::v2::MiddleTensor> GraphDumperOSSV2::build_middle_tensor
auto fformat = build_tensor_format(layout.format);
serialized_middle_tensor = fbs::v2::CreateMiddleTensor(
m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat);
} else if (var.node()->shape().ndim > 0) {
auto shape = var.node()->shape();
auto fshape =
m_builder.CreateVectorScalarCast<uint32_t>(shape.shape, shape.ndim);
serialized_middle_tensor =
fbs::v2::CreateMiddleTensor(m_builder, fbname, fshape);
} else {
serialized_middle_tensor = fbs::v2::CreateMiddleTensor(m_builder, fbname);
}
serialized_middle_tensor = fbs::v2::CreateMiddleTensor(m_builder, fbname);
return serialized_middle_tensor;
}
......@@ -278,8 +286,12 @@ flatbuffers::Offset<fbs::v2::Operator> GraphDumperOSSV2::build_single_opr(
v.reserve(m_cur_opr->output().size());
for (auto out : m_cur_opr->output()) {
if (!out->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
auto fbs_out = build_middle_tensor(out);
m_model_middle_tensors.push_back(fbs_out);
if (m_config.keep_var_name >= 1) {
auto fbs_out = build_middle_tensor(out);
m_model_middle_tensors.push_back(fbs_out);
} else {
m_model_middle_tensors.push_back(0);
}
m_var2midtensor_id[out] = m_model_middle_tensors.size() - 1;
v.emplace_back(m_var2midtensor_id.at(out));
}
......@@ -425,13 +437,19 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
}
}
auto fbs_output_alias = m_builder.CreateVector(output_vars_alias);
auto fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors);
flatbuffers::Offset<flatbuffers::Vector<
flatbuffers::Offset<mgb::serialization::fbs::v2::MiddleTensor>>>
fb_mid_tensor;
if (m_config.keep_var_name >= 1)
fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors);
fbs::v2::ModelBuilder model(m_builder);
model.add_mge_version(MGB_VERSION);
model.add_model_version(m_version);
model.add_oprs(fb_oprs);
model.add_middle_tensors(fb_mid_tensor);
if (m_config.keep_var_name >= 1) {
model.add_middle_tensors(fb_mid_tensor);
}
model.add_output_vars_idx(fb_output_vars);
model.add_output_alias(fbs_output_alias);
model.add_nr_shared_tensor(m_nr_shared_tensor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册