Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
09de5a07
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
09de5a07
编写于
1月 28, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/serialization): be able to serialize operator names
GitOrigin-RevId: d295abb5da0b70d4675e62e6632dc1c7bd77d58c
上级
bb8f2928
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
53 addition
and
1 deletion
+53
-1
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+3
-0
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+2
-1
src/serialization/impl/schema.fbs
src/serialization/impl/schema.fbs
+1
-0
src/serialization/impl/serializer_oss.cpp
src/serialization/impl/serializer_oss.cpp
+9
-0
src/serialization/include/megbrain/serialization/load_dump_config.h
...ization/include/megbrain/serialization/load_dump_config.h
+5
-0
src/serialization/test/serializer_oss.cpp
src/serialization/test/serializer_oss.cpp
+33
-0
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
09de5a07
...
@@ -305,6 +305,7 @@ def dump_graph(
...
@@ -305,6 +305,7 @@ def dump_graph(
output_vars
:
Union
[
Dict
[
str
,
VarNode
],
List
[
VarNode
]],
output_vars
:
Union
[
Dict
[
str
,
VarNode
],
List
[
VarNode
]],
*
,
*
,
keep_var_name
:
int
=
1
,
keep_var_name
:
int
=
1
,
keep_op_name
:
bool
=
True
,
keep_param_name
:
bool
=
False
,
keep_param_name
:
bool
=
False
,
keep_opr_priority
:
bool
=
False
,
keep_opr_priority
:
bool
=
False
,
strip_info_file
=
None
,
strip_info_file
=
None
,
...
@@ -325,6 +326,7 @@ def dump_graph(
...
@@ -325,6 +326,7 @@ def dump_graph(
* 0: none of the names are kept
* 0: none of the names are kept
* 1: (default)keep names of output vars
* 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars
* 2: keep names of all (output and internal) vars
:param keep_op_name: whether to keep operator names.
:param keep_param_name: whether to keep param names, so param values can be
:param keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
easily manipulated after loading model
:param keep_opr_priority: whether to keep priority setting for operators
:param keep_opr_priority: whether to keep priority setting for operators
...
@@ -368,6 +370,7 @@ def dump_graph(
...
@@ -368,6 +370,7 @@ def dump_graph(
dump_content
=
_imperative_rt
.
dump_graph
(
dump_content
=
_imperative_rt
.
dump_graph
(
ov
,
ov
,
keep_var_name
,
keep_var_name
,
keep_op_name
,
keep_param_name
,
keep_param_name
,
keep_opr_priority
,
keep_opr_priority
,
stat
,
stat
,
...
...
imperative/python/src/graph_rt.cpp
浏览文件 @
09de5a07
...
@@ -294,6 +294,7 @@ void init_graph_rt(py::module m) {
...
@@ -294,6 +294,7 @@ void init_graph_rt(py::module m) {
m
.
def
(
"dump_graph"
,
[](
m
.
def
(
"dump_graph"
,
[](
const
std
::
vector
<
VarNode
*>&
dest_vars
,
const
std
::
vector
<
VarNode
*>&
dest_vars
,
int
keep_var_name
,
int
keep_var_name
,
bool
keep_op_name
,
bool
keep_param_name
,
bool
keep_param_name
,
bool
keep_opr_priority
,
bool
keep_opr_priority
,
py
::
list
&
stat
,
py
::
list
&
stat
,
...
@@ -306,7 +307,7 @@ void init_graph_rt(py::module m) {
...
@@ -306,7 +307,7 @@ void init_graph_rt(py::module m) {
SymbolVarArray
symvars
(
dest_vars
.
begin
(),
dest_vars
.
end
());
SymbolVarArray
symvars
(
dest_vars
.
begin
(),
dest_vars
.
end
());
ser
::
GraphDumper
::
DumpConfig
config
{
keep_var_name
,
keep_param_name
,
ser
::
GraphDumper
::
DumpConfig
config
{
keep_var_name
,
keep_param_name
,
keep_opr_priority
};
keep_opr_priority
,
keep_op_name
};
auto
rst
=
dumper
->
dump
(
symvars
,
config
);
auto
rst
=
dumper
->
dump
(
symvars
,
config
);
for
(
auto
i
:
rst
.
inputs
)
{
for
(
auto
i
:
rst
.
inputs
)
{
...
...
src/serialization/impl/schema.fbs
浏览文件 @
09de5a07
...
@@ -124,6 +124,7 @@ table Operator {
...
@@ -124,6 +124,7 @@ table Operator {
blobs:[Blob];
blobs:[Blob];
/// Operator may want to save more than one OperatorParam
/// Operator may want to save more than one OperatorParam
additional_params:[OperatorParam];
additional_params:[OperatorParam];
name:string;
}
}
struct OutputVar {
struct OutputVar {
...
...
src/serialization/impl/serializer_oss.cpp
浏览文件 @
09de5a07
...
@@ -208,6 +208,11 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
...
@@ -208,6 +208,11 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
inputs
=
m_builder
.
CreateVector
(
v
);
inputs
=
m_builder
.
CreateVector
(
v
);
}
}
Offset
<
String
>
operator_name
;
if
(
m_config
.
keep_op_name
)
{
operator_name
=
m_builder
.
CreateSharedString
(
opr
->
name
());
}
Offset
<
Vector
<
Offset
<
String
>>>
output_names
;
Offset
<
Vector
<
Offset
<
String
>>>
output_names
;
if
(
m_config
.
keep_var_name
>=
2
||
if
(
m_config
.
keep_var_name
>=
2
||
(
m_config
.
keep_var_name
==
1
&&
(
m_config
.
keep_var_name
==
1
&&
...
@@ -255,6 +260,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
...
@@ -255,6 +260,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
}
}
builder
.
add_comp_node
(
comp_node
);
builder
.
add_comp_node
(
comp_node
);
builder
.
add_output_name
(
output_names
);
builder
.
add_output_name
(
output_names
);
builder
.
add_name
(
operator_name
);
builder
.
add_output_dtype
(
output_dtype
);
builder
.
add_output_dtype
(
output_dtype
);
if
(
param_cnt
>
0
)
{
if
(
param_cnt
>
0
)
{
builder
.
add_param_type
(
m_cur_opr_param_type
[
0
]);
builder
.
add_param_type
(
m_cur_opr_param_type
[
0
]);
...
@@ -698,6 +704,9 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
...
@@ -698,6 +704,9 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
if
(
fbopr
->
output_dtype
())
{
if
(
fbopr
->
output_dtype
())
{
config
.
output_dtype
(
fbs
::
intl
::
load_dtype
(
fbopr
->
output_dtype
()));
config
.
output_dtype
(
fbs
::
intl
::
load_dtype
(
fbopr
->
output_dtype
()));
}
}
if
(
fbopr
->
name
())
{
config
.
name
(
fbopr
->
name
()
->
str
());
}
if
(
fbopr
->
comp_node
())
{
if
(
fbopr
->
comp_node
())
{
auto
cnt
=
fbopr
->
comp_node
()
->
size
();
auto
cnt
=
fbopr
->
comp_node
()
->
size
();
cg
::
OperatorNodeConfig
::
CompNodeArray
comp_node_arr
(
cnt
);
cg
::
OperatorNodeConfig
::
CompNodeArray
comp_node_arr
(
cnt
);
...
...
src/serialization/include/megbrain/serialization/load_dump_config.h
浏览文件 @
09de5a07
...
@@ -43,6 +43,9 @@ struct GraphDumpConfig {
...
@@ -43,6 +43,9 @@ struct GraphDumpConfig {
//! whether to keep operator priorities
//! whether to keep operator priorities
bool
keep_opr_priority
;
bool
keep_opr_priority
;
//! whether to keep operator names
bool
keep_op_name
;
//! extra user data to be passed by dump caller into opr dump
//! extra user data to be passed by dump caller into opr dump
//! implementations; useful for implementing nested opr dump
//! implementations; useful for implementing nested opr dump
std
::
shared_ptr
<
UserDataContainer
>
user_data
;
std
::
shared_ptr
<
UserDataContainer
>
user_data
;
...
@@ -57,12 +60,14 @@ struct GraphDumpConfig {
...
@@ -57,12 +60,14 @@ struct GraphDumpConfig {
GraphDumpConfig
(
int
keep_var_name_
=
1
,
bool
keep_param_name_
=
false
,
GraphDumpConfig
(
int
keep_var_name_
=
1
,
bool
keep_param_name_
=
false
,
bool
keep_opr_priority_
=
false
,
bool
keep_opr_priority_
=
false
,
bool
keep_op_name_
=
true
,
const
std
::
shared_ptr
<
UserDataContainer
>&
user_data_
=
const
std
::
shared_ptr
<
UserDataContainer
>&
user_data_
=
std
::
make_shared
<
UserDataContainer
>
(),
std
::
make_shared
<
UserDataContainer
>
(),
const
TensorValueDumper
&
tensor_value_dumper_
=
{})
const
TensorValueDumper
&
tensor_value_dumper_
=
{})
:
keep_var_name
{
keep_var_name_
},
:
keep_var_name
{
keep_var_name_
},
keep_param_name
{
keep_param_name_
},
keep_param_name
{
keep_param_name_
},
keep_opr_priority
{
keep_opr_priority_
},
keep_opr_priority
{
keep_opr_priority_
},
keep_op_name
{
keep_op_name_
},
user_data
{
user_data_
},
user_data
{
user_data_
},
tensor_value_dumper
{
tensor_value_dumper_
}
{}
tensor_value_dumper
{
tensor_value_dumper_
}
{}
};
};
...
...
src/serialization/test/serializer_oss.cpp
浏览文件 @
09de5a07
...
@@ -711,6 +711,39 @@ TEST(TestSerializer2, ParamerizedDType) {
...
@@ -711,6 +711,39 @@ TEST(TestSerializer2, ParamerizedDType) {
load
();
load
();
}
}
TEST
(
TestSerializer2
,
OperatorName
)
{
auto
fname
=
GET_OUTPUT_FILE
();
TensorShape
shape
{
2
,
3
};
auto
dump
=
[
&
]()
{
auto
cn
=
CompNode
::
load
(
"xpu0"
);
auto
host_x
=
std
::
make_shared
<
HostTensorND
>
(
cn
,
shape
),
host_y
=
std
::
make_shared
<
HostTensorND
>
(
cn
,
shape
);
auto
graph
=
ComputingGraph
::
make
();
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
,
{
"x"
}),
y
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_y
,
{
"y"
});
using
Mode
=
opr
::
Elemwise
::
Mode
;
auto
z
=
opr
::
Elemwise
::
make
({
x
,
y
},
Mode
::
ADD
,
{
"add(x, y)"
});
auto
dumper
=
GraphDumper
::
make
(
OutputFile
::
make_fs
(
fname
.
c_str
()),
GraphDumpFormat
::
FLATBUFFERS
);
auto
rst
=
dumper
->
dump
({
z
.
rename
(
"z"
)});
};
auto
load
=
[
&
]()
{
HostTensorGenerator
<>
gen
;
auto
loader
=
GraphLoader
::
make
(
InputFile
::
make_fs
(
fname
.
c_str
()),
GraphDumpFormat
::
FLATBUFFERS
);
auto
rst
=
loader
->
load
();
auto
z
=
rst
.
output_var_map
.
at
(
"z"
);
auto
op_name
=
z
.
node
()
->
owner_opr
()
->
cname
();
int
cmp
=
strcmp
(
op_name
,
"add(x, y)"
);
EXPECT_EQ
(
cmp
,
0
);
};
dump
();
load
();
}
TEST
(
TestSerializer2
,
HasOutputDtype
)
{
TEST
(
TestSerializer2
,
HasOutputDtype
)
{
auto
fname
=
GET_OUTPUT_FILE
();
auto
fname
=
GET_OUTPUT_FILE
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录