Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a3f9073c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a3f9073c
编写于
8月 10, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(traced_module): update graph transform and add _module_name
GitOrigin-RevId: ef63ae0fd0dcdd69c3566e19f8a34d85422a1e1e
上级
b3d0affa
变更
7
展开全部
隐藏空白更改
内联
并排
Showing
7 changed file
with
582 addition
and
200 deletion
+582
-200
imperative/python/megengine/experimental/traced_module/__init__.py
...e/python/megengine/experimental/traced_module/__init__.py
+0
-1
imperative/python/megengine/experimental/traced_module/expr.py
...ative/python/megengine/experimental/traced_module/expr.py
+66
-40
imperative/python/megengine/experimental/traced_module/module_tracer.py
...hon/megengine/experimental/traced_module/module_tracer.py
+39
-3
imperative/python/megengine/experimental/traced_module/node.py
...ative/python/megengine/experimental/traced_module/node.py
+64
-17
imperative/python/megengine/experimental/traced_module/pytree.py
...ive/python/megengine/experimental/traced_module/pytree.py
+1
-1
imperative/python/megengine/experimental/traced_module/traced_module.py
...hon/megengine/experimental/traced_module/traced_module.py
+408
-135
imperative/python/test/unit/traced_module/test_modification.py
...ative/python/test/unit/traced_module/test_modification.py
+4
-3
未找到文件。
imperative/python/megengine/experimental/traced_module/__init__.py
浏览文件 @
a3f9073c
...
...
@@ -14,7 +14,6 @@ from .traced_module import (
register_as_builtin
,
trace_module
,
wrap
,
wrap_tensors
,
)
_register_all_builtin_module
()
...
...
imperative/python/megengine/experimental/traced_module/expr.py
浏览文件 @
a3f9073c
...
...
@@ -33,17 +33,6 @@ def rstrip(s: str, __chars: str):
return
s
def
lstrip
(
s
:
str
,
__chars
:
str
):
__chars
=
re
.
escape
(
__chars
)
s
=
re
.
sub
(
r
"^(?:%s)+(?P<right>.*)$"
%
__chars
,
"\g<right>"
,
s
)
return
s
def
strip
(
s
:
str
,
__chars
:
str
):
s
=
lstrip
(
rstrip
(
s
,
__chars
),
__chars
)
return
s
class
Expr
:
"""
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
...
...
@@ -89,27 +78,40 @@ class Expr:
outputs
=
(
outputs
,)
name
=
None
orig_name
=
None
if
isinstance
(
self
,
CallMethod
):
name
=
self
.
inputs
[
0
].
_name
assert
name
is
not
None
orig_name
=
self
.
inputs
[
0
].
_orig_name
assert
isinstance
(
name
,
str
),
"The name of ({}) must be a str"
.
format
(
self
.
inputs
[
0
]
)
assert
isinstance
(
orig_name
,
str
),
"The orig_name of ({}) must be a str"
.
format
(
self
.
inputs
[
0
])
name
=
rstrip
(
name
,
"_out"
)
if
self
.
method
==
"__call__"
:
name
+=
"_out"
orig_name
+=
"_out"
else
:
strip_method
=
s
trip
(
self
.
method
,
"_"
)
strip_method
=
s
elf
.
method
.
strip
(
"_"
)
name
=
"%s_out"
%
strip_method
orig_name
=
name
elif
isinstance
(
self
,
CallFunction
):
name
=
self
.
func
.
__name__
+
"_out"
elif
isinstance
(
self
,
Apply
):
name
=
str
(
self
.
opdef
).
lower
()
+
"_out"
for
i
in
outputs
:
assert
isinstance
(
i
,
RawTensor
)
assert
isinstance
(
i
,
RawTensor
)
,
"The output must be a Tensor"
o_name
=
(
active_module_tracer
().
current_scope
().
_create_unique_name
(
name
)
)
self
.
outputs
.
append
(
NodeMixin
.
get_wrapped_type
(
i
)(
expr
=
self
,
name
=
o_name
)
NodeMixin
.
get_wrapped_type
(
i
)(
expr
=
self
,
name
=
o_name
,
orig_name
=
orig_name
if
orig_name
else
o_name
,
)
)
for
i
,
node
in
zip
(
outputs
,
self
.
outputs
,):
...
...
@@ -125,21 +127,26 @@ class Expr:
else
:
return
inputs
,
{}
def
_replace_nodes
(
self
,
repl_dict
:
Dict
[
Node
,
Node
],
nodes
:
List
[
Node
]):
def
replace_inputs
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
while
repl_dict
:
node
,
repl_node
=
repl_dict
.
popitem
()
assert
type
(
node
)
==
type
(
repl_node
)
assert
node
in
nodes
index
=
nodes
.
index
(
node
)
nodes
[
index
]
=
repl_node
assert
node
in
self
.
inputs
,
"({}) is not in the ({})"
.
format
(
node
,
self
)
assert
(
repl_node
.
top_graph
==
node
.
top_graph
),
"({}) and ({}) are not in the same graph"
.
format
(
node
,
repl_node
)
graph
=
self
.
top_graph
repl_expr_idx
=
graph
.
_exprs
.
index
(
repl_node
.
expr
)
self_idx
=
graph
.
_exprs
.
index
(
self
)
assert
(
repl_expr_idx
<
self_idx
),
"({}) must be generated before ({})"
.
format
(
repl_node
,
self
)
idx
=
self
.
inputs
.
index
(
node
)
self
.
inputs
[
idx
]
=
repl_node
user_idx
=
node
.
users
.
index
(
self
)
assert
user_idx
>=
0
node
.
users
.
pop
(
user_idx
)
repl_node
.
users
.
append
(
self
)
node
.
users
.
pop
(
self
)
def
replace_inputs
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
self
.
_replace_nodes
(
repl_dict
,
self
.
inputs
)
def
replace_outputs
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
self
.
_replace_nodes
(
repl_dict
,
self
.
outputs
)
@
property
def
kwargs
(
self
):
...
...
@@ -159,7 +166,8 @@ class Expr:
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
.
pop
(
"_top_graph"
,
None
)
if
"_top_graph"
in
state
:
state
.
pop
(
"_top_graph"
)
return
state
...
...
@@ -167,12 +175,14 @@ class Expr:
class
Input
(
Expr
):
name
=
None
def
__init__
(
self
,
name
=
None
,
type
=
None
):
def
__init__
(
self
,
name
=
None
,
type
=
None
,
orig_name
=
None
):
super
().
__init__
()
self
.
inputs
=
[]
node_cls
=
type
if
type
else
Node
if
orig_name
is
None
:
orig_name
=
name
self
.
outputs
=
[
node_cls
(
self
,
name
=
name
),
node_cls
(
self
,
name
=
name
,
orig_name
=
orig_name
),
]
self
.
name
=
name
...
...
@@ -184,7 +194,7 @@ class Input(Expr):
active_module_tracer
().
current_scope
().
_create_unique_name
(
oup_node
.
_name
)
)
oup_node
.
_name
=
name
active_module_tracer
().
current_scope
().
add_input
(
oup_node
)
active_module_tracer
().
current_scope
().
_
add_input
(
oup_node
)
return
expr
.
outputs
[
0
]
def
__repr__
(
self
):
...
...
@@ -195,7 +205,7 @@ class Input(Expr):
class
GetAttr
(
Expr
):
name
=
None
def
__init__
(
self
,
module
,
name
,
type
=
None
):
def
__init__
(
self
,
module
,
name
,
type
=
None
,
orig_name
=
None
):
super
().
__init__
()
assert
isinstance
(
module
,
ModuleNode
)
self
.
inputs
=
[
...
...
@@ -205,7 +215,7 @@ class GetAttr(Expr):
self
.
name
=
name
node_cls
=
type
if
type
else
Node
self
.
outputs
=
[
node_cls
(
self
,
name
=
name
),
node_cls
(
self
,
name
=
name
,
orig_name
=
orig_name
),
]
@
classmethod
...
...
@@ -218,7 +228,7 @@ class GetAttr(Expr):
module
=
module
.
expr
.
inputs
[
0
]
oup_name
=
active_module_tracer
().
current_scope
().
_create_unique_name
(
oup_name
)
expr
.
outputs
[
0
].
_name
=
oup_name
active_module_tracer
().
current_scope
().
insert
(
expr
)
active_module_tracer
().
current_scope
().
_
insert
(
expr
)
return
expr
.
outputs
[
0
]
def
interpret
(
self
,
*
inputs
):
...
...
@@ -255,7 +265,7 @@ class CallMethod(Expr):
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
expr
=
cls
(
*
args
,
**
kwargs
)
active_module_tracer
().
current_scope
().
insert
(
expr
)
active_module_tracer
().
current_scope
().
_
insert
(
expr
)
return
expr
@
property
...
...
@@ -315,7 +325,7 @@ class Apply(Expr):
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
expr
=
cls
(
*
args
,
**
kwargs
)
active_module_tracer
().
current_scope
().
insert
(
expr
)
active_module_tracer
().
current_scope
().
_
insert
(
expr
)
return
expr
def
interpret
(
self
,
*
inputs
):
...
...
@@ -382,7 +392,7 @@ class CallFunction(Expr):
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
expr
=
cls
(
*
args
,
**
kwargs
)
active_module_tracer
().
current_scope
().
insert
(
expr
)
active_module_tracer
().
current_scope
().
_
insert
(
expr
)
return
expr
def
interpret
(
self
,
*
inputs
):
...
...
@@ -423,7 +433,7 @@ class Constant(Expr):
self
.
inputs
=
[]
node_cls
=
NodeMixin
.
get_wrapped_type
(
c
)
self
.
outputs
=
[
node_cls
(
self
,
name
=
name
),
node_cls
(
self
,
name
=
name
,
orig_name
=
name
),
]
self
.
outputs
[
0
].
_name
=
name
if
name
else
"const_"
+
str
(
self
.
_id
)
...
...
@@ -431,9 +441,23 @@ class Constant(Expr):
def
make
(
cls
,
*
args
,
**
kwargs
):
expr
=
cls
(
*
args
,
**
kwargs
)
name
=
"const_module"
if
isinstance
(
expr
.
value
,
Module
)
else
"const_tensor"
name
=
active_module_tracer
().
current_scope
().
_create_unique_name
(
name
)
full_name
=
name
if
(
isinstance
(
expr
.
value
,
RawTensor
)
and
id
(
expr
.
value
)
in
active_module_tracer
().
id2name
):
full_name
=
active_module_tracer
().
id2name
[
id
(
expr
.
value
)]
scope_name
=
active_module_tracer
().
current_scope
().
_module_name
if
full_name
and
scope_name
:
full_name
=
(
"self."
+
full_name
)[
len
(
scope_name
)
+
1
:]
else
:
full_name
=
name
else
:
full_name
=
name
name
=
active_module_tracer
().
current_scope
().
_create_unique_name
(
full_name
)
expr
.
outputs
[
0
].
_name
=
name
active_module_tracer
().
current_scope
().
insert
(
expr
)
expr
.
outputs
[
0
].
_orig_name
=
full_name
active_module_tracer
().
current_scope
().
_insert
(
expr
)
return
expr
.
outputs
[
0
]
def
interpret
(
self
,
*
inputs
):
...
...
@@ -453,7 +477,9 @@ class Constant(Expr):
)
def
__getstate__
(
self
):
state
=
super
().
__getstate__
()
state
=
self
.
__dict__
.
copy
()
if
"_top_graph"
in
state
:
state
.
pop
(
"_top_graph"
)
if
isinstance
(
self
.
value
,
RawTensor
):
state
[
"value"
]
=
Tensor
(
self
.
value
)
return
state
imperative/python/megengine/experimental/traced_module/module_tracer.py
浏览文件 @
a3f9073c
...
...
@@ -84,6 +84,34 @@ BUILTIN_ARRAY_METHOD = [
"__setitem__"
,
]
BUILTIN_TENSOR_WRAP_METHOD
=
[
"T"
,
"to"
,
"size"
,
"shape"
,
"detach"
,
"device"
,
"dtype"
,
"grad"
,
"item"
,
"name"
,
"ndim"
,
"numpy"
,
"qparams"
,
"set_value"
,
"reset_zero"
,
"requires_grad"
,
"_reset"
,
"_isscalar"
,
"_setscalar"
,
"_tuple_shape"
,
"_unsetscalar"
,
]
def
get_tensor_wrapable_method
():
return
BUILTIN_TENSOR_WRAP_METHOD
+
BUILTIN_ARRAY_METHOD
def
active_module_tracer
():
return
_active_module_tracer
...
...
@@ -101,9 +129,10 @@ class module_tracer:
_active_scopes
=
None
def
__init__
(
self
,
wrap_fn
):
def
__init__
(
self
,
wrap_fn
,
id2name
):
self
.
_active_scopes
=
[]
self
.
patcher
=
Patcher
(
wrap_fn
)
self
.
id2name
=
id2name
@
classmethod
def
register_as_builtin
(
cls
,
mod
):
...
...
@@ -127,6 +156,10 @@ class module_tracer:
return
None
class
NotExist
:
pass
class
PatchedFn
:
frame_dict
=
None
name
=
None
...
...
@@ -138,14 +171,17 @@ class PatchedFn:
self
.
origin_fn
=
(
self
.
frame_dict
[
name
]
if
isinstance
(
frame_dict
,
collections
.
abc
.
Mapping
)
else
getattr
(
frame_dict
,
name
)
else
getattr
(
frame_dict
,
name
,
NotExist
)
)
def
set_func
(
self
,
func
):
if
isinstance
(
self
.
frame_dict
,
collections
.
abc
.
Mapping
):
self
.
frame_dict
[
self
.
name
]
=
func
else
:
setattr
(
self
.
frame_dict
,
self
.
name
,
func
)
if
func
is
not
NotExist
:
setattr
(
self
.
frame_dict
,
self
.
name
,
func
)
else
:
delattr
(
self
.
frame_dict
,
self
.
name
)
class
Patcher
:
...
...
imperative/python/megengine/experimental/traced_module/node.py
浏览文件 @
a3f9073c
...
...
@@ -30,14 +30,17 @@ class Node:
_id
=
None
_top_graph
=
None
# type: weakref.ReferenceType
_name
=
None
_orig_name
=
None
_format_spec
=
""
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
):
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
,
orig_name
:
str
=
None
):
self
.
expr
=
expr
self
.
users
=
[]
# List[Expr]
self
.
_id
=
Node
.
__total_id
Node
.
__total_id
+=
1
self
.
_name
=
name
self
.
_orig_name
=
orig_name
self
.
actual_node
=
[]
# type: List[Node]
def
__setstate__
(
self
,
d
):
self
.
__dict__
=
d
...
...
@@ -48,7 +51,7 @@ class Node:
return
self
.
__format__
(
format_spec
)
def
__format__
(
self
,
format_spec
:
str
)
->
str
:
if
format_spec
==
""
or
format_spec
is
None
:
if
not
format_spec
:
format_spec
=
Node
.
_format_spec
name
=
self
.
_name
if
name
is
None
:
...
...
@@ -100,9 +103,8 @@ class ModuleNode(Node):
module_type
=
Module
# type: Type[Module]
_owner
=
None
# type: weakref.ReferenceType
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
):
super
().
__init__
(
expr
,
name
)
self
.
actual_mnode
=
[]
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
,
orig_name
:
str
=
None
):
super
().
__init__
(
expr
,
name
,
orig_name
)
def
__getstate__
(
self
):
return
{
...
...
@@ -110,6 +112,7 @@ class ModuleNode(Node):
"users"
:
self
.
users
,
"_id"
:
self
.
_id
,
"_name"
:
self
.
_name
,
"_orig_name"
:
self
.
_orig_name
,
"module_type"
:
self
.
module_type
,
}
...
...
@@ -125,23 +128,67 @@ class TensorNode(Node):
``TensorNode`` represents the Tensor objects.
"""
shape
=
None
# type: Tuple[int]
dtype
=
None
# type: numpy.dtype
qparams
=
None
device
=
None
_shape
=
None
# type: Tuple[int]
_dtype
=
None
# type: numpy.dtype
_qparams
=
None
_device
=
None
_value
=
None
# type: Tensor
def
__getstate__
(
self
):
return
{
"expr"
:
self
.
expr
,
"users"
:
self
.
users
,
"_id"
:
self
.
_id
,
"
qparams"
:
self
.
qparams
,
"
shape"
:
self
.
shape
,
"
dtype"
:
self
.
dtype
,
"
device"
:
self
.
device
,
"
_qparams"
:
self
.
_
qparams
,
"
_shape"
:
self
.
_
shape
,
"
_dtype"
:
self
.
_
dtype
,
"
_device"
:
self
.
_
device
,
"_name"
:
self
.
_name
,
"_orig_name"
:
self
.
_orig_name
,
}
@
property
def
shape
(
self
):
return
self
.
_shape
@
shape
.
setter
def
shape
(
self
,
shape
):
self
.
_shape
=
shape
@
property
def
dtype
(
self
):
return
self
.
_dtype
@
dtype
.
setter
def
dtype
(
self
,
dtype
):
self
.
_dtype
=
dtype
@
property
def
device
(
self
):
return
self
.
_device
@
device
.
setter
def
device
(
self
,
device
):
self
.
_device
=
device
@
property
def
qparams
(
self
):
return
self
.
_qparams
@
qparams
.
setter
def
qparams
(
self
,
qparams
):
self
.
_qparams
=
qparams
@
property
def
value
(
self
):
return
self
.
_value
@
value
.
setter
def
value
(
self
,
value
):
if
isinstance
(
value
,
RawTensor
)
and
NodeMixin
.
get
(
value
,
None
)
is
not
None
:
setattr
(
value
,
"_NodeMixin__node"
,
None
)
self
.
_value
=
value
class
NodeMixin
(
abc
.
ABC
):
__node
=
None
...
...
@@ -156,13 +203,13 @@ class NodeMixin(abc.ABC):
assert
isinstance
(
node
,
TensorNode
)
assert
isinstance
(
value
,
RawTensor
)
if
isinstance
(
value
,
RawTensor
):
node
.
dtype
=
value
.
dtype
node
.
shape
=
(
node
.
_
dtype
=
value
.
dtype
node
.
_
shape
=
(
value
.
_tuple_shape
if
isinstance
(
value
,
Tensor
)
else
value
.
shape
)
node
.
device
=
value
.
device
node
.
_
device
=
value
.
device
if
hasattr
(
value
,
"_qparams"
)
and
value
.
_qparams
is
not
None
:
node
.
qparams
=
value
.
qparams
node
.
_
qparams
=
value
.
qparams
@
classmethod
def
wrap
(
cls
,
value
,
node
):
...
...
imperative/python/megengine/experimental/traced_module/pytree.py
浏览文件 @
a3f9073c
...
...
@@ -133,7 +133,7 @@ def _is_leaf(obj):
def
_leaf_type
(
node
):
if
isinstance
(
node
,
(
RawTensor
,
TensorNode
)):
return
(
Tensor
,
TensorNode
,
ArgsIndex
)
elif
isinstance
(
node
,
(
NodeMixin
,
Module
)):
elif
isinstance
(
node
,
(
NodeMixin
,
Module
,
ModuleNode
)):
return
(
Module
,
ModuleNode
,
NodeMixin
,
ArgsIndex
)
else
:
return
(
type
(
node
),
ArgsIndex
)
...
...
imperative/python/megengine/experimental/traced_module/traced_module.py
浏览文件 @
a3f9073c
此差异已折叠。
点击以展开。
imperative/python/test/unit/traced_module/test_modification.py
浏览文件 @
a3f9073c
...
...
@@ -64,9 +64,10 @@ def test_search():
def
test_insert
():
traced_module
,
x
,
expect
=
_init_block
()
graph
=
traced_module
.
graph
relu_node
=
graph
.
get_function_by_type
(
F
.
relu
).
as_unique
().
outputs
neg_node
=
graph
.
insert_function
(
lambda
x
:
F
.
neg
(
x
),
*
relu_node
)
graph
.
replace_node
({
relu_node
[
0
]:
neg_node
})
relu_out
=
graph
.
get_function_by_type
(
F
.
relu
).
as_unique
().
outputs
[
0
]
with
graph
.
insert_exprs
():
neg_out
=
F
.
neg
(
relu_out
)
graph
.
replace_node
({
relu_out
:
neg_out
})
graph
.
compile
()
np
.
testing
.
assert_allclose
(
expect
-
1
,
1
-
traced_module
(
x
),
atol
=
1e-6
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录