Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e6c271ae
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e6c271ae
编写于
11月 24, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/traced_module): fix some bugs for graph surgery
GitOrigin-RevId: 6328a84cbc8554c847530ce990331145ca82e043
上级
2d54ad18
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
139 addition
and
77 deletion
+139
-77
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+112
-75
imperative/python/test/unit/traced_module/test_modification.py
...ative/python/test/unit/traced_module/test_modification.py
+27
-2
未找到文件。
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
e6c271ae
...
...
@@ -122,18 +122,18 @@ def _is_leaf(node):
return
isinstance
(
node
,
RawTensor
)
_enable_
node_to_tensor
=
False
_enable_
graph_surgery_mode
=
False
def
_
convert_node_flag
():
return
_enable_
node_to_tensor
def
_
graph_surgery_mode
():
return
_enable_
graph_surgery_mode
def
_set_
convert_node_flag
(
flag
:
bool
=
False
):
global
_enable_
node_to_tensor
pre_
flag
=
_enable_node_to_tensor
_enable_
node_to_tensor
=
flag
return
pre_
flag
def
_set_
graph_surgery_mode
(
mode
:
bool
):
global
_enable_
graph_surgery_mode
pre_
mode
=
_enable_graph_surgery_mode
_enable_
graph_surgery_mode
=
mode
return
pre_
mode
def
_node_to_tensor
(
*
args
,
**
kwargs
):
...
...
@@ -145,11 +145,11 @@ def _node_to_tensor(*args, **kwargs):
active_module_tracer
().
current_scope
().
_add_input
(
n
)
value
=
n
.
value
if
value
is
None
:
flag
=
_set_
convert_node_flag
(
False
)
flag
=
_set_
graph_surgery_mode
(
False
)
unset_module_tracing
()
value
=
F
.
zeros
(
shape
=
n
.
_shape
,
dtype
=
n
.
_dtype
)
set_module_tracing
()
_set_
convert_node_flag
(
flag
)
_set_
graph_surgery_mode
(
flag
)
orig_n
=
NodeMixin
.
get
(
value
,
None
)
if
orig_n
is
None
or
"setitem"
not
in
orig_n
.
_name
:
NodeMixin
.
wrap_safe
(
value
,
n
)
...
...
@@ -180,17 +180,25 @@ def _tensor_to_node(tensors):
def
_wrap_method_to_tensor_node
():
def
_any_method
(
name
):
def
_any_method
(
name
,
func
):
def
_any
(
*
args
,
**
kwargs
):
args
,
kwargs
=
_node_to_tensor
(
*
args
,
**
kwargs
)
attr
=
getattr
(
args
[
0
],
name
)
outs
=
attr
if
callable
(
attr
):
outs
=
attr
(
*
(
args
[
1
:]),
**
kwargs
)
if
name
==
"__setitem__"
:
_node_to_tensor
(
outs
)
return
None
outs
=
_tensor_to_node
(
outs
)
if
is_tracing_module
()
and
_graph_surgery_mode
():
args
,
kwargs
=
_node_to_tensor
(
*
args
,
**
kwargs
)
attr
=
getattr
(
args
[
0
],
name
)
outs
=
attr
if
callable
(
attr
):
outs
=
attr
(
*
(
args
[
1
:]),
**
kwargs
)
if
name
==
"__setitem__"
:
_node_to_tensor
(
outs
)
return
None
outs
=
_tensor_to_node
(
outs
)
return
outs
else
:
outs
=
func
if
callable
(
func
):
outs
=
func
(
*
args
,
**
kwargs
)
if
isinstance
(
func
,
property
):
outs
=
func
.
__get__
(
*
args
,
**
kwargs
)
return
outs
return
_any
...
...
@@ -199,9 +207,9 @@ def _wrap_method_to_tensor_node():
for
method
in
get_tensor_wrapable_method
():
patch
=
PatchedFn
(
TensorNode
,
method
)
if
type
(
getattr
(
Tensor
,
method
))
==
property
:
patch
.
set_func
(
property
(
_any_method
(
method
)))
patch
.
set_func
(
property
(
_any_method
(
method
,
patch
.
origin_fn
)))
else
:
patch
.
set_func
(
_any_method
(
method
))
patch
.
set_func
(
_any_method
(
method
,
patch
.
origin_fn
))
tensor_method_patch
.
append
(
patch
)
return
tensor_method_patch
...
...
@@ -209,7 +217,7 @@ def _wrap_method_to_tensor_node():
def
_convert_node_and_tensor
(
orig_func
):
@
functools
.
wraps
(
orig_func
)
def
_convert
(
*
args
,
**
kwargs
):
if
_convert_node_flag
()
and
is_tracing_modul
e
():
if
is_tracing_module
()
and
_graph_surgery_mod
e
():
args
,
kwargs
=
_node_to_tensor
(
*
args
,
**
kwargs
)
rst
=
orig_func
(
*
args
,
**
kwargs
,
method_func
=
_convert
)
rst
=
_tensor_to_node
(
rst
)
...
...
@@ -224,31 +232,35 @@ def _convert_node_and_tensor(orig_func):
def
_wrap_mnode_getattr
(
orig_getattr
):
@
functools
.
wraps
(
orig_getattr
)
def
wraped_fn
(
self
,
name
):
obj
=
self
.
owner
current_graph
=
active_module_tracer
().
current_scope
()
if
self
.
top_graph
is
not
None
:
current_graph
.
_add_input
(
self
)
attr
=
getattr
(
obj
,
name
)
node
=
attr
if
not
isinstance
(
attr
,
TracedModuleBuilder
):
if
isinstance
(
attr
,
Module
):
attr
=
TracedModuleBuilder
(
attr
)
setattr
(
obj
,
name
,
attr
)
if
is_tracing_module
()
and
_graph_surgery_mode
():
obj
=
self
.
owner
current_graph
=
active_module_tracer
().
current_scope
()
if
self
.
top_graph
is
not
None
:
current_graph
.
_add_input
(
self
)
attr
=
getattr
(
obj
,
name
)
node
=
attr
if
not
isinstance
(
attr
,
TracedModuleBuilder
):
if
isinstance
(
attr
,
Module
):
attr
=
TracedModuleBuilder
(
attr
)
setattr
(
obj
,
name
,
attr
)
if
isinstance
(
attr
,
(
NodeMixin
,
RawTensor
)):
NodeMixin
.
wrap
(
attr
,
lambda
:
GetAttr
.
make
(
self
,
type
=
NodeMixin
.
get_wrapped_type
(
attr
),
attr_name
=
name
,
name
=
""
,
),
)
if
isinstance
(
attr
,
(
NodeMixin
,
RawTensor
)):
NodeMixin
.
wrap
(
attr
,
lambda
:
GetAttr
.
make
(
self
,
type
=
NodeMixin
.
get_wrapped_type
(
attr
),
attr_name
=
name
,
name
=
""
,
),
)
if
isinstance
(
attr
,
(
NodeMixin
,
RawTensor
)):
node
=
NodeMixin
.
get
(
attr
)
if
isinstance
(
node
,
ModuleNode
):
node
.
_owner
=
weakref
.
ref
(
attr
)
node
=
NodeMixin
.
get
(
attr
)
if
isinstance
(
node
,
ModuleNode
)
and
isinstance
(
attr
,
(
NodeMixin
,
Module
)):
node
.
_owner
=
weakref
.
ref
(
attr
)
return
node
else
:
node
=
object
.
__getattribute__
(
self
,
name
)
return
node
return
wraped_fn
...
...
@@ -257,10 +269,13 @@ def _wrap_mnode_getattr(orig_getattr):
def
_wrap_mnode_call
(
orig_call
):
@
functools
.
wraps
(
orig_call
)
def
wraped_fn
(
self
,
*
args
,
**
kwargs
):
obj
=
self
.
owner
if
self
.
top_graph
is
not
None
:
active_module_tracer
().
current_scope
().
_add_input
(
self
)
rst
=
obj
(
*
args
,
**
kwargs
)
if
is_tracing_module
()
and
_graph_surgery_mode
():
obj
=
self
.
owner
if
self
.
top_graph
is
not
None
:
active_module_tracer
().
current_scope
().
_add_input
(
self
)
rst
=
obj
(
*
args
,
**
kwargs
)
else
:
raise
TypeError
(
"'ModuleNode' object is not callable"
)
return
rst
return
wraped_fn
...
...
@@ -284,7 +299,7 @@ class _InsertExprs:
Node
.
_set_next_id
(
node_id
)
Expr
.
_set_next_id
(
expr_id
)
set_module_tracing
()
_set_
convert_node_flag
(
True
)
_set_
graph_surgery_mode
(
True
)
assert
active_module_tracer
()
is
None
set_active_module_tracer
(
module_tracer
(
lambda
x
:
_convert_node_and_tensor
(
_wrapped_function
(
x
)))
...
...
@@ -303,20 +318,30 @@ class _InsertExprs:
if
va
is
not
None
:
return
False
active_module_tracer
().
patcher
.
__exit__
(
ty
,
va
,
tr
)
_set_convert_node_flag
(
False
)
while
self
.
_tensor_method_patch
:
pf
=
self
.
_tensor_method_patch
.
pop
()
pf
.
set_func
(
pf
.
origin_fn
)
# delete ModuleNode.__call__ to avoid entering the
# ModuleNode.__init__ method when call a ModuleNode object.
delattr
(
ModuleNode
,
"__call__"
)
module
=
self
.
graph
.
inputs
[
0
].
owner
for
k
,
v
in
module
.
__dict__
.
items
():
if
isinstance
(
v
,
TracedModuleBuilder
):
v
=
v
.
build
()
setattr
(
module
,
k
,
v
)
def
build_traced_module
(
module
:
TracedModuleBuilder
,
target_module
:
TracedModule
):
for
k
,
v
in
module
.
__dict__
.
items
():
if
isinstance
(
v
,
TracedModuleBuilder
):
traced_v
=
v
.
build
()
build_traced_module
(
v
,
traced_v
)
setattr
(
target_module
,
k
,
traced_v
)
build_traced_module
(
module
,
module
)
set_symbolic_shape
(
self
.
use_sym_shape
)
_set_graph_surgery_mode
(
False
)
set_active_module_tracer
(
None
)
unset_module_tracing
()
...
...
@@ -435,7 +460,7 @@ class NameSpace:
def
unassociate_name_with_obj
(
self
,
node
:
Node
):
assert
node
.
name
in
self
.
used_names
assert
self
.
used_names
[
node
.
name
]
is
node
#
assert self.used_names[node.name] is node
self
.
_used_names
[
node
.
name
]
=
None
@
property
...
...
@@ -1364,6 +1389,8 @@ class TracedModuleBuilder(NodeMixin):
for
node
in
self
.
nodes
:
node
.
module_type
=
mod_type
return
self
.
_mod
elif
isinstance
(
self
.
_mod
,
TracedModule
)
and
_graph_surgery_mode
():
return
self
.
_mod
else
:
is_qat
=
isinstance
(
self
.
_mod
,
QATModule
)
or
(
...
...
@@ -1409,6 +1436,10 @@ class TracedModuleBuilder(NodeMixin):
def
__call__
(
self
,
*
args
,
**
kwargs
):
assert
isinstance
(
self
.
_mod
,
Module
)
is_graph_surgery_mode
=
_graph_surgery_mode
()
if
isinstance
(
self
.
_mod
,
TracedModule
)
and
is_graph_surgery_mode
:
_set_graph_surgery_mode
(
False
)
# prepare args and kwargs for inner graph
if
"method_func"
in
kwargs
:
kwargs
.
pop
(
"method_func"
)
...
...
@@ -1514,7 +1545,7 @@ class TracedModuleBuilder(NodeMixin):
)
rst
=
type
(
self
.
_mod
).
forward
(
*
args
,
**
kwargs
)
if
_
convert_node_flag
():
if
_
graph_surgery_mode
():
rst
=
_node_to_tensor
(
rst
)[
0
][
0
]
outputs
,
out_def
=
tree_flatten
(
rst
,
is_leaf
=
_is_leaf
)
...
...
@@ -1536,6 +1567,7 @@ class TracedModuleBuilder(NodeMixin):
callnode
.
add_outputs
(
outputs
)
self
.
_argdef_graph_map
[
callnode
.
arg_def
]
=
self
.
_body
self
.
_argdef_outdef_map
[
callnode
.
arg_def
]
=
out_def
_set_graph_surgery_mode
(
is_graph_surgery_mode
)
return
rst
def
__setattr__
(
self
,
name
,
value
):
...
...
@@ -1556,7 +1588,7 @@ class TracedModuleBuilder(NodeMixin):
return
active_module_tracer
().
patcher
.
wrap_fn
(
attr
)
if
isinstance
(
attr
,
(
List
,
Dict
)):
flag
=
_set_
convert_node_flag
(
False
)
flag
=
_set_
graph_surgery_mode
(
False
)
unset_module_tracing
()
has_module
,
m_container
=
replace_container_with_module_container
(
attr
)
if
m_container
:
...
...
@@ -1567,7 +1599,7 @@ class TracedModuleBuilder(NodeMixin):
" Module and Non-Module objects."
)
set_module_tracing
()
_set_
convert_node_flag
(
flag
)
_set_
graph_surgery_mode
(
flag
)
if
isinstance
(
attr
,
Module
):
attr
=
TracedModuleBuilder
(
attr
)
...
...
@@ -1628,20 +1660,25 @@ class _expr_iter:
self
.
_visited_graph
=
set
()
def
__iter__
(
self
):
for
inp_node
in
self
.
graph
.
inputs
:
yield
inp_node
.
expr
for
expr
in
self
.
graph
.
_exprs
:
if
isinstance
(
expr
,
CallMethod
)
and
isinstance
(
expr
.
inputs
[
0
],
ModuleNode
):
yield
expr
if
(
self
.
recursive
and
expr
.
graph
is
not
None
and
id
(
expr
.
graph
)
not
in
self
.
_visited_graph
):
yield
from
expr
.
graph
.
exprs
(
self
.
recursive
)
self
.
_visited_graph
.
add
(
id
(
expr
.
graph
))
else
:
yield
expr
yield
from
self
.
_gen_expr
(
self
.
graph
)
def
_gen_expr
(
self
,
graph
:
InternalGraph
):
visit_inp
=
set
()
for
inp_node
in
graph
.
inputs
:
if
inp_node
not
in
visit_inp
:
yield
inp_node
.
expr
visit_inp
.
add
(
inp_node
)
for
expr
in
graph
.
_exprs
:
yield
expr
if
(
self
.
recursive
and
hasattr
(
expr
,
"graph"
)
and
expr
.
graph
is
not
None
and
id
(
expr
.
graph
)
not
in
self
.
_visited_graph
):
self
.
_visited_graph
.
add
(
id
(
expr
.
graph
))
yield
from
self
.
_gen_expr
(
expr
.
graph
)
class
_node_iter
:
...
...
imperative/python/test/unit/traced_module/test_modification.py
浏览文件 @
e6c271ae
...
...
@@ -15,7 +15,7 @@ import megengine.functional as F
import
megengine.module
as
M
import
megengine.module.qat
as
qat
from
megengine.module.identity
import
Identity
from
megengine.traced_module
import
trace_module
from
megengine.traced_module
import
TracedModule
,
trace_module
from
megengine.traced_module.expr
import
CallFunction
,
CallMethod
,
Expr
,
GetAttr
,
Input
from
megengine.traced_module.node
import
ModuleNode
,
Node
,
TensorNode
...
...
@@ -182,7 +182,6 @@ def test_insert_module():
setattr
(
traced_module
,
"neg"
,
Neg
(
name
=
"neg"
))
setattr
(
traced_module
,
"neg2"
,
Neg
(
name
=
"neg"
))
setattr
(
traced_module
,
"param"
,
F
.
zeros
((
1
,)))
with
graph
.
insert_exprs
():
neg_out
=
self
.
neg
(
relu_out
)
neg_out
=
self
.
neg2
(
relu_out
)
...
...
@@ -199,6 +198,32 @@ def test_insert_module():
if
isinstance
(
n
,
TensorNode
):
assert
n
.
value
is
None
traced_module
,
x
,
expect
=
_init_module
()
setattr
(
traced_module
.
block0
,
"neg"
,
Neg
(
name
=
None
))
graph
=
traced_module
.
graph
self
=
graph
.
inputs
[
0
]
out_node
=
graph
.
outputs
[
0
]
with
graph
.
insert_exprs
():
neg_out
=
self
.
block0
.
neg
(
out_node
)
graph
.
replace_node
({
out_node
:
neg_out
})
graph
.
compile
()
np
.
testing
.
assert_allclose
(
expect
,
-
traced_module
(
x
),
atol
=
1e-6
)
assert
isinstance
(
traced_module
.
block0
.
neg
,
TracedModule
)
assert
traced_module
.
block0
.
neg
.
graph
is
not
None
setattr
(
traced_module
.
block0
.
neg
,
"neg"
,
Neg
(
name
=
None
))
setattr
(
traced_module
.
block0
.
neg
.
neg
,
"relu"
,
M
.
ReLU
())
out_node
=
graph
.
outputs
[
0
]
with
graph
.
insert_exprs
():
neg_out
=
self
.
block0
.
neg
.
neg
(
out_node
)
neg_out
=
self
.
block0
.
neg
.
neg
(
neg_out
)
relu_out
=
self
.
block0
.
neg
.
neg
.
relu
(
neg_out
)
graph
.
replace_node
({
out_node
:
relu_out
})
graph
.
compile
()
np
.
testing
.
assert_allclose
(
F
.
relu
(
-
expect
),
traced_module
(
x
),
atol
=
1e-6
)
assert
isinstance
(
traced_module
.
block0
.
neg
.
neg
,
TracedModule
)
assert
traced_module
.
block0
.
neg
.
neg
.
graph
is
not
None
def
test_insert_qat_module
():
class
concat
(
qat
.
Concat
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录