Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c7a8d945
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c7a8d945
编写于
9月 16, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/traced_module): let graph record total_id
GitOrigin-RevId: f99178f3ac45b2fd12828fc65ef995ee369d308c
上级
8b40f577
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
64 addition
and
26 deletion
+64
-26
imperative/python/megengine/traced_module/expr.py
imperative/python/megengine/traced_module/expr.py
+9
-0
imperative/python/megengine/traced_module/node.py
imperative/python/megengine/traced_module/node.py
+9
-4
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+46
-22
未找到文件。
imperative/python/megengine/traced_module/expr.py
浏览文件 @
c7a8d945
...
...
@@ -167,6 +167,15 @@ class Expr:
state
.
pop
(
"_top_graph"
)
return
state
@
classmethod
def
get_total_id
(
cls
):
return
cls
.
__total_id
@
classmethod
def
set_total_id
(
cls
,
id
:
int
=
0
):
assert
isinstance
(
id
,
int
)
cls
.
__total_id
=
id
# expr: None (i.e. fake expression which is used to mark input)
class
Input
(
Expr
):
...
...
imperative/python/megengine/traced_module/node.py
浏览文件 @
c7a8d945
...
...
@@ -42,10 +42,6 @@ class Node:
self
.
_orig_name
=
orig_name
self
.
actual_node
=
[]
# type: List[Node]
def
__setstate__
(
self
,
d
):
self
.
__dict__
=
d
Node
.
__total_id
=
max
(
Node
.
__total_id
,
self
.
_id
)
+
1
def
__repr__
(
self
):
format_spec
=
Node
.
_format_spec
return
self
.
__format__
(
format_spec
)
...
...
@@ -89,6 +85,15 @@ class Node:
cls
.
_format_spec
=
str
return
old_format_spec
@
classmethod
def
get_total_id
(
cls
):
return
cls
.
__total_id
@
classmethod
def
set_total_id
(
cls
,
id
:
int
=
0
):
assert
isinstance
(
id
,
int
)
cls
.
__total_id
=
id
class
ModuleNode
(
Node
):
r
"""``ModuleNode`` represents the Module objects."""
...
...
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
c7a8d945
...
...
@@ -247,6 +247,10 @@ def _init_id2name(mod: Module, prefix: str = ""):
class
_InsertExprs
:
def
__init__
(
self
,
graph
,
expr
:
Optional
[
Expr
]
=
None
):
self
.
graph
=
graph
while
graph
.
top_graph
is
not
None
:
graph
=
graph
.
top_graph
assert
graph
.
inputs
[
0
].
owner
.
_is_top
self
.
root_graph
=
graph
self
.
global_scope
=
InternalGraph
(
graph
.
_name
,
graph
.
_prefix_name
,
graph
.
_module_name
)
...
...
@@ -256,6 +260,9 @@ class _InsertExprs:
def
__enter__
(
self
):
self
.
use_sym_shape
=
set_symbolic_shape
(
True
)
node_id
,
expr_id
=
self
.
root_graph
.
_total_ids
Node
.
set_total_id
(
node_id
)
Expr
.
set_total_id
(
expr_id
)
set_module_tracing
()
_set_convert_node_flag
(
True
)
assert
active_module_tracer
()
is
None
...
...
@@ -334,10 +341,8 @@ class _InsertExprs:
insert_index
+=
1
self
.
graph
.
_used_names
.
update
(
self
.
global_scope
.
_used_names
)
graph
=
self
.
graph
while
graph
.
top_graph
is
not
None
:
graph
=
graph
.
top_graph
graph
.
inputs
[
0
].
owner
.
_update_ref
()
self
.
root_graph
.
_total_ids
=
(
Node
.
get_total_id
(),
Expr
.
get_total_id
())
self
.
root_graph
.
inputs
[
0
].
owner
.
_update_ref
()
return
True
...
...
@@ -353,7 +358,8 @@ class InternalGraph:
_exprs
=
None
# type: List[Expr]
_inputs
=
None
# type: List[Node]
_outputs
=
None
# type: List[Node]
_top_graph
=
None
_top_graph
=
None
# type: InternalGraph
_total_ids
=
None
# type: List[int]
def
__init__
(
self
,
name
:
str
=
None
,
prefix_name
:
str
=
""
,
module_name
:
str
=
""
):
self
.
_exprs
=
[]
...
...
@@ -704,8 +710,12 @@ class InternalGraph:
def
replace_node
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
while
repl_dict
:
node
,
repl_node
=
repl_dict
.
popitem
()
assert
type
(
node
)
==
type
(
repl_node
),
"The type of {}({}) and {}({}) are not the same"
.
format
(
node
,
type
(
node
).
__name__
,
repl_node
,
type
(
repl_node
).
__name__
)
# check graph inputs and outputs
# assert node not in self.inputs, "Cannot replace inputs"
for
i
,
n
in
enumerate
(
self
.
outputs
):
if
n
is
node
:
self
.
outputs
[
i
]
=
repl_node
...
...
@@ -713,7 +723,10 @@ class InternalGraph:
# update inputs of expr in node.users
graph
=
repl_node
.
top_graph
assert
graph
is
not
None
index
=
graph
.
_exprs
.
index
(
repl_node
.
expr
)
assert
graph
is
self
index
=
-
1
if
not
isinstance
(
repl_node
.
expr
,
Input
):
index
=
graph
.
_exprs
.
index
(
repl_node
.
expr
)
dep_exprs
=
self
.
get_dep_exprs
(
repl_node
)
i
=
0
while
i
<
len
(
node
.
users
):
...
...
@@ -745,6 +758,13 @@ class InternalGraph:
n
.
users
.
remove
(
expr
)
self
.
_exprs
.
remove
(
expr
)
def
_reset_ids
(
self
):
for
total_expr_id
,
expr
in
enumerate
(
self
.
exprs
()):
expr
.
_id
=
total_expr_id
for
total_node_id
,
node
in
enumerate
(
self
.
nodes
()):
node
.
_id
=
total_node_id
self
.
_total_ids
=
(
total_node_id
+
1
,
total_expr_id
+
1
)
def
interpret
(
self
,
*
inputs
):
node2value
=
{}
end_nodes_set
=
set
(
self
.
_end_point
)
...
...
@@ -989,6 +1009,8 @@ class TracedModuleBuilder(NodeMixin):
)
for
_
,
g
in
self
.
_argdef_graph_map
.
items
():
g
.
compile
()
if
self
.
_is_top
:
g
.
_total_ids
=
(
Node
.
get_total_id
(),
Expr
.
get_total_id
())
for
k
,
v
in
self
.
__dict__
.
items
():
if
k
not
in
TracedModuleBuilder
.
__builder_attributes__
:
...
...
@@ -1247,6 +1269,8 @@ class _expr_iter:
self
.
recursive
=
recursive
def
__iter__
(
self
):
for
inp_node
in
self
.
graph
.
inputs
:
yield
inp_node
.
expr
for
expr
in
self
.
graph
.
_exprs
:
if
isinstance
(
expr
,
CallMethod
)
and
isinstance
(
expr
.
inputs
[
0
],
ModuleNode
):
yield
expr
...
...
@@ -1262,10 +1286,10 @@ class _node_iter:
node_ids
=
set
()
for
expr
in
graph
.
exprs
(
recursive
):
for
n
in
expr
.
inputs
+
expr
.
outputs
:
if
n
.
_id
in
node_ids
:
if
id
(
n
)
in
node_ids
:
continue
nodes
.
append
(
n
)
node_ids
.
add
(
n
.
_id
)
node_ids
.
add
(
id
(
n
)
)
self
.
nodes
=
list
(
sorted
(
nodes
,
key
=
lambda
x
:
x
.
_id
))
def
__iter__
(
self
):
...
...
@@ -1546,6 +1570,7 @@ class TracedModule(Module):
active_module_tracer
().
push_scope
(
new_module
.
graph
)
def
_flatten_subgraph
(
parent_graph
:
InternalGraph
,
graph
:
InternalGraph
,
module
:
Module
,
call
=
None
,
...
...
@@ -1590,7 +1615,10 @@ class TracedModule(Module):
if
inp
is
call_out
:
expr
.
inputs
[
index
]
=
repl_dict
[
out
]
repl_dict
[
out
].
users
.
append
(
expr
)
if
parent_graph
is
not
None
:
for
index
,
parent_out
in
enumerate
(
parent_graph
.
_outputs
):
if
parent_out
is
call_out
:
parent_graph
.
_outputs
[
index
]
=
repl_dict
[
out
]
continue
repl_dict
[
out
]
=
call
.
outputs
[
ind
]
...
...
@@ -1622,6 +1650,7 @@ class TracedModule(Module):
)
exprs
.
extend
(
_flatten_subgraph
(
graph
,
expr_graph
,
obj
,
expr
,
...
...
@@ -1643,19 +1672,10 @@ class TracedModule(Module):
i
.
users
.
remove
(
call
)
return
exprs
new_module
.
graph
.
_exprs
=
_flatten_subgraph
(
new_module
.
graph
,
new_module
)
new_module
.
graph
.
_exprs
=
_flatten_subgraph
(
None
,
new_module
.
graph
,
new_module
)
new_module
.
graph
.
compile
()
set_active_module_tracer
(
None
)
for
_id
,
expr
in
enumerate
(
new_module
.
graph
.
_exprs
):
expr
.
_id
=
_id
total_node_id
=
0
for
i
in
new_module
.
graph
.
_inputs
:
i
.
_id
=
total_node_id
total_node_id
+=
1
for
expr
in
new_module
.
graph
.
_exprs
:
for
o
in
expr
.
outputs
:
o
.
_id
=
total_node_id
total_node_id
+=
1
new_module
.
graph
.
_reset_ids
()
return
new_module
def
__getstate__
(
self
):
...
...
@@ -1735,6 +1755,8 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
set_active_module_tracer
(
module_tracer
(
_wrapped_function
,
_init_id2name
(
mod
,
"self"
))
)
for
cls
in
[
Expr
,
Node
]:
cls
.
set_total_id
(
0
)
with
active_module_tracer
().
patcher
:
global_scope
=
InternalGraph
(
name
=
""
)
active_module_tracer
().
push_scope
(
global_scope
)
...
...
@@ -1750,7 +1772,9 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
)
builder
(
*
args
,
**
kwargs
)
active_module_tracer
().
pop_scope
()
return
builder
.
build
()
traced_mod
=
builder
.
build
()
traced_mod
.
graph
.
_reset_ids
()
return
traced_mod
finally
:
set_symbolic_shape
(
use_sym_shape
)
set_active_module_tracer
(
None
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录