Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
feea43bc
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
feea43bc
编写于
10月 26, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/traced_module): associate name with node
GitOrigin-RevId: 8d9a59bade03b62aa8d4821dfb74cc828bf7312c
上级
a6fe7f7f
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
68 addition
and
21 deletion
+68
-21
imperative/python/megengine/traced_module/expr.py
imperative/python/megengine/traced_module/expr.py
+7
-1
imperative/python/megengine/traced_module/node.py
imperative/python/megengine/traced_module/node.py
+1
-2
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+60
-18
未找到文件。
imperative/python/megengine/traced_module/expr.py
浏览文件 @
feea43bc
...
...
@@ -69,6 +69,10 @@ def is_apply_def(expr):
return
isinstance
(
expr
,
Apply
)
def
is_input
(
expr
):
return
isinstance
(
expr
,
Input
)
class
Expr
:
r
"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
``GetAttr``, ``Input``, ``Constant``) on ``Node``.
...
...
@@ -215,9 +219,11 @@ class Input(Expr):
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
assert
active_module_tracer
()
is
not
None
current_graph
=
active_module_tracer
().
current_scope
()
expr
=
cls
(
*
args
,
**
kwargs
)
out_node
=
expr
.
outputs
[
0
]
active_module_tracer
().
current_scope
().
_add_input
(
out_node
)
current_graph
.
_namespace
.
auto_naming_for_outputs
(
expr
)
current_graph
.
_add_input
(
out_node
)
return
expr
.
outputs
[
0
]
def
__repr__
(
self
):
...
...
imperative/python/megengine/traced_module/node.py
浏览文件 @
feea43bc
...
...
@@ -74,8 +74,7 @@ class Node:
"The name(%s) is already in use. Please try a different one again."
%
(
new_name
)
)
new_name
=
graph
.
_namespace
.
create_unique_name
(
new_name
)
self
.
_name
=
new_name
self
.
_name
=
graph
.
_namespace
.
create_unique_name
(
new_name
,
self
)
@
property
def
qualname
(
self
):
...
...
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
feea43bc
...
...
@@ -68,6 +68,7 @@ from .expr import (
is_call_tensor_method
,
is_constant
,
is_getattr
,
is_input
,
)
from
.fake_quant
import
FakeQuantize
as
TM_FakeQuant
from
.module_tracer
import
(
...
...
@@ -342,13 +343,19 @@ class NameSpace:
self
.
qualname
=
qualname
self
.
_used_names
=
{}
def
create_unique_name
(
self
,
name
:
str
)
->
str
:
def
create_unique_name
(
self
,
name
:
str
,
node
:
Any
=
None
)
->
str
:
assert
isinstance
(
name
,
str
),
"The name must be a string"
if
name
in
self
.
_used_names
and
self
.
_used_names
[
name
]
is
node
:
return
name
name
=
re
.
sub
(
"[^0-9a-zA-Z_]+"
,
"_"
,
name
)
if
name
[
0
].
isdigit
():
name
=
"_{}"
.
format
(
name
)
while
name
in
self
.
_used_names
or
_is_builtin_name
(
name
):
while
(
name
in
self
.
_used_names
and
self
.
_used_names
[
name
]
is
not
None
)
or
_is_builtin_name
(
name
):
match
=
re
.
match
(
r
"(.*)_(\d+)$"
,
name
)
if
match
is
None
:
name
=
name
+
"_1"
...
...
@@ -357,6 +364,10 @@ class NameSpace:
name
=
"{}_{}"
.
format
(
base
,
int
(
num
)
+
1
)
self
.
_used_names
.
setdefault
(
name
)
if
node
is
not
None
:
self
.
associate_name_with_obj
(
name
,
node
)
return
name
def
auto_naming_for_outputs
(
self
,
expr
:
Expr
):
...
...
@@ -384,7 +395,7 @@ class NameSpace:
qualname
=
"{}.{}"
.
format
(
expr
.
inputs
[
0
].
qualname
,
expr
.
name
)
name
=
get_suffix_name
(
self
.
qualname
,
qualname
)
_add_suffix
=
lambda
x
:
x
elif
is_constant
(
expr
):
elif
is_constant
(
expr
)
or
is_input
(
expr
)
:
name
=
(
expr
.
name
if
expr
.
name
else
"const_"
+
type
(
expr
.
value
).
__name__
.
lower
()
)
...
...
@@ -392,16 +403,25 @@ class NameSpace:
_add_suffix
=
lambda
x
:
x
for
node
in
expr
.
outputs
:
if
node
.
_name
==
""
or
node
.
_name
in
self
.
used_names
:
assert
_add_suffix
(
name
)
==
name
or
isinstance
(
node
,
TensorNode
)
node
.
_name
=
self
.
create_unique_name
(
_add_suffix
(
name
))
cur_name
=
node
.
_name
if
node
.
_name
else
_add_suffix
(
name
)
node
.
_name
=
self
.
create_unique_name
(
cur_name
,
node
)
if
node
.
_qualname
==
""
:
node
.
_qualname
=
qualname
assert
get_suffix_name
(
self
.
qualname
,
qualname
)
assert
get_suffix_name
(
self
.
qualname
,
qualname
)
is
not
None
def
merge
(
self
,
other
:
"NameSpace"
):
self
.
_used_names
.
update
(
other
.
used_names
)
def
associate_name_with_obj
(
self
,
name
:
str
,
node
:
Node
):
assert
name
in
self
.
used_names
assert
self
.
used_names
[
name
]
is
None
,
"The name(%s) is already in use"
%
(
name
)
self
.
_used_names
[
name
]
=
node
def
unassociate_name_with_obj
(
self
,
node
:
Node
):
assert
node
.
name
in
self
.
used_names
assert
self
.
used_names
[
node
.
name
]
is
node
self
.
_used_names
[
node
.
name
]
=
None
@
property
def
used_names
(
self
):
return
self
.
_used_names
...
...
@@ -487,7 +507,7 @@ class InternalGraph:
"The name(%s) is already in use. Please try a different one again."
%
(
new_name
)
)
new_name
=
self
.
_namespace
.
create_unique_name
(
new_name
)
new_name
=
self
.
_namespace
.
create_unique_name
(
new_name
,
self
)
self
.
_name
=
new_name
@
property
...
...
@@ -726,6 +746,7 @@ class InternalGraph:
node
=
Input
(
type
=
TensorNode
,
name
=
name
,
qualname
=
"%s.[%s]"
%
(
self
.
_qualname
,
name
)
).
outputs
[
0
]
self
.
_namespace
.
associate_name_with_obj
(
node
.
name
,
node
)
node
.
shape
=
val
.
shape
node
.
dtype
=
val
.
dtype
return
node
...
...
@@ -764,9 +785,11 @@ class InternalGraph:
assert
moudle
.
_is_top
,
"add_input_node only supports top graph"
def
create_node
(
name
=
None
):
name
=
self
.
_namespace
.
create_unique_name
(
name
)
node
=
Input
(
type
=
TensorNode
,
name
=
name
,
qualname
=
"%s.[%s]"
%
(
self
.
_qualname
,
name
)
).
outputs
[
0
]
self
.
_namespace
.
associate_name_with_obj
(
node
.
name
,
node
)
node
.
shape
=
shape
node
.
dtype
=
dtype
return
node
...
...
@@ -774,7 +797,7 @@ class InternalGraph:
org_argdef
=
list
(
moudle
.
argdef_graph_map
.
keys
())[
0
]
args
,
kwargs
=
org_argdef
.
unflatten
(
self
.
_inputs
)
formal_inp_node
=
create_node
(
self
.
_namespace
.
create_unique_name
(
name
)
)
formal_inp_node
=
create_node
(
name
)
inputs
,
tree_def
=
tree_flatten
(
((
*
args
,
formal_inp_node
),
kwargs
),
is_const_leaf
=
lambda
x
:
not
isinstance
(
x
,
(
TensorNode
,
ModuleNode
)),
...
...
@@ -1006,6 +1029,8 @@ class InternalGraph:
for
n
in
expr
.
inputs
:
n
.
users
.
remove
(
expr
)
self
.
_exprs
.
remove
(
expr
)
for
n
in
expr
.
outputs
:
self
.
_namespace
.
unassociate_name_with_obj
(
n
)
def
_reset_ids
(
self
):
for
total_expr_id
,
expr
in
enumerate
(
self
.
exprs
()):
...
...
@@ -1014,6 +1039,11 @@ class InternalGraph:
node
.
_id
=
total_node_id
self
.
_total_ids
=
(
total_node_id
+
1
,
total_expr_id
+
1
)
def
_re_associate_name
(
self
):
self
.
_namespace
.
used_names
.
clear
()
for
node
in
self
.
nodes
(
False
):
node
.
_name
=
self
.
_namespace
.
create_unique_name
(
node
.
name
,
node
)
def
interpret
(
self
,
*
inputs
):
node2value
=
{}
end_nodes_set
=
set
(
self
.
_end_point
)
...
...
@@ -1108,6 +1138,8 @@ class InternalGraph:
if
n
.
_qualname
:
qualname
=
"{}.{}"
.
format
(
qualname
,
n
.
_qualname
)
n
.
_qualname
=
qualname
self
.
_namespace
=
NameSpace
(
self
.
_name
,
self
.
_qualname
)
self
.
_re_associate_name
()
def
_get_meth_name
(
obj
,
func
):
...
...
@@ -1372,6 +1404,7 @@ class TracedModuleBuilder(NodeMixin):
continue
for
g
in
mod
.
argdef_graph_map
.
values
():
replace_qualname
(
g
)
g
.
_namespace
.
qualname
=
g
.
qualname
for
n
in
g
.
nodes
(
False
):
replace_qualname
(
n
)
else
:
...
...
@@ -1383,6 +1416,7 @@ class TracedModuleBuilder(NodeMixin):
name
=
parent_graph
.
_namespace
.
create_unique_name
(
module_qualname
),
qualname
=
module_qualname
,
)
parent_graph
.
_namespace
.
associate_name_with_obj
(
self
.
_body
.
name
,
self
.
_body
)
active_module_tracer
().
push_scope
(
self
.
_body
)
# rebind self to new input node
...
...
@@ -1552,6 +1586,7 @@ class _expr_iter:
def
__init__
(
self
,
graph
:
InternalGraph
,
recursive
:
bool
=
True
):
self
.
graph
=
graph
self
.
recursive
=
recursive
self
.
_visited_graph
=
set
()
def
__iter__
(
self
):
for
inp_node
in
self
.
graph
.
inputs
:
...
...
@@ -1559,8 +1594,13 @@ class _expr_iter:
for
expr
in
self
.
graph
.
_exprs
:
if
isinstance
(
expr
,
CallMethod
)
and
isinstance
(
expr
.
inputs
[
0
],
ModuleNode
):
yield
expr
if
self
.
recursive
and
expr
.
graph
is
not
None
:
if
(
self
.
recursive
and
expr
.
graph
is
not
None
and
id
(
expr
.
graph
)
not
in
self
.
_visited_graph
):
yield
from
expr
.
graph
.
exprs
(
self
.
recursive
)
self
.
_visited_graph
.
add
(
id
(
expr
.
graph
))
else
:
yield
expr
...
...
@@ -1570,12 +1610,11 @@ class _node_iter:
nodes
=
[]
node_ids
=
set
()
for
expr
in
graph
.
exprs
(
recursive
):
for
n
in
expr
.
inputs
+
expr
.
outputs
:
if
id
(
n
)
in
node_ids
:
continue
for
n
in
expr
.
outputs
:
assert
id
(
n
)
not
in
node_ids
nodes
.
append
(
n
)
node_ids
.
add
(
id
(
n
))
self
.
nodes
=
list
(
sorted
(
nodes
,
key
=
lambda
x
:
x
.
_id
))
self
.
nodes
=
nodes
def
__iter__
(
self
):
for
node
in
self
.
nodes
:
...
...
@@ -2076,10 +2115,12 @@ class TracedModule(Module):
if
parent_graph
is
not
None
:
for
node
in
expr
.
outputs
:
if
node
in
rename_blacklist
:
continue
name
=
"{}_{}"
.
format
(
prefix_name
,
node
.
_name
)
node
.
_name
=
parent_graph
.
_namespace
.
create_unique_name
(
name
)
name
=
node
.
_name
if
node
not
in
rename_blacklist
:
name
=
"{}_{}"
.
format
(
prefix_name
,
name
)
node
.
_name
=
parent_graph
.
_namespace
.
create_unique_name
(
name
,
node
)
exprs
.
append
(
expr
)
...
...
@@ -2092,6 +2133,7 @@ class TracedModule(Module):
new_module
.
graph
.
_exprs
=
_flatten_subgraph
(
None
,
new_module
.
graph
,
None
,
new_module
)
new_module
.
graph
.
_re_associate_name
()
new_module
.
graph
.
compile
()
new_module
.
graph
.
_reset_ids
()
return
new_module
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录