Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8118a594
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
8118a594
编写于
11月 13, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/utils): fix get_oprs_seq of cgtools
GitOrigin-RevId: 366a56f4d5b7b607d4f14b9102c387392a3a5936
上级
ae8c3c81
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
61 addition
and
7 deletion
+61
-7
imperative/python/megengine/__init__.py
imperative/python/megengine/__init__.py
+0
-1
imperative/python/megengine/utils/comp_graph_tools.py
imperative/python/megengine/utils/comp_graph_tools.py
+22
-4
imperative/python/test/unit/test_cgtools.py
imperative/python/test/unit/test_cgtools.py
+37
-1
imperative/python/test/unit/test_tracing.py
imperative/python/test/unit/test_tracing.py
+2
-1
未找到文件。
imperative/python/megengine/__init__.py
浏览文件 @
8118a594
...
...
@@ -78,7 +78,6 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from
.serialization
import
load
,
save
from
.tensor
import
Parameter
,
Tensor
,
tensor
from
.version
import
__version__
from
.utils
import
comp_graph_tools
as
cgtools
_set_fork_exec_path_for_timed_func
(
sys
.
executable
,
...
...
imperative/python/megengine/utils/comp_graph_tools.py
浏览文件 @
8118a594
...
...
@@ -15,6 +15,19 @@ from ..core._imperative_rt import OperatorNode, VarNode
from
..core.tensor
import
megbrain_graph
as
G
from
..core.tensor.raw_tensor
import
as_raw_tensor
__all__
=
[
"get_dep_vars"
,
"get_owner_opr_inputs"
,
"get_owner_opr_type"
,
"get_opr_type"
,
"graph_traversal"
,
"get_oprs_seq"
,
"replace_vars"
,
"replace_oprs"
,
"set_priority_to_id"
,
"load_and_inference"
,
]
def
get_dep_vars
(
var
:
VarNode
,
var_type
:
str
=
None
)
->
List
[
VarNode
]:
"""
...
...
@@ -166,7 +179,7 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo
# reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
# when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
def
prune_reshape_oprs
(
outputs
,
oprs_seq
,
var2oprs
):
def
iterative_pruning
(
cur_opr
,
post_opr
,
marked_opr_ids
):
def
iterative_pruning
(
cur_opr
,
post_opr
,
marked_opr_ids
,
visited
):
useless
=
True
for
oup
in
cur_opr
.
outputs
:
if
"workspace"
not
in
oup
.
name
:
...
...
@@ -177,15 +190,20 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo
if
useless
:
marked_opr_ids
.
append
(
cur_opr
.
id
)
for
inp
in
cur_opr
.
inputs
:
iterative_pruning
(
inp
.
owner
,
cur_opr
,
marked_opr_ids
)
for
opr
in
set
([
var
.
owner
for
var
in
cur_opr
.
inputs
]):
if
(
opr
.
id
,
cur_opr
.
id
)
not
in
visited
:
visited
.
add
((
opr
.
id
,
cur_opr
.
id
))
iterative_pruning
(
opr
,
cur_opr
,
marked_opr_ids
,
visited
)
reshape_vars
=
get_dep_vars
(
outputs
,
"Reshape"
)
reshape_oprs
=
[
var
.
owner
for
var
in
reshape_vars
]
marked_opr_ids
=
[]
visited
=
set
()
for
reshape_opr
in
reshape_oprs
:
iterative_pruning
(
reshape_opr
.
inputs
[
1
].
owner
,
reshape_opr
,
marked_opr_ids
)
iterative_pruning
(
reshape_opr
.
inputs
[
1
].
owner
,
reshape_opr
,
marked_opr_ids
,
visited
)
# filter out all marked oprs
return
list
(
filter
(
lambda
x
:
x
.
id
not
in
marked_opr_ids
,
oprs_seq
))
...
...
imperative/python/test/unit/test_cgtools.py
浏览文件 @
8118a594
...
...
@@ -13,9 +13,10 @@ import pytest
import
megengine
import
megengine.functional
as
F
import
megengine.module
as
M
from
megengine
import
cgtools
import
megengine.utils.comp_graph_tools
as
cgtools
from
megengine.core.tensor
import
megbrain_graph
as
mgb_graph
from
megengine.core.tensor.raw_tensor
import
as_raw_tensor
from
megengine.core.tensor.utils
import
astensor1d
from
megengine.jit
import
trace
...
...
@@ -98,3 +99,38 @@ def test_load_refcnt():
graph
,
_
,
(
varnode
,)
=
mgb_graph
.
load_graph
(
io
.
BytesIO
(
buf
))
del
graph
varnode
.
owner
def
test_get_opr_seq
():
class
Net
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
data
=
megengine
.
tensor
(
np
.
random
.
random
((
1
,
1
,
4
,
4
)),
dtype
=
np
.
float32
)
def
forward
(
self
,
input
):
A
=
input
.
shape
[
0
]
shape
=
astensor1d
((
A
,
A
),
self
.
data
,
dtype
=
"int32"
,
device
=
input
.
device
)
x
=
F
.
reshape
(
self
.
data
,
shape
)
o
=
input
+
x
return
o
net
=
Net
()
input
=
megengine
.
tensor
(
np
.
random
.
random
((
4
,
4
)),
dtype
=
np
.
float32
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
func
(
inp
,
*
,
net
=
None
):
return
net
(
inp
)
func
(
input
,
net
=
net
)
file
=
io
.
BytesIO
()
func
.
dump
(
file
,
optimize_for_inference
=
False
)
file
.
seek
(
0
)
*
_
,
outputs
=
mgb_graph
.
load_graph
(
file
)
seq_1
=
cgtools
.
get_oprs_seq
(
outputs
,
True
)
assert
len
(
seq_1
)
==
5
seq_2
=
cgtools
.
get_oprs_seq
(
outputs
,
False
)
assert
len
(
seq_2
)
==
6
imperative/python/test/unit/test_tracing.py
浏览文件 @
8118a594
...
...
@@ -14,7 +14,8 @@ import pytest
import
megengine.core.tensor.megbrain_graph
as
G
import
megengine.functional
as
F
from
megengine
import
cgtools
,
tensor
import
megengine.utils.comp_graph_tools
as
cgtools
from
megengine
import
tensor
from
megengine.core._trace_option
import
set_symbolic_shape
from
megengine.core.ops
import
builtin
as
ops
from
megengine.core.ops.builtin
import
Elemwise
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录