提交 09de5a07 编写于 作者: M Megvii Engine Team

feat(mgb/serialization): be able to serialize operator names

GitOrigin-RevId: d295abb5da0b70d4675e62e6632dc1c7bd77d58c
上级 bb8f2928
......@@ -305,6 +305,7 @@ def dump_graph(
output_vars: Union[Dict[str, VarNode], List[VarNode]],
*,
keep_var_name: int = 1,
keep_op_name: bool = True,
keep_param_name: bool = False,
keep_opr_priority: bool = False,
strip_info_file=None,
......@@ -325,6 +326,7 @@ def dump_graph(
* 0: none of the names are kept
* 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars
:param keep_op_name: whether to keep operator names.
:param keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
:param keep_opr_priority: whether to keep priority setting for operators
......@@ -368,6 +370,7 @@ def dump_graph(
dump_content = _imperative_rt.dump_graph(
ov,
keep_var_name,
keep_op_name,
keep_param_name,
keep_opr_priority,
stat,
......
......@@ -294,6 +294,7 @@ void init_graph_rt(py::module m) {
m.def("dump_graph", [](
const std::vector<VarNode*>& dest_vars,
int keep_var_name,
bool keep_op_name,
bool keep_param_name,
bool keep_opr_priority,
py::list& stat,
......@@ -306,7 +307,7 @@ void init_graph_rt(py::module m) {
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name,
keep_opr_priority};
keep_opr_priority, keep_op_name};
auto rst = dumper->dump(symvars, config);
for (auto i : rst.inputs) {
......
......@@ -124,6 +124,7 @@ table Operator {
blobs:[Blob];
/// Operator may want to save more than one OperatorParam
additional_params:[OperatorParam];
name:string;
}
struct OutputVar {
......
......@@ -208,6 +208,11 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
inputs = m_builder.CreateVector(v);
}
Offset<String> operator_name;
if (m_config.keep_op_name) {
operator_name = m_builder.CreateSharedString(opr->name());
}
Offset<Vector<Offset<String>>> output_names;
if (m_config.keep_var_name >= 2 ||
(m_config.keep_var_name == 1 &&
......@@ -255,6 +260,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
}
builder.add_comp_node(comp_node);
builder.add_output_name(output_names);
builder.add_name(operator_name);
builder.add_output_dtype(output_dtype);
if (param_cnt > 0) {
builder.add_param_type(m_cur_opr_param_type[0]);
......@@ -698,6 +704,9 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
if (fbopr->output_dtype()) {
config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype()));
}
if (fbopr->name()) {
config.name(fbopr->name()->str());
}
if (fbopr->comp_node()) {
auto cnt = fbopr->comp_node()->size();
cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt);
......
......@@ -43,6 +43,9 @@ struct GraphDumpConfig {
//! whether to keep operator priorities
bool keep_opr_priority;
//! whether to keep operator names
bool keep_op_name;
//! extra user data to be passed by dump caller into opr dump
//! implementations; useful for implementing nested opr dump
std::shared_ptr<UserDataContainer> user_data;
......@@ -57,12 +60,14 @@ struct GraphDumpConfig {
GraphDumpConfig(int keep_var_name_ = 1, bool keep_param_name_ = false,
bool keep_opr_priority_ = false,
bool keep_op_name_ = true,
const std::shared_ptr<UserDataContainer>& user_data_ =
std::make_shared<UserDataContainer>(),
const TensorValueDumper& tensor_value_dumper_ = {})
: keep_var_name{keep_var_name_},
keep_param_name{keep_param_name_},
keep_opr_priority{keep_opr_priority_},
keep_op_name{keep_op_name_},
user_data{user_data_},
tensor_value_dumper{tensor_value_dumper_} {}
};
......
......@@ -711,6 +711,39 @@ TEST(TestSerializer2, ParamerizedDType) {
load();
}
TEST(TestSerializer2, OperatorName) {
auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3};
auto dump = [&]() {
auto cn = CompNode::load("xpu0");
auto host_x = std::make_shared<HostTensorND>(cn, shape),
host_y = std::make_shared<HostTensorND>(cn, shape);
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}),
y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"});
using Mode = opr::Elemwise::Mode;
auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"});
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = dumper->dump({z.rename("z")});
};
auto load = [&]() {
HostTensorGenerator<> gen;
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = loader->load();
auto z = rst.output_var_map.at("z");
auto op_name = z.node()->owner_opr()->cname();
int cmp = strcmp(op_name, "add(x, y)");
EXPECT_EQ(cmp, 0);
};
dump();
load();
}
TEST(TestSerializer2, HasOutputDtype) {
auto fname = GET_OUTPUT_FILE();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册