Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2d42455f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2d42455f
编写于
6月 01, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/utils): fix toposort to get definition order
GitOrigin-RevId: 47a26dd6dda31d349f7439c64a72027e6a9a7391
上级
0c97b2a3
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
170 addition
and
27 deletion
+170
-27
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+4
-0
imperative/python/megengine/utils/comp_graph_tools.py
imperative/python/megengine/utils/comp_graph_tools.py
+71
-7
imperative/python/megengine/utils/network.py
imperative/python/megengine/utils/network.py
+4
-3
imperative/python/megengine/utils/network_node.py
imperative/python/megengine/utils/network_node.py
+42
-16
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+8
-1
imperative/python/test/unit/utils/test_cgtools.py
imperative/python/test/unit/utils/test_cgtools.py
+41
-0
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
2d42455f
...
@@ -893,6 +893,10 @@ class trace:
...
@@ -893,6 +893,10 @@ class trace:
if
isinstance
(
file
,
str
):
if
isinstance
(
file
,
str
):
permission
=
"wb"
if
append
==
False
else
"ab"
permission
=
"wb"
if
append
==
False
else
"ab"
file
=
open
(
file
,
permission
)
file
=
open
(
file
,
permission
)
if
keep_opr_priority
:
graph
.
_set_priority_to_id
(
dest_vars
)
dump_content
,
dump_info
=
G
.
dump_graph
(
dump_content
,
dump_info
=
G
.
dump_graph
(
dest_vars
,
dest_vars
,
keep_var_name
=
keep_var_name
,
keep_var_name
=
keep_var_name
,
...
...
imperative/python/megengine/utils/comp_graph_tools.py
浏览文件 @
2d42455f
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
import
collections
import
heapq
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Tuple
,
Union
...
@@ -88,6 +89,41 @@ def get_opr_type(opr: _OpNode) -> str:
...
@@ -88,6 +89,41 @@ def get_opr_type(opr: _OpNode) -> str:
return
opr
.
type
return
opr
.
type
class
_OprStableOrderHeapq
:
"""heap implementation for operator comparison in stable order"""
_list
=
None
_extra_priority
=
None
_used_id_name_pairs
=
None
def
__init__
(
self
,
extra_priority
):
assert
isinstance
(
extra_priority
,
collections
.
Callable
)
self
.
_list
=
[]
self
.
_extra_priority
=
extra_priority
self
.
_used_id_name_pairs
=
{}
def
pop_min
(
self
):
return
heapq
.
heappop
(
self
.
_list
)[
-
1
]
def
add
(
self
,
opr
):
# named as add to mimic set() interface
id_
=
opr
.
id
name
=
opr
.
name
other
=
self
.
_used_id_name_pairs
.
setdefault
((
id_
,
name
),
opr
)
if
other
is
not
opr
:
raise
RuntimeError
(
"duplicated (id, name) pair: opr0={} opr1={}"
.
format
(
other
,
opr
)
)
item
=
self
.
_extra_priority
(
opr
)
+
(
id_
,
name
,
opr
)
heapq
.
heappush
(
self
.
_list
,
item
)
def
__bool__
(
self
):
return
bool
(
self
.
_list
)
def
graph_traversal
(
outputs
:
_VarNode
):
def
graph_traversal
(
outputs
:
_VarNode
):
"""
"""
Helper function to traverse the computing graph and return enough useful information.
Helper function to traverse the computing graph and return enough useful information.
...
@@ -110,12 +146,13 @@ def graph_traversal(outputs: _VarNode):
...
@@ -110,12 +146,13 @@ def graph_traversal(outputs: _VarNode):
var2oprs
=
collections
.
defaultdict
(
list
)
var2oprs
=
collections
.
defaultdict
(
list
)
opr2receivers
=
collections
.
defaultdict
(
list
)
opr2receivers
=
collections
.
defaultdict
(
list
)
queue
=
[]
queue
=
list
(
set
(
map
(
lambda
x
:
x
.
owner
,
outputs
)))
[
queue
.
append
(
o
)
for
o
in
[
x
.
owner
for
x
in
outputs
]
if
o
not
in
queue
]
visited
=
set
(
map
(
lambda
x
:
x
.
id
,
queue
))
visited
=
set
(
map
(
lambda
x
:
x
.
id
,
queue
))
# iterate through whole comp_graph, fill in meta information
# iterate through whole comp_graph, fill in meta information
indegree2opr
=
collections
.
defaultdict
(
set
)
indegree2opr
=
collections
.
defaultdict
(
set
)
indegree2opr
[
0
]
=
_OprStableOrderHeapq
(
lambda
op
:
(
op
.
priority
,))
opr2indegree
=
{}
opr2indegree
=
{}
idx
=
0
idx
=
0
...
@@ -138,8 +175,8 @@ def graph_traversal(outputs: _VarNode):
...
@@ -138,8 +175,8 @@ def graph_traversal(outputs: _VarNode):
indegree
+=
1
indegree
+=
1
opr2receivers
[
pre_opr
.
id
].
append
(
cur_opr
.
id
)
opr2receivers
[
pre_opr
.
id
].
append
(
cur_opr
.
id
)
opr
=
cur_opr
if
indegree
==
0
else
cur_opr
.
id
indegree2opr
[
indegree
].
add
(
cur_opr
.
id
)
indegree2opr
[
indegree
].
add
(
opr
)
opr2indegree
[
cur_opr
.
id
]
=
indegree
opr2indegree
[
cur_opr
.
id
]
=
indegree
return
map_oprs
,
map_vars
,
var2oprs
,
opr2receivers
,
indegree2opr
,
opr2indegree
return
map_oprs
,
map_vars
,
var2oprs
,
opr2receivers
,
indegree2opr
,
opr2indegree
...
@@ -162,8 +199,8 @@ def get_oprs_seq(
...
@@ -162,8 +199,8 @@ def get_oprs_seq(
oprs_seq
=
[]
oprs_seq
=
[]
nr_remain
=
len
(
map_oprs
)
nr_remain
=
len
(
map_oprs
)
while
indegree2opr
[
0
]:
while
indegree2opr
[
0
]:
opr
_id
=
indegree2opr
[
0
].
pop
()
opr
=
indegree2opr
[
0
].
pop_min
()
opr
=
map_oprs
[
opr_id
]
opr
_id
=
opr
.
id
nr_remain
-=
1
nr_remain
-=
1
if
opr
.
type
!=
"ImmutableTensor"
or
not
prune_immtensor
:
if
opr
.
type
!=
"ImmutableTensor"
or
not
prune_immtensor
:
oprs_seq
.
append
(
opr
)
oprs_seq
.
append
(
opr
)
...
@@ -173,6 +210,9 @@ def get_oprs_seq(
...
@@ -173,6 +210,9 @@ def get_oprs_seq(
indegree2opr
[
indegree
].
remove
(
post_id
)
indegree2opr
[
indegree
].
remove
(
post_id
)
indegree
-=
1
indegree
-=
1
if
indegree
==
0
:
indegree2opr
[
indegree
].
add
(
map_oprs
[
post_id
])
else
:
indegree2opr
[
indegree
].
add
(
post_id
)
indegree2opr
[
indegree
].
add
(
post_id
)
opr2indegree
[
post_id
]
=
indegree
opr2indegree
[
post_id
]
=
indegree
...
@@ -213,10 +253,34 @@ def get_oprs_seq(
...
@@ -213,10 +253,34 @@ def get_oprs_seq(
# filter out all marked oprs
# filter out all marked oprs
return
list
(
filter
(
lambda
x
:
x
.
id
not
in
marked_opr_ids
,
oprs_seq
))
return
list
(
filter
(
lambda
x
:
x
.
id
not
in
marked_opr_ids
,
oprs_seq
))
# adjust the order of oprs, let param/data privoder oprs close to the oprs which use them as inputs.
def
reorder_oprs_seq
(
oprs
):
rst
=
[]
param_or_data_provider_oprs
=
[]
other_oprs
=
[]
for
o
in
oprs
:
if
o
.
type
in
[
"ImmutableTensor"
,
"Host2DeviceCopy"
]:
param_or_data_provider_oprs
.
append
(
o
)
else
:
other_oprs
.
append
(
o
)
for
o
in
other_oprs
:
for
inp
in
o
.
inputs
:
if
inp
.
owner
.
type
in
[
"ImmutableTensor"
,
"Host2DeviceCopy"
]:
if
inp
.
owner
in
param_or_data_provider_oprs
:
rst
.
append
(
inp
.
owner
)
param_or_data_provider_oprs
.
remove
(
inp
.
owner
)
rst
.
append
(
o
)
rst
=
rst
+
param_or_data_provider_oprs
assert
len
(
rst
)
==
len
(
oprs
)
return
rst
map_oprs
,
_
,
var2oprs
,
opr2receivers
,
indegree2opr
,
opr2indegree
=
graph_traversal
(
map_oprs
,
_
,
var2oprs
,
opr2receivers
,
indegree2opr
,
opr2indegree
=
graph_traversal
(
outputs
outputs
)
)
oprs_seq
=
topological_sort
(
map_oprs
,
opr2receivers
,
indegree2opr
,
opr2indegree
)
oprs_seq
=
topological_sort
(
map_oprs
,
opr2receivers
,
indegree2opr
,
opr2indegree
)
oprs_seq
=
reorder_oprs_seq
(
oprs_seq
)
if
prune_reshape
is
True
:
if
prune_reshape
is
True
:
oprs_seq
=
prune_reshape_oprs
(
outputs
,
oprs_seq
,
var2oprs
.
copy
())
oprs_seq
=
prune_reshape_oprs
(
outputs
,
oprs_seq
,
var2oprs
.
copy
())
return
oprs_seq
return
oprs_seq
...
...
imperative/python/megengine/utils/network.py
浏览文件 @
2d42455f
...
@@ -241,6 +241,7 @@ class Network:
...
@@ -241,6 +241,7 @@ class Network:
if
optimize_for_inference
:
if
optimize_for_inference
:
metadata
.
optimize_options
=
optimize_options
metadata
.
optimize_options
=
optimize_options
G
.
set_priority_to_id
([
o
.
_node
if
isinstance
(
o
,
G
.
VarNode
)
else
o
for
o
in
out
])
dump_content
,
_
=
G
.
dump_graph
(
dump_content
,
_
=
G
.
dump_graph
(
out
,
out
,
keep_var_name
=
keep_var_name
,
keep_var_name
=
keep_var_name
,
...
@@ -353,7 +354,7 @@ class Network:
...
@@ -353,7 +354,7 @@ class Network:
)
)
shp
[
0
]
=
batchsize
shp
[
0
]
=
batchsize
i
.
shape
=
tuple
(
shp
)
i
.
shape
=
tuple
(
shp
)
self
.
_compile
()
assert
prev_batchsize
is
not
None
,
"no data provider found"
assert
prev_batchsize
is
not
None
,
"no data provider found"
assert
not
blacklist
,
"unused items in blacklist: {}"
.
format
(
blacklist
)
assert
not
blacklist
,
"unused items in blacklist: {}"
.
format
(
blacklist
)
...
@@ -363,7 +364,6 @@ class Network:
...
@@ -363,7 +364,6 @@ class Network:
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars.
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars.
"""
"""
if
not
all
([
var
.
owner
for
var
in
repl_dict
.
values
()]):
if
not
all
([
var
.
owner
for
var
in
repl_dict
.
values
()]):
print
(
repl_dict
.
values
())
self
.
add_dep_oprs
(
*
list
(
repl_dict
.
values
()))
self
.
add_dep_oprs
(
*
list
(
repl_dict
.
values
()))
for
var
in
self
.
all_vars
:
for
var
in
self
.
all_vars
:
if
var
in
repl_dict
:
if
var
in
repl_dict
:
...
@@ -373,6 +373,7 @@ class Network:
...
@@ -373,6 +373,7 @@ class Network:
owner
.
outputs
[
idx
]
=
var
owner
.
outputs
[
idx
]
=
var
var
.
__dict__
.
update
(
repl_var
.
__dict__
)
var
.
__dict__
.
update
(
repl_var
.
__dict__
)
var
.
var
=
repl_var
.
var
var
.
var
=
repl_var
.
var
self
.
_compile
()
def
replace_oprs
(
self
,
repl_dict
:
Dict
[
OpNode
,
OpNode
]):
def
replace_oprs
(
self
,
repl_dict
:
Dict
[
OpNode
,
OpNode
]):
"""
"""
...
@@ -384,11 +385,11 @@ class Network:
...
@@ -384,11 +385,11 @@ class Network:
assert
len
(
opr
.
outputs
)
==
len
(
assert
len
(
opr
.
outputs
)
==
len
(
repl_dict
[
opr
].
outputs
repl_dict
[
opr
].
outputs
),
"can not replace {} with {}"
.
format
(
type
(
opr
),
type
(
repl_dict
[
opr
]))
),
"can not replace {} with {}"
.
format
(
type
(
opr
),
type
(
repl_dict
[
opr
]))
repl_dict
[
opr
].
outputs
=
opr
.
outputs
for
ind
,
var
in
enumerate
(
opr
.
outputs
):
for
ind
,
var
in
enumerate
(
opr
.
outputs
):
var
.
owner
=
repl_dict
[
opr
]
var
.
owner
=
repl_dict
[
opr
]
var
.
__dict__
.
update
(
repl_dict
[
opr
].
outputs
[
ind
].
__dict__
)
var
.
__dict__
.
update
(
repl_dict
[
opr
].
outputs
[
ind
].
__dict__
)
var
.
var
=
repl_dict
[
opr
].
outputs
[
ind
].
var
var
.
var
=
repl_dict
[
opr
].
outputs
[
ind
].
var
self
.
_compile
()
def
get_opr_by_type
(
self
,
oprcls
,
unique
=
True
):
def
get_opr_by_type
(
self
,
oprcls
,
unique
=
True
):
assert
issubclass
(
oprcls
,
OpNode
)
assert
issubclass
(
oprcls
,
OpNode
)
...
...
imperative/python/megengine/utils/network_node.py
浏览文件 @
2d42455f
...
@@ -90,6 +90,10 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
...
@@ -90,6 +90,10 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
def
dtype
(
self
):
def
dtype
(
self
):
return
self
.
var
.
dtype
if
self
.
var
else
None
return
self
.
var
.
dtype
if
self
.
var
else
None
@
property
def
ndim
(
self
):
return
super
().
ndim
def
__bool__
(
self
):
def
__bool__
(
self
):
return
False
return
False
...
@@ -134,7 +138,18 @@ class OpNode(NetworkNode):
...
@@ -134,7 +138,18 @@ class OpNode(NetworkNode):
self
.
outputs
=
[]
self
.
outputs
=
[]
self
.
params
=
{}
self
.
params
=
{}
self
.
_opr
=
None
# mgb opnode
self
.
_opr
=
None
# mgb opnode
self
.
id
=
id
(
self
)
@
property
def
id
(
self
):
if
self
.
_opr
is
not
None
:
return
self
.
_opr
.
id
return
id
(
self
)
@
property
def
priority
(
self
):
if
self
.
_opr
is
not
None
:
return
self
.
_opr
.
priority
return
0
@
classmethod
@
classmethod
def
load
(
cls
,
opr
):
def
load
(
cls
,
opr
):
...
@@ -144,7 +159,12 @@ class OpNode(NetworkNode):
...
@@ -144,7 +159,12 @@ class OpNode(NetworkNode):
obj
.
_opr
=
opr
obj
.
_opr
=
opr
return
obj
return
obj
def
compile
(
self
,
graph
=
None
):
def
compile
(
self
):
if
(
self
.
_opr
is
None
or
len
(
self
.
_opr
.
inputs
)
!=
len
(
self
.
inputs
)
or
any
([
i
!=
j
.
var
for
i
,
j
in
zip
(
self
.
_opr
.
inputs
,
self
.
inputs
)])
):
op
=
self
.
opdef
(
**
self
.
params
)
op
=
self
.
opdef
(
**
self
.
params
)
args
=
[
i
.
var
for
i
in
self
.
inputs
]
args
=
[
i
.
var
for
i
in
self
.
inputs
]
outputs
=
rt
.
invoke_op
(
op
,
args
)
outputs
=
rt
.
invoke_op
(
op
,
args
)
...
@@ -197,6 +217,12 @@ class Host2DeviceCopy(OpNode):
...
@@ -197,6 +217,12 @@ class Host2DeviceCopy(OpNode):
return
self
return
self
def
compile
(
self
,
graph
):
def
compile
(
self
,
graph
):
if
(
self
.
_opr
is
None
or
self
.
_opr
.
outputs
[
0
].
comp_node
!=
self
.
device
or
self
.
_opr
.
outputs
[
0
].
shape
!=
self
.
shape
or
self
.
_opr
.
outputs
[
0
].
dtype
!=
self
.
dtype
):
outputs
=
rt
.
make_h2d
(
graph
,
self
.
device
,
self
.
dtype
,
self
.
shape
,
self
.
name
)
outputs
=
rt
.
make_h2d
(
graph
,
self
.
device
,
self
.
dtype
,
self
.
shape
,
self
.
name
)
self
.
_opr
=
outputs
.
owner
self
.
_opr
=
outputs
.
owner
if
len
(
self
.
outputs
)
==
0
:
if
len
(
self
.
outputs
)
==
0
:
...
...
imperative/python/src/graph_rt.cpp
浏览文件 @
2d42455f
...
@@ -192,6 +192,13 @@ void init_graph_rt(py::module m) {
...
@@ -192,6 +192,13 @@ void init_graph_rt(py::module m) {
})
})
.
def
(
"__repr__"
,
[](
cg
::
OperatorNodeBase
*
opr
){
.
def
(
"__repr__"
,
[](
cg
::
OperatorNodeBase
*
opr
){
return
"Opr:"
+
opr
->
name
();
return
"Opr:"
+
opr
->
name
();
})
.
def_property
(
"priority"
,
[](
cg
::
OperatorNodeBase
*
opr
)
{
return
opr
->
node_prop
().
attribute
().
priority
;
},
[](
cg
::
OperatorNodeBase
*
opr
,
int
priority
)
{
opr
->
node_prop
().
attribute
().
priority
=
priority
;
});
});
py
::
class_
<
cg
::
AsyncExecutable
>
(
m
,
"AsyncExecutable"
)
py
::
class_
<
cg
::
AsyncExecutable
>
(
m
,
"AsyncExecutable"
)
...
...
imperative/python/test/unit/utils/test_cgtools.py
浏览文件 @
2d42455f
...
@@ -19,6 +19,7 @@ from megengine.core.tensor import megbrain_graph as mgb_graph
...
@@ -19,6 +19,7 @@ from megengine.core.tensor import megbrain_graph as mgb_graph
from
megengine.core.tensor.megbrain_graph
import
apply_normal_varnode
from
megengine.core.tensor.megbrain_graph
import
apply_normal_varnode
from
megengine.core.tensor.utils
import
astensor1d
from
megengine.core.tensor.utils
import
astensor1d
from
megengine.jit
import
trace
from
megengine.jit
import
trace
from
megengine.utils.network
import
Network
def
make_dev_tensor
(
value
,
dtype
=
None
,
device
=
None
):
def
make_dev_tensor
(
value
,
dtype
=
None
,
device
=
None
):
...
@@ -143,6 +144,46 @@ def test_get_opr_seq():
...
@@ -143,6 +144,46 @@ def test_get_opr_seq():
assert
len
(
seq_2
)
==
6
assert
len
(
seq_2
)
==
6
def
test_topological_sort
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
func
(
x
,
y
):
a
=
x
+
y
a1
=
F
.
relu
(
a
)
a2
=
F
.
abs
(
a
)
a3
=
F
.
ceil
(
a
)
*
2
a4
=
F
.
floor
(
a
)
r
=
a1
-
a2
r1
=
a3
/
a4
return
r
,
r1
file
=
io
.
BytesIO
()
func
(
megengine
.
tensor
(
1.0
),
megengine
.
tensor
(
2.0
))
func
.
dump
(
file
,
optimize_for_inference
=
False
,
keep_opr_name
=
True
,
keep_opr_priority
=
True
)
file
.
seek
(
0
)
g
=
Network
.
load
(
file
)
oprseq1
=
g
.
all_oprs
gt
=
[
"Host2DeviceCopy"
,
"Host2DeviceCopy"
,
"ADD"
,
"RELU"
,
"ABS"
,
"CEIL"
,
"ImmutableTensor"
,
"MUL"
,
"FLOOR"
,
"SUB"
,
"TRUE_DIV"
,
]
for
op
,
mode
in
zip
(
oprseq1
,
gt
):
if
op
.
type
==
"Elemwise"
:
assert
op
.
params
[
"mode"
]
==
mode
else
:
assert
op
.
type
==
mode
def
test_graph_function
():
def
test_graph_function
():
class
Net
(
M
.
Module
):
class
Net
(
M
.
Module
):
def
forward
(
self
,
a
,
b
):
def
forward
(
self
,
a
,
b
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录