提交 09de5a07 编写于 作者: M Megvii Engine Team

feat(mgb/serialization): be able to serialize operator names

GitOrigin-RevId: d295abb5da0b70d4675e62e6632dc1c7bd77d58c
上级 bb8f2928
...@@ -305,6 +305,7 @@ def dump_graph( ...@@ -305,6 +305,7 @@ def dump_graph(
output_vars: Union[Dict[str, VarNode], List[VarNode]], output_vars: Union[Dict[str, VarNode], List[VarNode]],
*, *,
keep_var_name: int = 1, keep_var_name: int = 1,
keep_op_name: bool = True,
keep_param_name: bool = False, keep_param_name: bool = False,
keep_opr_priority: bool = False, keep_opr_priority: bool = False,
strip_info_file=None, strip_info_file=None,
...@@ -325,6 +326,7 @@ def dump_graph( ...@@ -325,6 +326,7 @@ def dump_graph(
* 0: none of the names are kept * 0: none of the names are kept
* 1: (default)keep names of output vars * 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars * 2: keep names of all (output and internal) vars
:param keep_op_name: whether to keep operator names.
:param keep_param_name: whether to keep param names, so param values can be :param keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model easily manipulated after loading model
:param keep_opr_priority: whether to keep priority setting for operators :param keep_opr_priority: whether to keep priority setting for operators
...@@ -368,6 +370,7 @@ def dump_graph( ...@@ -368,6 +370,7 @@ def dump_graph(
dump_content = _imperative_rt.dump_graph( dump_content = _imperative_rt.dump_graph(
ov, ov,
keep_var_name, keep_var_name,
keep_op_name,
keep_param_name, keep_param_name,
keep_opr_priority, keep_opr_priority,
stat, stat,
......
...@@ -294,6 +294,7 @@ void init_graph_rt(py::module m) { ...@@ -294,6 +294,7 @@ void init_graph_rt(py::module m) {
m.def("dump_graph", []( m.def("dump_graph", [](
const std::vector<VarNode*>& dest_vars, const std::vector<VarNode*>& dest_vars,
int keep_var_name, int keep_var_name,
bool keep_op_name,
bool keep_param_name, bool keep_param_name,
bool keep_opr_priority, bool keep_opr_priority,
py::list& stat, py::list& stat,
...@@ -306,7 +307,7 @@ void init_graph_rt(py::module m) { ...@@ -306,7 +307,7 @@ void init_graph_rt(py::module m) {
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name,
keep_opr_priority}; keep_opr_priority, keep_op_name};
auto rst = dumper->dump(symvars, config); auto rst = dumper->dump(symvars, config);
for (auto i : rst.inputs) { for (auto i : rst.inputs) {
......
...@@ -124,6 +124,7 @@ table Operator { ...@@ -124,6 +124,7 @@ table Operator {
blobs:[Blob]; blobs:[Blob];
/// Operator may want to save more than one OperatorParam /// Operator may want to save more than one OperatorParam
additional_params:[OperatorParam]; additional_params:[OperatorParam];
name:string;
} }
struct OutputVar { struct OutputVar {
......
...@@ -208,6 +208,11 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( ...@@ -208,6 +208,11 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
inputs = m_builder.CreateVector(v); inputs = m_builder.CreateVector(v);
} }
Offset<String> operator_name;
if (m_config.keep_op_name) {
operator_name = m_builder.CreateSharedString(opr->name());
}
Offset<Vector<Offset<String>>> output_names; Offset<Vector<Offset<String>>> output_names;
if (m_config.keep_var_name >= 2 || if (m_config.keep_var_name >= 2 ||
(m_config.keep_var_name == 1 && (m_config.keep_var_name == 1 &&
...@@ -255,6 +260,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( ...@@ -255,6 +260,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
} }
builder.add_comp_node(comp_node); builder.add_comp_node(comp_node);
builder.add_output_name(output_names); builder.add_output_name(output_names);
builder.add_name(operator_name);
builder.add_output_dtype(output_dtype); builder.add_output_dtype(output_dtype);
if (param_cnt > 0) { if (param_cnt > 0) {
builder.add_param_type(m_cur_opr_param_type[0]); builder.add_param_type(m_cur_opr_param_type[0]);
...@@ -698,6 +704,9 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( ...@@ -698,6 +704,9 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
if (fbopr->output_dtype()) { if (fbopr->output_dtype()) {
config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype())); config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype()));
} }
if (fbopr->name()) {
config.name(fbopr->name()->str());
}
if (fbopr->comp_node()) { if (fbopr->comp_node()) {
auto cnt = fbopr->comp_node()->size(); auto cnt = fbopr->comp_node()->size();
cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt); cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt);
......
...@@ -43,6 +43,9 @@ struct GraphDumpConfig { ...@@ -43,6 +43,9 @@ struct GraphDumpConfig {
//! whether to keep operator priorities //! whether to keep operator priorities
bool keep_opr_priority; bool keep_opr_priority;
//! whether to keep operator names
bool keep_op_name;
//! extra user data to be passed by dump caller into opr dump //! extra user data to be passed by dump caller into opr dump
//! implementations; useful for implementing nested opr dump //! implementations; useful for implementing nested opr dump
std::shared_ptr<UserDataContainer> user_data; std::shared_ptr<UserDataContainer> user_data;
...@@ -57,12 +60,14 @@ struct GraphDumpConfig { ...@@ -57,12 +60,14 @@ struct GraphDumpConfig {
GraphDumpConfig(int keep_var_name_ = 1, bool keep_param_name_ = false, GraphDumpConfig(int keep_var_name_ = 1, bool keep_param_name_ = false,
bool keep_opr_priority_ = false, bool keep_opr_priority_ = false,
bool keep_op_name_ = true,
const std::shared_ptr<UserDataContainer>& user_data_ = const std::shared_ptr<UserDataContainer>& user_data_ =
std::make_shared<UserDataContainer>(), std::make_shared<UserDataContainer>(),
const TensorValueDumper& tensor_value_dumper_ = {}) const TensorValueDumper& tensor_value_dumper_ = {})
: keep_var_name{keep_var_name_}, : keep_var_name{keep_var_name_},
keep_param_name{keep_param_name_}, keep_param_name{keep_param_name_},
keep_opr_priority{keep_opr_priority_}, keep_opr_priority{keep_opr_priority_},
keep_op_name{keep_op_name_},
user_data{user_data_}, user_data{user_data_},
tensor_value_dumper{tensor_value_dumper_} {} tensor_value_dumper{tensor_value_dumper_} {}
}; };
......
...@@ -711,6 +711,39 @@ TEST(TestSerializer2, ParamerizedDType) { ...@@ -711,6 +711,39 @@ TEST(TestSerializer2, ParamerizedDType) {
load(); load();
} }
TEST(TestSerializer2, OperatorName) {
auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3};
auto dump = [&]() {
auto cn = CompNode::load("xpu0");
auto host_x = std::make_shared<HostTensorND>(cn, shape),
host_y = std::make_shared<HostTensorND>(cn, shape);
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}),
y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"});
using Mode = opr::Elemwise::Mode;
auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"});
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = dumper->dump({z.rename("z")});
};
auto load = [&]() {
HostTensorGenerator<> gen;
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = loader->load();
auto z = rst.output_var_map.at("z");
auto op_name = z.node()->owner_opr()->cname();
int cmp = strcmp(op_name, "add(x, y)");
EXPECT_EQ(cmp, 0);
};
dump();
load();
}
TEST(TestSerializer2, HasOutputDtype) { TEST(TestSerializer2, HasOutputDtype) {
auto fname = GET_OUTPUT_FILE(); auto fname = GET_OUTPUT_FILE();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册