Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
91efd67d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
10 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
91efd67d
编写于
9月 14, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/jit): change dump options, add test
GitOrigin-RevId: fbc0d51c2be1fd51aaea121f6afa48b25abf661a
上级
099ffeac
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
60 addition
and
66 deletion
+60
-66
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+25
-26
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+14
-2
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+6
-0
imperative/python/test/unit/test_cgtools.py
imperative/python/test/unit/test_cgtools.py
+2
-3
imperative/python/test/unit/test_tracing.py
imperative/python/test/unit/test_tracing.py
+13
-35
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
91efd67d
...
...
@@ -130,32 +130,31 @@ def optimize_for_inference(dest_vars, **kwargs):
inference)
"""
inference_options
=
GraphOptimizeOptions
()
if
optimize_for_inference
:
inference_optimize_layout_transform_map
=
{
"enable_hwcd4"
:
GraphOptimizeOptions
.
LayoutTransform
.
NHWCD4
,
"enable_nchw4"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW4
,
"enable_nchw88"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW88
,
"enable_nchw32"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW32
,
"enable_nchw44"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW44
,
"enable_nchw44_dot"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW44_DOT
,
"enable_chwn4"
:
GraphOptimizeOptions
.
LayoutTransform
.
CHWN4
,
}
for
k
,
v
in
inference_optimize_layout_transform_map
.
items
():
if
kwargs
.
pop
(
k
,
False
):
inference_options
.
layout_transform
=
v
if
kwargs
.
pop
(
"enable_io16xc32"
,
False
):
inference_options
.
f16_io_f32_comp
=
True
if
kwargs
.
pop
(
"enable_ioc16"
,
False
):
inference_options
.
f16_io_comp
=
True
if
kwargs
.
pop
(
"enable_fuse_conv_bias_nonlinearity"
,
False
):
inference_options
.
fuse_conv_bias_nonlinearity
=
True
if
kwargs
.
pop
(
"enable_fuse_conv_bias_with_z"
,
False
):
inference_options
.
fuse_conv_bias_with_z
=
True
if
kwargs
:
raise
ValueError
(
"unknown options: %s"
%
list
(
kwargs
))
inference_optimize_layout_transform_map
=
{
"enable_hwcd4"
:
GraphOptimizeOptions
.
LayoutTransform
.
NHWCD4
,
"enable_nchw4"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW4
,
"enable_nchw88"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW88
,
"enable_nchw32"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW32
,
"enable_nchw44"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW44
,
"enable_nchw44_dot"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW44_DOT
,
"enable_chwn4"
:
GraphOptimizeOptions
.
LayoutTransform
.
CHWN4
,
}
for
k
,
v
in
inference_optimize_layout_transform_map
.
items
():
if
kwargs
.
pop
(
k
,
False
):
inference_options
.
layout_transform
=
v
if
kwargs
.
pop
(
"enable_io16xc32"
,
False
):
inference_options
.
f16_io_f32_comp
=
True
if
kwargs
.
pop
(
"enable_ioc16"
,
False
):
inference_options
.
f16_io_comp
=
True
if
kwargs
.
pop
(
"enable_fuse_conv_bias_nonlinearity"
,
False
):
inference_options
.
fuse_conv_bias_nonlinearity
=
True
if
kwargs
.
pop
(
"enable_fuse_conv_bias_with_z"
,
False
):
inference_options
.
fuse_conv_bias_with_z
=
True
if
kwargs
:
raise
ValueError
(
"unknown options: %s"
%
list
(
kwargs
))
res_vars
=
_imperative_rt
.
optimize_for_inference
(
[
i
.
_node
for
i
in
dest_vars
],
inference_options
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
91efd67d
...
...
@@ -458,7 +458,16 @@ class trace:
self
.
_process_outputs
(
outputs
)
return
outputs
def
dump
(
self
,
file
,
*
,
arg_names
=
None
,
output_names
=
None
,
append
=
False
,
**
kwargs
):
def
dump
(
self
,
file
,
*
,
arg_names
=
None
,
output_names
=
None
,
append
=
False
,
optimize_for_inference
=
True
,
**
kwargs
):
r
"""Serializes trace to file system.
:param file: output file, could be file object or filename.
...
...
@@ -467,6 +476,8 @@ class trace:
use the default name if not specified.
:param append: whether output is appended to ``file``.
Only works when ``file`` is str.
:param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True
:Keyword Arguments:
...
...
@@ -572,7 +583,8 @@ class trace:
v
.
name
=
output_names
[
i
]
dest_vars
.
append
(
v
)
dest_vars
=
G
.
optimize_for_inference
(
dest_vars
,
**
kwargs
)
if
optimize_for_inference
:
dest_vars
=
G
.
optimize_for_inference
(
dest_vars
,
**
kwargs
)
if
isinstance
(
file
,
str
):
permission
=
"wb"
if
append
==
False
else
"ab"
...
...
imperative/python/src/graph_rt.cpp
浏览文件 @
91efd67d
...
...
@@ -155,6 +155,9 @@ void init_graph_rt(py::module m) {
})
.
def_property_readonly
(
"id"
,[](
cg
::
VarNode
*
v
){
return
(
v
->
id
());
})
.
def
(
"__repr__"
,
[](
cg
::
VarNode
*
v
)
{
return
"Var:"
+
v
->
name
();
});
py
::
class_
<
cg
::
OperatorNodeBase
,
GraphNodePtr
<
cg
::
OperatorNodeBase
>>
(
m
,
"OperatorNode"
)
...
...
@@ -175,6 +178,9 @@ void init_graph_rt(py::module m) {
})
.
def_property_readonly
(
"type"
,[](
cg
::
OperatorNodeBase
*
opr
){
return
opr
->
dyn_typeinfo
()
->
name
;
})
.
def
(
"__repr__"
,
[](
cg
::
OperatorNodeBase
*
opr
){
return
"Opr:"
+
opr
->
name
();
});
...
...
imperative/python/test/unit/test_cgtools.py
浏览文件 @
91efd67d
...
...
@@ -67,7 +67,6 @@ def test_replace_oprs():
np
.
testing
.
assert_equal
(
res
,
np
.
array
([
5.0
*
5.0
*
1.25
]))
@
pytest
.
mark
.
skip
(
reason
=
"Please check opr index"
)
def
test_graph_traversal
():
net
=
M
.
Conv2d
(
3
,
32
,
3
)
...
...
@@ -77,11 +76,11 @@ def test_graph_traversal():
return
x
data
=
np
.
random
.
random
([
1
,
3
,
224
,
224
]).
astype
(
np
.
float32
)
for
i
in
range
(
3
):
for
_
in
range
(
3
):
fun
(
megengine
.
tensor
(
data
))
file
=
io
.
BytesIO
()
fun
.
dump
(
file
)
fun
.
dump
(
file
,
optimize_for_inference
=
False
)
file
.
seek
(
0
)
cg
,
_
,
outputs
=
mgb_graph
.
load_graph
(
file
)
...
...
imperative/python/test/unit/test_tracing.py
浏览文件 @
91efd67d
...
...
@@ -13,7 +13,6 @@ import numpy as np
import
pytest
import
megengine
import
megengine.core.tensor.megbrain_graph
as
G
import
megengine.module
as
M
from
megengine
import
cgtools
,
tensor
from
megengine.core._trace_option
import
set_tensor_shape
...
...
@@ -150,7 +149,6 @@ def test_capture_dump():
np
.
testing
.
assert_equal
(
result
[
0
],
y
)
@
pytest
.
mark
.
skip
(
reason
=
"get MultipleDeviceTensorHolder instead of SharedDeviceTensor"
)
def
test_dump_volatile
():
p
=
as_raw_tensor
([
2
])
...
...
@@ -167,7 +165,7 @@ def test_dump_volatile():
np
.
testing
.
assert_equal
(
f
(
as_raw_tensor
(
x
)).
numpy
(),
y
)
file
=
io
.
BytesIO
()
f
.
dump
(
file
)
f
.
dump
(
file
,
optimize_for_inference
=
False
)
file
.
seek
(
0
)
cg
,
_
,
outputs
=
G
.
load_graph
(
file
)
(
out
,)
=
outputs
...
...
@@ -196,26 +194,7 @@ def test_trace_profiler():
assert
out
.
get
(
"profiler"
)
@
pytest
.
mark
.
skip
(
reason
=
"eq_to_unit failed in inplace.cpp"
)
def
test_goptions_div_zero
():
@
trace
(
symbolic
=
True
,
opt_level
=
0
)
def
f
(
x
):
return
x
/
x
@
trace
(
symbolic
=
True
,
opt_level
=
1
)
def
g
(
x
):
return
x
/
x
out
=
f
(
tensor
(
0.0
))
if
out
==
out
:
raise
ValueError
(
"actual result should be nan"
)
out
=
g
(
tensor
(
0.0
))
if
out
!=
out
:
raise
ValueError
(
"actual result should be 1"
)
@
pytest
.
mark
.
skip
(
reason
=
"cast to Elemwise failed in inplace.cpp"
)
@
pytest
.
mark
.
skip
(
reason
=
"could not disable opt_level"
)
def
test_goptions_log_exp
():
@
trace
(
symbolic
=
True
,
opt_level
=
0
,
capture_as_const
=
True
)
def
f
(
x
):
...
...
@@ -227,19 +206,19 @@ def test_goptions_log_exp():
f
(
tensor
(
1.0
))
_
,
out
=
mkstemp
()
f
.
dump
(
out
)
*
_
,
outputs
=
G
.
load_
comp_graph_from_file
(
out
)
f
.
dump
(
out
,
optimize_for_inference
=
False
)
*
_
,
outputs
=
G
.
load_
graph
(
out
)
oprs_1
=
cgtools
.
get_oprs_seq
(
outputs
)
g
(
tensor
(
1.0
))
g
.
dump
(
out
)
*
_
,
outputs
=
G
.
load_
comp_graph_from_file
(
out
)
g
.
dump
(
out
,
optimize_for_inference
=
False
)
*
_
,
outputs
=
G
.
load_
graph
(
out
)
oprs_2
=
cgtools
.
get_oprs_seq
(
outputs
)
assert
len
(
oprs_1
)
-
len
(
oprs_2
)
==
2
@
pytest
.
mark
.
skip
(
reason
=
"
need cgtools to check final oprs
"
)
@
pytest
.
mark
.
skip
(
reason
=
"
could not disable opt_level
"
)
def
test_goptions_log_sum_exp
():
@
trace
(
symbolic
=
True
,
opt_level
=
0
,
capture_as_const
=
True
)
def
f
(
x
,
y
):
...
...
@@ -251,19 +230,18 @@ def test_goptions_log_sum_exp():
f
(
tensor
(
1.0
),
tensor
(
2.0
))
_
,
out
=
mkstemp
()
f
.
dump
(
out
)
*
_
,
outputs
=
G
.
load_
comp_graph_from_file
(
out
)
f
.
dump
(
out
,
optimize_for_inference
=
False
)
*
_
,
outputs
=
G
.
load_
graph
(
out
)
oprs_1
=
cgtools
.
get_oprs_seq
(
outputs
)
g
(
tensor
(
1.0
),
tensor
(
2.0
))
g
.
dump
(
out
)
*
_
,
outputs
=
G
.
load_
comp_graph_from_file
(
out
)
g
.
dump
(
out
,
optimize_for_inference
=
False
)
*
_
,
outputs
=
G
.
load_
graph
(
out
)
oprs_2
=
cgtools
.
get_oprs_seq
(
outputs
)
assert
len
(
oprs_1
)
-
len
(
oprs_2
)
==
2
@
pytest
.
mark
.
skip
(
reason
=
"need cgtools to check computing input dtype"
)
def
test_optimize_for_inference
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
f
(
x
):
...
...
@@ -271,9 +249,9 @@ def test_optimize_for_inference():
_
,
out
=
mkstemp
()
f
(
tensor
(
5.0
))
f
.
dump
(
out
,
optimize_for_inference
=
True
,
optimize_options
=
{
"enable_io16xc32"
:
True
}
)
f
.
dump
(
out
,
enable_io16xc32
=
True
)
res
=
G
.
load_
comp_graph_from_file
(
out
)
res
=
G
.
load_
graph
(
out
)
computing_input
=
res
.
output_vars_list
[
0
].
owner
.
inputs
[
0
]
assert
computing_input
.
dtype
==
np
.
float16
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录