Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
90107b6d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
90107b6d
编写于
5月 11, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/cgtools): add network vistior interface with optional pruning
GitOrigin-RevId: cfa69e3e83ecbf32d5c4827b4562dc8f65b5d674
上级
270b7488
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
187 addition
and
0 deletion
+187
-0
python_module/megengine/_internal/comp_graph_tools.py
python_module/megengine/_internal/comp_graph_tools.py
+139
-0
python_module/src/swig/comp_graph_tools.i
python_module/src/swig/comp_graph_tools.i
+4
-0
python_module/test/unit/jit/test_jit.py
python_module/test/unit/jit/test_jit.py
+44
-0
未找到文件。
python_module/megengine/_internal/comp_graph_tools.py
浏览文件 @
90107b6d
...
@@ -67,6 +67,145 @@ def get_type(var):
...
@@ -67,6 +67,145 @@ def get_type(var):
return
_mgb
.
_get_owner_opr_type
(
var
)
return
_mgb
.
_get_owner_opr_type
(
var
)
def
get_opr_type
(
opr
):
"""get the type of a opr
:type var: :class:`.Operator`
:rtype: ``str``
"""
assert
isinstance
(
opr
,
_mgb
.
Operator
)
return
_mgb
.
_get_opr_type
(
opr
)
def
graph_traversal
(
outputs
):
"""helper function to traverse the computing graph and reeturn enough useful information
:param outputs: model outputs
:type outputs: :class:`.Symbolvar`
:return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree)
WHERE
map_oprs is dict from opr_id to actual opr
map_vars is dict from var_id to actual var
var2oprs is dict from var to dest oprs along with index
opr2receivers is dict from current opr to next opr
indegree2opr is dict from in_degree to opr in computing graph
opr2indegree is dict from opr in computing graph to in_degree
(indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function
"""
# meta information for comp graph
map_oprs
=
collections
.
defaultdict
(
set
)
map_vars
=
collections
.
defaultdict
(
set
)
var2oprs
=
collections
.
defaultdict
(
list
)
opr2receivers
=
collections
.
defaultdict
(
list
)
queue
=
list
(
map
(
lambda
x
:
x
.
owner_opr
,
outputs
))
visited
=
set
(
map
(
lambda
x
:
x
.
id
,
queue
))
# iterate through whole comp_graph, fill in meta information
indegree2opr
=
collections
.
defaultdict
(
set
)
opr2indegree
=
{}
idx
=
0
while
idx
<
len
(
queue
):
cur_opr
=
queue
[
idx
]
map_oprs
[
cur_opr
.
id
]
=
cur_opr
idx
+=
1
indegree
=
0
for
var_idx
,
var
in
enumerate
(
cur_opr
.
inputs
):
map_vars
[
var
.
id
]
=
var
var2oprs
[
var
.
id
].
append
((
cur_opr
.
id
,
var_idx
))
pre_opr
=
var
.
owner_opr
if
pre_opr
.
id
not
in
visited
:
visited
.
add
(
pre_opr
.
id
)
queue
.
append
(
pre_opr
)
indegree
+=
1
opr2receivers
[
pre_opr
.
id
].
append
(
cur_opr
.
id
)
indegree2opr
[
indegree
].
add
(
cur_opr
.
id
)
opr2indegree
[
cur_opr
.
id
]
=
indegree
return
map_oprs
,
map_vars
,
var2oprs
,
opr2receivers
,
indegree2opr
,
opr2indegree
def
get_oprs_seq
(
outputs
,
prune_reshape
=
False
):
"""get oprs in some topological order for a dumped model
:param outputs: model outputs
:param prune_reshape: whether to prune the operators useless during inference
:return: opr list with some correct execution order
"""
def
topological_sort
(
map_oprs
,
opr2receivers
,
indegree2opr
,
opr2indegree
):
# generate an execution order with topological sort algorithm
oprs_seq
=
[]
nr_remain
=
len
(
map_oprs
)
while
indegree2opr
[
0
]:
opr_id
=
indegree2opr
[
0
].
pop
()
opr
=
map_oprs
[
opr_id
]
nr_remain
-=
1
# skip const value generation operator
if
get_opr_type
(
opr
)
!=
"ImmutableTensor"
:
oprs_seq
.
append
(
opr
)
for
post_id
in
opr2receivers
[
opr_id
]:
indegree
=
opr2indegree
[
post_id
]
indegree2opr
[
indegree
].
remove
(
post_id
)
indegree
-=
1
indegree2opr
[
indegree
].
add
(
post_id
)
opr2indegree
[
post_id
]
=
indegree
assert
nr_remain
==
0
,
"there are {} remaining nodes; cyclic graph?"
.
format
(
nr_remain
)
return
oprs_seq
# reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
# when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
def
prune_reshape_oprs
(
outputs
,
oprs_seq
,
var2oprs
):
def
iterative_pruning
(
cur_opr
,
post_opr
,
marked_opr_ids
):
useless
=
True
for
oup
in
cur_opr
.
outputs
:
if
"workspace"
not
in
oup
.
name
:
var_idx
=
post_opr
.
inputs
.
index
(
oup
)
var2oprs
[
oup
.
id
].
remove
((
post_opr
.
id
,
var_idx
))
useless
=
useless
and
(
len
(
var2oprs
[
oup
.
id
])
==
0
)
if
useless
:
marked_opr_ids
.
append
(
cur_opr
.
id
)
for
inp
in
cur_opr
.
inputs
:
iterative_pruning
(
inp
.
owner_opr
,
cur_opr
,
marked_opr_ids
)
reshape_vars
=
get_dep_vars
(
outputs
,
"Reshape"
)
reshape_oprs
=
[
var
.
owner_opr
for
var
in
reshape_vars
]
marked_opr_ids
=
[]
for
reshape_opr
in
reshape_oprs
:
iterative_pruning
(
reshape_opr
.
inputs
[
1
].
owner_opr
,
reshape_opr
,
marked_opr_ids
)
# filter out all marked oprs
return
list
(
filter
(
lambda
x
:
x
.
id
not
in
marked_opr_ids
,
oprs_seq
))
map_oprs
,
_
,
var2oprs
,
opr2receivers
,
indegree2opr
,
opr2indegree
=
graph_traversal
(
outputs
)
oprs_seq
=
topological_sort
(
map_oprs
,
opr2receivers
,
indegree2opr
,
opr2indegree
)
if
prune_reshape
is
True
:
oprs_seq
=
prune_reshape_oprs
(
outputs
,
oprs_seq
,
var2oprs
.
copy
())
return
oprs_seq
def
replace_vars
(
dst
,
varmap
):
def
replace_vars
(
dst
,
varmap
):
"""replace vars in the graph
"""replace vars in the graph
...
...
python_module/src/swig/comp_graph_tools.i
浏览文件 @
90107b6d
...
@@ -10,6 +10,10 @@
...
@@ -10,6 +10,10 @@
return
var
.
node
()
->
owner_opr
()
->
dyn_typeinfo
()
->
name
;
return
var
.
node
()
->
owner_opr
()
->
dyn_typeinfo
()
->
name
;
}
}
std
::
string
_get_opr_type
(
Operator
opr
)
{
return
opr
.
node
()
->
dyn_typeinfo
()
->
name
;
}
SymbolVarArray
_replace_vars
(
const
SymbolVarArray&
repl_src
,
SymbolVarArray
_replace_vars
(
const
SymbolVarArray&
repl_src
,
const
SymbolVarArray&
repl_dst
,
const
SymbolVarArray&
repl_dst
,
const
SymbolVarArray&
vars
)
{
const
SymbolVarArray&
vars
)
{
...
...
python_module/test/unit/jit/test_jit.py
浏览文件 @
90107b6d
...
@@ -15,6 +15,7 @@ import pytest
...
@@ -15,6 +15,7 @@ import pytest
import
megengine
as
mge
import
megengine
as
mge
import
megengine._internal
as
mgb
import
megengine._internal
as
mgb
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.module
as
M
from
megengine
import
functional
as
F
from
megengine
import
functional
as
F
from
megengine
import
jit
,
tensor
from
megengine
import
jit
,
tensor
...
@@ -148,6 +149,49 @@ def test_dump_volatile():
...
@@ -148,6 +149,49 @@ def test_dump_volatile():
assert
mgb
.
cgtools
.
get_type
(
mgb
.
cgtools
.
get_inputs
(
out
)[
1
])
==
"SharedDeviceTensor"
assert
mgb
.
cgtools
.
get_type
(
mgb
.
cgtools
.
get_inputs
(
out
)[
1
])
==
"SharedDeviceTensor"
def
test_graph_traversal
():
net
=
M
.
Conv2d
(
3
,
4
,
3
,
1
,
1
,
groups
=
1
,
bias
=
False
)
net
.
eval
()
@
jit
.
trace
(
symbolic
=
True
)
def
fun
(
data
):
return
net
(
data
)
data
=
np
.
random
.
random
([
1
,
3
,
224
,
224
]).
astype
(
np
.
float32
)
fun
.
trace
(
data
)
with
mkstemp
()
as
out
:
fun
.
dump
(
out
)
*
_
,
outputs
=
mgb
.
load_comp_graph_from_file
(
out
)
_
,
map_vars
,
var2oprs
,
*
_
=
mgb
.
cgtools
.
graph_traversal
(
outputs
)
input_var
=
map_vars
[
1
]
_
,
var_idx
=
var2oprs
[
input_var
.
id
][
0
]
assert
var_idx
==
0
def
test_network_visitor
():
@
jit
.
trace
(
symbolic
=
True
)
def
f
(
x
):
# this line will produce shape_of, subtensor and concat op
# after pruning, they will be deleted
target_shape
=
(
x
.
shape
[
0
],
-
1
)
return
x
.
reshape
(
*
target_shape
)
f
.
trace
(
tensor
(
np
.
random
.
random
([
2
,
3
,
4
,
5
]).
astype
(
np
.
float32
)))
with
mkstemp
()
as
out
:
f
.
dump
(
out
)
*
_
,
outputs
=
mgb
.
load_comp_graph_from_file
(
out
)
all_oprs
=
mgb
.
cgtools
.
get_oprs_seq
(
outputs
)
pruned_oprs
=
mgb
.
cgtools
.
get_oprs_seq
(
outputs
,
prune_reshape
=
True
)
assert
len
(
all_oprs
)
==
len
(
pruned_oprs
)
+
3
def
test_shape_tracing
():
def
test_shape_tracing
():
for
symbolic
in
[
False
,
True
]:
for
symbolic
in
[
False
,
True
]:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录