Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a3f9073c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a3f9073c
编写于
8月 10, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(traced_module): update graph transform and add _module_name
GitOrigin-RevId: ef63ae0fd0dcdd69c3566e19f8a34d85422a1e1e
上级
b3d0affa
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
582 addition
and
200 deletion
+582
-200
imperative/python/megengine/experimental/traced_module/__init__.py
...e/python/megengine/experimental/traced_module/__init__.py
+0
-1
imperative/python/megengine/experimental/traced_module/expr.py
...ative/python/megengine/experimental/traced_module/expr.py
+66
-40
imperative/python/megengine/experimental/traced_module/module_tracer.py
...hon/megengine/experimental/traced_module/module_tracer.py
+39
-3
imperative/python/megengine/experimental/traced_module/node.py
...ative/python/megengine/experimental/traced_module/node.py
+64
-17
imperative/python/megengine/experimental/traced_module/pytree.py
...ive/python/megengine/experimental/traced_module/pytree.py
+1
-1
imperative/python/megengine/experimental/traced_module/traced_module.py
...hon/megengine/experimental/traced_module/traced_module.py
+408
-135
imperative/python/test/unit/traced_module/test_modification.py
...ative/python/test/unit/traced_module/test_modification.py
+4
-3
未找到文件。
imperative/python/megengine/experimental/traced_module/__init__.py
浏览文件 @
a3f9073c
...
...
@@ -14,7 +14,6 @@ from .traced_module import (
register_as_builtin
,
trace_module
,
wrap
,
wrap_tensors
,
)
_register_all_builtin_module
()
...
...
imperative/python/megengine/experimental/traced_module/expr.py
浏览文件 @
a3f9073c
...
...
@@ -33,17 +33,6 @@ def rstrip(s: str, __chars: str):
return
s
def
lstrip
(
s
:
str
,
__chars
:
str
):
__chars
=
re
.
escape
(
__chars
)
s
=
re
.
sub
(
r
"^(?:%s)+(?P<right>.*)$"
%
__chars
,
"\g<right>"
,
s
)
return
s
def
strip
(
s
:
str
,
__chars
:
str
):
s
=
lstrip
(
rstrip
(
s
,
__chars
),
__chars
)
return
s
class
Expr
:
"""
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
...
...
@@ -89,27 +78,40 @@ class Expr:
outputs
=
(
outputs
,)
name
=
None
orig_name
=
None
if
isinstance
(
self
,
CallMethod
):
name
=
self
.
inputs
[
0
].
_name
assert
name
is
not
None
orig_name
=
self
.
inputs
[
0
].
_orig_name
assert
isinstance
(
name
,
str
),
"The name of ({}) must be a str"
.
format
(
self
.
inputs
[
0
]
)
assert
isinstance
(
orig_name
,
str
),
"The orig_name of ({}) must be a str"
.
format
(
self
.
inputs
[
0
])
name
=
rstrip
(
name
,
"_out"
)
if
self
.
method
==
"__call__"
:
name
+=
"_out"
orig_name
+=
"_out"
else
:
strip_method
=
s
trip
(
self
.
method
,
"_"
)
strip_method
=
s
elf
.
method
.
strip
(
"_"
)
name
=
"%s_out"
%
strip_method
orig_name
=
name
elif
isinstance
(
self
,
CallFunction
):
name
=
self
.
func
.
__name__
+
"_out"
elif
isinstance
(
self
,
Apply
):
name
=
str
(
self
.
opdef
).
lower
()
+
"_out"
for
i
in
outputs
:
assert
isinstance
(
i
,
RawTensor
)
assert
isinstance
(
i
,
RawTensor
)
,
"The output must be a Tensor"
o_name
=
(
active_module_tracer
().
current_scope
().
_create_unique_name
(
name
)
)
self
.
outputs
.
append
(
NodeMixin
.
get_wrapped_type
(
i
)(
expr
=
self
,
name
=
o_name
)
NodeMixin
.
get_wrapped_type
(
i
)(
expr
=
self
,
name
=
o_name
,
orig_name
=
orig_name
if
orig_name
else
o_name
,
)
)
for
i
,
node
in
zip
(
outputs
,
self
.
outputs
,):
...
...
@@ -125,21 +127,26 @@ class Expr:
else
:
return
inputs
,
{}
def
_replace_nodes
(
self
,
repl_dict
:
Dict
[
Node
,
Node
],
nodes
:
List
[
Node
]):
def
replace_inputs
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
while
repl_dict
:
node
,
repl_node
=
repl_dict
.
popitem
()
assert
type
(
node
)
==
type
(
repl_node
)
assert
node
in
nodes
index
=
nodes
.
index
(
node
)
nodes
[
index
]
=
repl_node
assert
node
in
self
.
inputs
,
"({}) is not in the ({})"
.
format
(
node
,
self
)
assert
(
repl_node
.
top_graph
==
node
.
top_graph
),
"({}) and ({}) are not in the same graph"
.
format
(
node
,
repl_node
)
graph
=
self
.
top_graph
repl_expr_idx
=
graph
.
_exprs
.
index
(
repl_node
.
expr
)
self_idx
=
graph
.
_exprs
.
index
(
self
)
assert
(
repl_expr_idx
<
self_idx
),
"({}) must be generated before ({})"
.
format
(
repl_node
,
self
)
idx
=
self
.
inputs
.
index
(
node
)
self
.
inputs
[
idx
]
=
repl_node
user_idx
=
node
.
users
.
index
(
self
)
assert
user_idx
>=
0
node
.
users
.
pop
(
user_idx
)
repl_node
.
users
.
append
(
self
)
node
.
users
.
pop
(
self
)
def
replace_inputs
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
self
.
_replace_nodes
(
repl_dict
,
self
.
inputs
)
def
replace_outputs
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
self
.
_replace_nodes
(
repl_dict
,
self
.
outputs
)
@
property
def
kwargs
(
self
):
...
...
@@ -159,7 +166,8 @@ class Expr:
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
.
pop
(
"_top_graph"
,
None
)
if
"_top_graph"
in
state
:
state
.
pop
(
"_top_graph"
)
return
state
...
...
@@ -167,12 +175,14 @@ class Expr:
class
Input
(
Expr
):
name
=
None
def
__init__
(
self
,
name
=
None
,
type
=
None
):
def
__init__
(
self
,
name
=
None
,
type
=
None
,
orig_name
=
None
):
super
().
__init__
()
self
.
inputs
=
[]
node_cls
=
type
if
type
else
Node
if
orig_name
is
None
:
orig_name
=
name
self
.
outputs
=
[
node_cls
(
self
,
name
=
name
),
node_cls
(
self
,
name
=
name
,
orig_name
=
orig_name
),
]
self
.
name
=
name
...
...
@@ -184,7 +194,7 @@ class Input(Expr):
active_module_tracer
().
current_scope
().
_create_unique_name
(
oup_node
.
_name
)
)
oup_node
.
_name
=
name
active_module_tracer
().
current_scope
().
add_input
(
oup_node
)
active_module_tracer
().
current_scope
().
_
add_input
(
oup_node
)
return
expr
.
outputs
[
0
]
def
__repr__
(
self
):
...
...
@@ -195,7 +205,7 @@ class Input(Expr):
class
GetAttr
(
Expr
):
name
=
None
def
__init__
(
self
,
module
,
name
,
type
=
None
):
def
__init__
(
self
,
module
,
name
,
type
=
None
,
orig_name
=
None
):
super
().
__init__
()
assert
isinstance
(
module
,
ModuleNode
)
self
.
inputs
=
[
...
...
@@ -205,7 +215,7 @@ class GetAttr(Expr):
self
.
name
=
name
node_cls
=
type
if
type
else
Node
self
.
outputs
=
[
node_cls
(
self
,
name
=
name
),
node_cls
(
self
,
name
=
name
,
orig_name
=
orig_name
),
]
@
classmethod
...
...
@@ -218,7 +228,7 @@ class GetAttr(Expr):
module
=
module
.
expr
.
inputs
[
0
]
oup_name
=
active_module_tracer
().
current_scope
().
_create_unique_name
(
oup_name
)
expr
.
outputs
[
0
].
_name
=
oup_name
active_module_tracer
().
current_scope
().
insert
(
expr
)
active_module_tracer
().
current_scope
().
_
insert
(
expr
)
return
expr
.
outputs
[
0
]
def
interpret
(
self
,
*
inputs
):
...
...
@@ -255,7 +265,7 @@ class CallMethod(Expr):
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
expr
=
cls
(
*
args
,
**
kwargs
)
active_module_tracer
().
current_scope
().
insert
(
expr
)
active_module_tracer
().
current_scope
().
_
insert
(
expr
)
return
expr
@
property
...
...
@@ -315,7 +325,7 @@ class Apply(Expr):
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
expr
=
cls
(
*
args
,
**
kwargs
)
active_module_tracer
().
current_scope
().
insert
(
expr
)
active_module_tracer
().
current_scope
().
_
insert
(
expr
)
return
expr
def
interpret
(
self
,
*
inputs
):
...
...
@@ -382,7 +392,7 @@ class CallFunction(Expr):
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
expr
=
cls
(
*
args
,
**
kwargs
)
active_module_tracer
().
current_scope
().
insert
(
expr
)
active_module_tracer
().
current_scope
().
_
insert
(
expr
)
return
expr
def
interpret
(
self
,
*
inputs
):
...
...
@@ -423,7 +433,7 @@ class Constant(Expr):
self
.
inputs
=
[]
node_cls
=
NodeMixin
.
get_wrapped_type
(
c
)
self
.
outputs
=
[
node_cls
(
self
,
name
=
name
),
node_cls
(
self
,
name
=
name
,
orig_name
=
name
),
]
self
.
outputs
[
0
].
_name
=
name
if
name
else
"const_"
+
str
(
self
.
_id
)
...
...
@@ -431,9 +441,23 @@ class Constant(Expr):
def
make
(
cls
,
*
args
,
**
kwargs
):
expr
=
cls
(
*
args
,
**
kwargs
)
name
=
"const_module"
if
isinstance
(
expr
.
value
,
Module
)
else
"const_tensor"
name
=
active_module_tracer
().
current_scope
().
_create_unique_name
(
name
)
full_name
=
name
if
(
isinstance
(
expr
.
value
,
RawTensor
)
and
id
(
expr
.
value
)
in
active_module_tracer
().
id2name
):
full_name
=
active_module_tracer
().
id2name
[
id
(
expr
.
value
)]
scope_name
=
active_module_tracer
().
current_scope
().
_module_name
if
full_name
and
scope_name
:
full_name
=
(
"self."
+
full_name
)[
len
(
scope_name
)
+
1
:]
else
:
full_name
=
name
else
:
full_name
=
name
name
=
active_module_tracer
().
current_scope
().
_create_unique_name
(
full_name
)
expr
.
outputs
[
0
].
_name
=
name
active_module_tracer
().
current_scope
().
insert
(
expr
)
expr
.
outputs
[
0
].
_orig_name
=
full_name
active_module_tracer
().
current_scope
().
_insert
(
expr
)
return
expr
.
outputs
[
0
]
def
interpret
(
self
,
*
inputs
):
...
...
@@ -453,7 +477,9 @@ class Constant(Expr):
)
def
__getstate__
(
self
):
state
=
super
().
__getstate__
()
state
=
self
.
__dict__
.
copy
()
if
"_top_graph"
in
state
:
state
.
pop
(
"_top_graph"
)
if
isinstance
(
self
.
value
,
RawTensor
):
state
[
"value"
]
=
Tensor
(
self
.
value
)
return
state
imperative/python/megengine/experimental/traced_module/module_tracer.py
浏览文件 @
a3f9073c
...
...
@@ -84,6 +84,34 @@ BUILTIN_ARRAY_METHOD = [
"__setitem__"
,
]
BUILTIN_TENSOR_WRAP_METHOD
=
[
"T"
,
"to"
,
"size"
,
"shape"
,
"detach"
,
"device"
,
"dtype"
,
"grad"
,
"item"
,
"name"
,
"ndim"
,
"numpy"
,
"qparams"
,
"set_value"
,
"reset_zero"
,
"requires_grad"
,
"_reset"
,
"_isscalar"
,
"_setscalar"
,
"_tuple_shape"
,
"_unsetscalar"
,
]
def
get_tensor_wrapable_method
():
return
BUILTIN_TENSOR_WRAP_METHOD
+
BUILTIN_ARRAY_METHOD
def
active_module_tracer
():
return
_active_module_tracer
...
...
@@ -101,9 +129,10 @@ class module_tracer:
_active_scopes
=
None
def
__init__
(
self
,
wrap_fn
):
def
__init__
(
self
,
wrap_fn
,
id2name
):
self
.
_active_scopes
=
[]
self
.
patcher
=
Patcher
(
wrap_fn
)
self
.
id2name
=
id2name
@
classmethod
def
register_as_builtin
(
cls
,
mod
):
...
...
@@ -127,6 +156,10 @@ class module_tracer:
return
None
class
NotExist
:
pass
class
PatchedFn
:
frame_dict
=
None
name
=
None
...
...
@@ -138,14 +171,17 @@ class PatchedFn:
self
.
origin_fn
=
(
self
.
frame_dict
[
name
]
if
isinstance
(
frame_dict
,
collections
.
abc
.
Mapping
)
else
getattr
(
frame_dict
,
name
)
else
getattr
(
frame_dict
,
name
,
NotExist
)
)
def
set_func
(
self
,
func
):
if
isinstance
(
self
.
frame_dict
,
collections
.
abc
.
Mapping
):
self
.
frame_dict
[
self
.
name
]
=
func
else
:
setattr
(
self
.
frame_dict
,
self
.
name
,
func
)
if
func
is
not
NotExist
:
setattr
(
self
.
frame_dict
,
self
.
name
,
func
)
else
:
delattr
(
self
.
frame_dict
,
self
.
name
)
class
Patcher
:
...
...
imperative/python/megengine/experimental/traced_module/node.py
浏览文件 @
a3f9073c
...
...
@@ -30,14 +30,17 @@ class Node:
_id
=
None
_top_graph
=
None
# type: weakref.ReferenceType
_name
=
None
_orig_name
=
None
_format_spec
=
""
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
):
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
,
orig_name
:
str
=
None
):
self
.
expr
=
expr
self
.
users
=
[]
# List[Expr]
self
.
_id
=
Node
.
__total_id
Node
.
__total_id
+=
1
self
.
_name
=
name
self
.
_orig_name
=
orig_name
self
.
actual_node
=
[]
# type: List[Node]
def
__setstate__
(
self
,
d
):
self
.
__dict__
=
d
...
...
@@ -48,7 +51,7 @@ class Node:
return
self
.
__format__
(
format_spec
)
def
__format__
(
self
,
format_spec
:
str
)
->
str
:
if
format_spec
==
""
or
format_spec
is
None
:
if
not
format_spec
:
format_spec
=
Node
.
_format_spec
name
=
self
.
_name
if
name
is
None
:
...
...
@@ -100,9 +103,8 @@ class ModuleNode(Node):
module_type
=
Module
# type: Type[Module]
_owner
=
None
# type: weakref.ReferenceType
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
):
super
().
__init__
(
expr
,
name
)
self
.
actual_mnode
=
[]
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
,
orig_name
:
str
=
None
):
super
().
__init__
(
expr
,
name
,
orig_name
)
def
__getstate__
(
self
):
return
{
...
...
@@ -110,6 +112,7 @@ class ModuleNode(Node):
"users"
:
self
.
users
,
"_id"
:
self
.
_id
,
"_name"
:
self
.
_name
,
"_orig_name"
:
self
.
_orig_name
,
"module_type"
:
self
.
module_type
,
}
...
...
@@ -125,23 +128,67 @@ class TensorNode(Node):
``TensorNode`` represents the Tensor objects.
"""
shape
=
None
# type: Tuple[int]
dtype
=
None
# type: numpy.dtype
qparams
=
None
device
=
None
_shape
=
None
# type: Tuple[int]
_dtype
=
None
# type: numpy.dtype
_qparams
=
None
_device
=
None
_value
=
None
# type: Tensor
def
__getstate__
(
self
):
return
{
"expr"
:
self
.
expr
,
"users"
:
self
.
users
,
"_id"
:
self
.
_id
,
"
qparams"
:
self
.
qparams
,
"
shape"
:
self
.
shape
,
"
dtype"
:
self
.
dtype
,
"
device"
:
self
.
device
,
"
_qparams"
:
self
.
_
qparams
,
"
_shape"
:
self
.
_
shape
,
"
_dtype"
:
self
.
_
dtype
,
"
_device"
:
self
.
_
device
,
"_name"
:
self
.
_name
,
"_orig_name"
:
self
.
_orig_name
,
}
@
property
def
shape
(
self
):
return
self
.
_shape
@
shape
.
setter
def
shape
(
self
,
shape
):
self
.
_shape
=
shape
@
property
def
dtype
(
self
):
return
self
.
_dtype
@
dtype
.
setter
def
dtype
(
self
,
dtype
):
self
.
_dtype
=
dtype
@
property
def
device
(
self
):
return
self
.
_device
@
device
.
setter
def
device
(
self
,
device
):
self
.
_device
=
device
@
property
def
qparams
(
self
):
return
self
.
_qparams
@
qparams
.
setter
def
qparams
(
self
,
qparams
):
self
.
_qparams
=
qparams
@
property
def
value
(
self
):
return
self
.
_value
@
value
.
setter
def
value
(
self
,
value
):
if
isinstance
(
value
,
RawTensor
)
and
NodeMixin
.
get
(
value
,
None
)
is
not
None
:
setattr
(
value
,
"_NodeMixin__node"
,
None
)
self
.
_value
=
value
class
NodeMixin
(
abc
.
ABC
):
__node
=
None
...
...
@@ -156,13 +203,13 @@ class NodeMixin(abc.ABC):
assert
isinstance
(
node
,
TensorNode
)
assert
isinstance
(
value
,
RawTensor
)
if
isinstance
(
value
,
RawTensor
):
node
.
dtype
=
value
.
dtype
node
.
shape
=
(
node
.
_
dtype
=
value
.
dtype
node
.
_
shape
=
(
value
.
_tuple_shape
if
isinstance
(
value
,
Tensor
)
else
value
.
shape
)
node
.
device
=
value
.
device
node
.
_
device
=
value
.
device
if
hasattr
(
value
,
"_qparams"
)
and
value
.
_qparams
is
not
None
:
node
.
qparams
=
value
.
qparams
node
.
_
qparams
=
value
.
qparams
@
classmethod
def
wrap
(
cls
,
value
,
node
):
...
...
imperative/python/megengine/experimental/traced_module/pytree.py
浏览文件 @
a3f9073c
...
...
@@ -133,7 +133,7 @@ def _is_leaf(obj):
def
_leaf_type
(
node
):
if
isinstance
(
node
,
(
RawTensor
,
TensorNode
)):
return
(
Tensor
,
TensorNode
,
ArgsIndex
)
elif
isinstance
(
node
,
(
NodeMixin
,
Module
)):
elif
isinstance
(
node
,
(
NodeMixin
,
Module
,
ModuleNode
)):
return
(
Module
,
ModuleNode
,
NodeMixin
,
ArgsIndex
)
else
:
return
(
type
(
node
),
ArgsIndex
)
...
...
imperative/python/megengine/experimental/traced_module/traced_module.py
浏览文件 @
a3f9073c
...
...
@@ -9,14 +9,19 @@
import
builtins
import
collections
import
copy
import
ctypes
import
fnmatch
import
functools
import
inspect
import
keyword
import
re
import
weakref
from
inspect
import
getcallargs
,
getmembers
,
isclass
,
ismethod
from
itertools
import
chain
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Type
,
Union
from
megengine
import
tensor
from
...
import
functional
as
F
from
...
import
get_logger
from
...
import
module
as
M
...
...
@@ -44,8 +49,10 @@ from ...tensor import Tensor
from
.expr
import
Apply
,
CallFunction
,
CallMethod
,
Constant
,
Expr
,
GetAttr
,
Input
from
.fake_quant
import
FakeQuantize
as
TM_FakeQuant
from
.module_tracer
import
(
PatchedFn
,
Patcher
,
active_module_tracer
,
get_tensor_wrapable_method
,
module_tracer
,
set_active_module_tracer
,
)
...
...
@@ -70,46 +77,267 @@ def _is_leaf(node):
return
isinstance
(
node
,
RawTensor
)
def
wrap_tensors
(
tensors
:
Tensor
,
nodes
:
TensorNode
):
inp_tensors
=
copy
.
deepcopy
(
tensors
)
inp_tensors
,
inp_def_v
=
tree_flatten
(
inp_tensors
)
inp_nodes
,
inp_def_n
=
tree_flatten
(
nodes
)
for
v
,
n
in
zip
(
inp_tensors
,
inp_nodes
):
if
isinstance
(
n
,
TensorNode
)
and
isinstance
(
v
,
Tensor
):
NodeMixin
.
wrap_safe
(
v
,
n
)
return
inp_def_v
.
unflatten
(
inp_tensors
)
_enable_node_to_tensor
=
False
def
_convert_node_flag
():
return
_enable_node_to_tensor
def
_set_convert_node_flag
(
flag
:
bool
=
False
):
global
_enable_node_to_tensor
pre_flag
=
_enable_node_to_tensor
_enable_node_to_tensor
=
flag
return
pre_flag
def
_node_to_tensor
(
*
args
,
**
kwargs
):
tensors
=
[]
nodes
,
tree_def
=
tree_flatten
((
args
,
kwargs
))
for
n
in
nodes
:
if
isinstance
(
n
,
TensorNode
):
if
n
.
top_graph
is
not
None
:
active_module_tracer
().
current_scope
().
_add_input
(
n
)
value
=
n
.
value
if
value
is
None
:
flag
=
_set_convert_node_flag
(
False
)
unset_module_tracing
()
value
=
F
.
zeros
(
shape
=
n
.
_shape
,
dtype
=
n
.
_dtype
)
set_module_tracing
()
_set_convert_node_flag
(
flag
)
orig_n
=
NodeMixin
.
get
(
value
,
None
)
if
orig_n
is
None
or
"setitem"
not
in
orig_n
.
_name
:
NodeMixin
.
wrap_safe
(
value
,
n
)
tensors
.
append
(
value
)
else
:
tensors
.
append
(
n
)
tensors
=
tree_def
.
unflatten
(
tensors
)
return
tensors
def
_tensor_to_node
(
tensors
):
if
tensors
is
None
:
return
None
nodes
=
[]
tensors
,
out_def
=
tree_flatten
(
tensors
)
for
t
in
tensors
:
if
isinstance
(
t
,
Tensor
):
n
=
NodeMixin
.
get
(
t
,
None
)
if
isinstance
(
n
,
TensorNode
):
n
.
value
=
t
nodes
.
append
(
n
)
else
:
nodes
.
append
(
t
)
else
:
nodes
.
append
(
t
)
nodes
=
out_def
.
unflatten
(
nodes
)
return
nodes
def
_wrap_method_to_tensor_node
():
def
_any_method
(
name
):
def
_any
(
*
args
,
**
kwargs
):
args
,
kwargs
=
_node_to_tensor
(
*
args
,
**
kwargs
)
attr
=
getattr
(
args
[
0
],
name
)
outs
=
attr
if
callable
(
attr
):
outs
=
attr
(
*
(
args
[
1
:]),
**
kwargs
)
if
name
==
"__setitem__"
:
_node_to_tensor
(
outs
)
return
None
outs
=
_tensor_to_node
(
outs
)
return
outs
return
_any
tensor_method_patch
=
[]
for
method
in
get_tensor_wrapable_method
():
patch
=
PatchedFn
(
TensorNode
,
method
)
if
type
(
getattr
(
Tensor
,
method
))
==
property
:
patch
.
set_func
(
property
(
_any_method
(
method
)))
else
:
patch
.
set_func
(
_any_method
(
method
))
tensor_method_patch
.
append
(
patch
)
return
tensor_method_patch
def
_convert_node_and_tensor
(
orig_func
):
@
functools
.
wraps
(
orig_func
)
def
_convert
(
*
args
,
**
kwargs
):
if
_convert_node_flag
()
and
is_tracing_module
():
args
,
kwargs
=
_node_to_tensor
(
*
args
,
**
kwargs
)
rst
=
orig_func
(
*
args
,
**
kwargs
,
method_func
=
_convert
)
rst
=
_tensor_to_node
(
rst
)
return
rst
else
:
rst
=
orig_func
(
*
args
,
**
kwargs
)
return
rst
return
_convert
def
_wrap_mnode_getattr
(
orig_getattr
):
@
functools
.
wraps
(
orig_getattr
)
def
wraped_fn
(
self
,
name
):
obj
=
self
.
owner
if
self
.
top_graph
is
not
None
:
active_module_tracer
().
current_scope
().
_add_input
(
self
)
attr
=
getattr
(
obj
,
name
)
node
=
attr
full_name
=
None
if
id
(
attr
)
in
active_module_tracer
().
id2name
:
full_name
=
active_module_tracer
().
id2name
[
id
(
attr
)]
if
not
isinstance
(
attr
,
TracedModuleBuilder
):
if
isinstance
(
attr
,
Module
):
attr
=
TracedModuleBuilder
(
attr
)
setattr
(
obj
,
name
,
attr
)
active_module_tracer
().
id2name
[
id
(
attr
)]
=
full_name
if
isinstance
(
attr
,
(
NodeMixin
,
RawTensor
)):
if
full_name
:
scope_name
=
active_module_tracer
().
current_scope
().
_module_name
if
scope_name
:
full_name
=
full_name
[
len
(
scope_name
)
+
1
:]
else
:
full_name
=
name
else
:
full_name
=
name
NodeMixin
.
wrap
(
attr
,
lambda
:
GetAttr
.
make
(
self
,
name
,
type
=
NodeMixin
.
get_wrapped_type
(
attr
),
orig_name
=
full_name
,
),
)
if
isinstance
(
attr
,
(
NodeMixin
,
RawTensor
)):
node
=
NodeMixin
.
get
(
attr
)
if
isinstance
(
node
,
ModuleNode
):
node
.
_owner
=
weakref
.
ref
(
attr
)
return
node
return
wraped_fn
def
_wrap_mnode_call
(
orig_call
):
@
functools
.
wraps
(
orig_call
)
def
wraped_fn
(
self
,
*
args
,
**
kwargs
):
obj
=
self
.
owner
if
self
.
top_graph
is
not
None
:
active_module_tracer
().
current_scope
().
_add_input
(
self
)
rst
=
obj
(
*
args
,
**
kwargs
)
return
rst
return
wraped_fn
def
_init_id2name
(
mod
:
Module
,
prefix
:
str
=
""
):
id2name
=
{
id
(
m
):
"%s.%s"
%
(
prefix
,
key
)
for
key
,
m
in
chain
(
mod
.
named_modules
(),
mod
.
named_parameters
(),
mod
.
named_buffers
()
)
}
return
id2name
class
_InsertExprs
:
def
__init__
(
self
,
graph
,
expr
:
Optional
[
Expr
]
=
None
,
after
:
bool
=
True
):
def
__init__
(
self
,
graph
,
expr
:
Optional
[
Expr
]
=
None
):
self
.
graph
=
graph
self
.
global_scope
=
InternalGraph
()
self
.
global_scope
=
InternalGraph
(
graph
.
_name
,
graph
.
_prefix_name
,
graph
.
_module_name
)
self
.
global_scope
.
_used_names
.
update
(
graph
.
_used_names
)
self
.
expr
=
expr
self
.
after
=
after
self
.
_tensor_method_patch
=
None
def
__enter__
(
self
):
self
.
use_sym_shape
=
set_symbolic_shape
(
True
)
set_module_tracing
()
_set_convert_node_flag
(
True
)
assert
active_module_tracer
()
is
None
set_active_module_tracer
(
module_tracer
(
_wrapped_function
))
module
=
self
.
graph
.
inputs
[
0
].
owner
_wrap_func
=
lambda
x
:
_convert_node_and_tensor
(
_wrapped_function
(
x
))
set_active_module_tracer
(
module_tracer
(
_wrap_func
,
_init_id2name
(
module
,
self
.
graph
.
_module_name
))
)
active_module_tracer
().
patcher
.
__enter__
()
for
cls
,
name
,
func
in
[
[
ModuleNode
,
"__getattr__"
,
_wrap_mnode_getattr
],
[
ModuleNode
,
"__call__"
,
_wrap_mnode_call
],
[
TracedModuleBuilder
,
"__call__"
,
_convert_node_and_tensor
],
]:
active_module_tracer
().
patcher
.
patch_function
(
cls
,
name
,
func
)
self
.
_tensor_method_patch
=
_wrap_method_to_tensor_node
()
active_module_tracer
().
push_scope
(
self
.
global_scope
)
def
__exit__
(
self
,
ty
,
va
,
tr
):
if
va
is
not
None
:
return
False
set_symbolic_shape
(
self
.
use_sym_shape
)
unset_module_tracing
()
active_module_tracer
().
patcher
.
__exit__
(
ty
,
va
,
tr
)
_set_convert_node_flag
(
False
)
while
self
.
_tensor_method_patch
:
pf
=
self
.
_tensor_method_patch
.
pop
()
pf
.
set_func
(
pf
.
origin_fn
)
module
=
self
.
graph
.
inputs
[
0
].
owner
for
mod
,
parent
in
module
.
modules
(
with_parent
=
True
):
name
=
mod
.
_name
if
isinstance
(
mod
,
TracedModuleBuilder
):
mod
=
mod
.
build
()
if
hasattr
(
mod
,
"graph"
):
for
node
in
mod
.
graph
.
nodes
():
node
.
value
=
None
setattr
(
parent
,
name
,
mod
)
set_active_module_tracer
(
None
)
index
=
len
(
self
.
graph
.
_exprs
)
if
self
.
after
else
0
for
node
in
self
.
global_scope
.
nodes
():
node
.
value
=
None
extra_inp_nodes
=
set
(
self
.
global_scope
.
inputs
)
max_inp_expr_idx
=
-
1
for
node
in
extra_inp_nodes
:
assert
(
node
.
top_graph
==
self
.
graph
),
"The input node ({}) is not in the graph ({})"
.
format
(
node
,
self
.
graph
)
if
isinstance
(
node
,
TensorNode
)
and
node
.
expr
in
self
.
graph
.
_exprs
:
max_inp_expr_idx
=
max
(
max_inp_expr_idx
,
self
.
graph
.
_exprs
.
index
(
node
.
expr
)
)
max_inp_expr_idx
+=
1
insert_index
=
-
1
if
self
.
expr
is
not
None
:
index
=
self
.
graph
.
_exprs
.
index
(
self
.
expr
)
if
self
.
after
:
index
+=
1
insert_index
=
self
.
graph
.
_exprs
.
index
(
self
.
expr
)
insert_index
+=
1
if
insert_index
<
max_inp_expr_idx
:
insert_index
=
max_inp_expr_idx
anchor_index
=
insert_index
-
1
if
anchor_index
>=
0
:
logger
.
info
(
"The new expr will be inserted after ( {} )"
.
format
(
self
.
graph
.
_exprs
[
anchor_index
]
)
)
for
expr
in
self
.
global_scope
.
_exprs
:
self
.
graph
.
_exprs
.
insert
(
index
,
expr
)
index
+=
1
self
.
graph
.
_exprs
.
insert
(
insert_index
,
expr
)
insert_index
+=
1
self
.
graph
.
_used_names
.
update
(
self
.
global_scope
.
_used_names
)
graph
=
self
.
graph
while
graph
.
top_graph
is
not
None
:
graph
=
graph
.
top_graph
graph
.
inputs
[
0
].
owner
.
_update_ref
()
return
True
class
InternalGraph
:
...
...
@@ -125,8 +353,9 @@ class InternalGraph:
_exprs
=
None
# type: List[Expr]
_inputs
=
None
# type: List[Node]
_outputs
=
None
# type: List[Node]
_top_graph
=
None
def
__init__
(
self
,
name
:
str
=
None
,
prefix_name
:
str
=
""
):
def
__init__
(
self
,
name
:
str
=
None
,
prefix_name
:
str
=
""
,
module_name
:
str
=
""
):
self
.
_exprs
=
[]
self
.
_inputs
=
[]
self
.
_outputs
=
[]
...
...
@@ -136,12 +365,13 @@ class InternalGraph:
self
.
_rst
=
collections
.
defaultdict
(
list
)
self
.
_name
=
name
self
.
_prefix_name
=
prefix_name
self
.
_module_name
=
module_name
def
insert
(
self
,
expr
):
def
_
insert
(
self
,
expr
):
self
.
_exprs
.
append
(
expr
)
def
_create_unique_name
(
self
,
name
:
str
)
->
str
:
assert
isinstance
(
name
,
str
)
assert
isinstance
(
name
,
str
)
,
"The name must be a str"
name
=
re
.
sub
(
"[^0-9a-zA-Z_]+"
,
"_"
,
name
)
if
name
[
0
].
isdigit
():
name
=
"_{}"
.
format
(
name
)
...
...
@@ -166,40 +396,45 @@ class InternalGraph:
return
self
.
_outputs
@
property
def
expr_filter
(
self
):
return
ExprFilter
(
_expr_iter
(
self
))
def
top_graph
(
self
):
if
self
.
_top_graph
:
return
self
.
_top_graph
()
return
None
@
property
def
node_filter
(
self
):
return
NodeFilter
(
_node_iter
(
self
))
def
exprs
(
self
,
recursive
=
True
):
return
ExprFilter
(
_expr_iter
(
self
,
recursive
))
def
nodes
(
self
,
recursive
=
True
):
return
NodeFilter
(
_node_iter
(
self
,
recursive
))
def
get_function_by_type
(
self
,
func
:
Callable
=
None
):
return
self
.
expr
_filter
.
call_function
(
func
)
def
get_function_by_type
(
self
,
func
:
Callable
=
None
,
recursive
=
True
):
return
self
.
expr
s
(
recursive
)
.
call_function
(
func
)
def
get_method_by_type
(
self
,
method
:
str
=
None
):
return
self
.
expr
_filter
.
call_method
(
method
)
def
get_method_by_type
(
self
,
method
:
str
=
None
,
recursive
=
True
):
return
self
.
expr
s
(
recursive
)
.
call_method
(
method
)
def
get_expr_by_id
(
self
,
expr_id
:
List
[
int
]
=
None
):
return
self
.
expr
_filter
.
expr_id
(
expr_id
)
def
get_expr_by_id
(
self
,
expr_id
:
List
[
int
]
=
None
,
recursive
=
True
):
return
self
.
expr
s
(
recursive
)
.
expr_id
(
expr_id
)
def
get_module_by_type
(
self
,
module_cls
:
Module
):
def
get_module_by_type
(
self
,
module_cls
:
Module
,
recursive
=
True
):
assert
issubclass
(
module_cls
,
Module
)
return
self
.
node
_filter
.
type
(
module_cls
,
ModuleNode
)
return
self
.
node
s
(
recursive
)
.
type
(
module_cls
,
ModuleNode
)
def
get_node_by_id
(
self
,
node_id
:
List
[
int
]
=
None
):
return
self
.
node
_filter
.
node_id
(
node_id
)
def
get_node_by_id
(
self
,
node_id
:
List
[
int
]
=
None
,
recursive
=
True
):
return
self
.
node
s
(
recursive
)
.
node_id
(
node_id
)
def
get_node_by_name
(
self
,
name
:
str
=
None
,
ignorecase
:
bool
=
True
):
return
self
.
node_filter
.
name
(
name
,
ignorecase
)
def
get_node_by_name
(
self
,
name
:
str
=
None
,
ignorecase
:
bool
=
True
,
recursive
=
True
):
return
self
.
nodes
(
recursive
).
name
(
name
,
ignorecase
)
def
add_input
(
self
,
i
):
def
_
add_input
(
self
,
i
):
self
.
_inputs
.
append
(
i
)
def
add_output
(
self
,
o
):
def
_
add_output
(
self
,
o
):
self
.
_outputs
.
append
(
o
)
def
_replace_inputs_outputs_and_add_prefixname
(
self
,
repl_dict
,
prefix_name
=
""
):
def
_replace_inputs_outputs
(
self
,
repl_dict
,
prefix_name
=
""
,
module_name
=
""
):
for
node
,
repl_node
in
repl_dict
.
items
():
assert
node
in
self
.
_inputs
or
node
in
self
.
_outputs
for
i
in
node
.
users
:
...
...
@@ -212,12 +447,15 @@ class InternalGraph:
for
idx
,
o
in
enumerate
(
self
.
_outputs
):
if
o
in
repl_dict
:
repl_dict
[
o
].
_orig_name
=
"{}{}"
.
format
(
module_name
,
o
.
_orig_name
)
self
.
_outputs
[
idx
]
=
repl_dict
[
o
]
for
expr
in
self
.
_exprs
:
for
idx
,
i
in
enumerate
(
expr
.
inputs
):
assert
i
.
_name
is
not
None
assert
isinstance
(
i
.
_name
,
str
),
"The node ({}) name must be a str"
.
format
(
i
)
if
i
in
repl_dict
:
expr
.
inputs
[
idx
]
=
repl_dict
[
i
]
elif
isinstance
(
i
,
TensorNode
)
and
prefix_name
not
in
i
.
_name
:
...
...
@@ -227,9 +465,12 @@ class InternalGraph:
.
current_scope
()
.
_create_unique_name
(
prefix_name
+
i
.
_name
.
lstrip
(
"_"
))
)
i
.
_orig_name
=
"{}{}"
.
format
(
module_name
,
i
.
_orig_name
)
for
idx
,
o
in
enumerate
(
expr
.
outputs
):
assert
o
.
_name
is
not
None
assert
isinstance
(
o
.
_name
,
str
),
"The node ({}) name must be a str"
.
format
(
i
)
if
o
in
repl_dict
:
expr
.
outputs
[
idx
]
=
repl_dict
[
o
]
expr
.
outputs
[
idx
].
expr
=
expr
...
...
@@ -240,6 +481,7 @@ class InternalGraph:
.
current_scope
()
.
_create_unique_name
(
prefix_name
+
o
.
_name
.
lstrip
(
"_"
))
)
o
.
_orig_name
=
"{}{}"
.
format
(
module_name
,
o
.
_orig_name
)
def
get_dep_exprs
(
self
,
nodes
:
Sequence
[
Node
])
->
List
[
Expr
]:
if
not
isinstance
(
nodes
,
Sequence
):
...
...
@@ -263,7 +505,7 @@ class InternalGraph:
def
reset_inputs
(
self
,
*
args
,
**
kwargs
):
forma_mnode
=
self
.
inputs
[
0
]
actual_mnodes
=
forma_mnode
.
actual_
m
node
actual_mnodes
=
forma_mnode
.
actual_node
call_nodes
=
[]
for
n
in
actual_mnodes
:
for
c_expr
in
n
.
users
:
...
...
@@ -318,7 +560,7 @@ class InternalGraph:
def
add_input_node
(
self
,
shape
,
dtype
=
"float32"
,
name
=
"args"
):
forma_mnode
=
self
.
inputs
[
0
]
actual_mnodes
=
forma_mnode
.
actual_
m
node
actual_mnodes
=
forma_mnode
.
actual_node
moudle
=
forma_mnode
.
owner
assert
moudle
.
_is_top
,
"add_input_node only support the top-level graph"
...
...
@@ -378,7 +620,7 @@ class InternalGraph:
moudle
=
forma_mnode
.
owner
assert
moudle
.
_is_top
,
"reset_outputs only support the top-level graph"
actual_mnodes
=
forma_mnode
.
actual_
m
node
actual_mnodes
=
forma_mnode
.
actual_node
call_nodes
=
[]
for
n
in
actual_mnodes
:
for
c_expr
in
n
.
users
:
...
...
@@ -406,7 +648,6 @@ class InternalGraph:
self
.
_outputs
[:]
=
outputs
moudle
.
argdef_outdef_map
[
tree_def
]
=
out_def
return
actual_nodes
def
add_output_node
(
self
,
node
:
TensorNode
):
...
...
@@ -415,7 +656,7 @@ class InternalGraph:
moudle
=
forma_mnode
.
owner
assert
moudle
.
_is_top
,
"add_output_node only support the top-level graph"
actual_mnodes
=
forma_mnode
.
actual_
m
node
actual_mnodes
=
forma_mnode
.
actual_node
call_nodes
=
[]
for
n
in
actual_mnodes
:
...
...
@@ -455,74 +696,35 @@ class InternalGraph:
return
actual_out_nodes
def
insert_function
(
self
,
func
:
Callable
,
*
args
,
**
kwargs
):
assert
isinstance
(
func
,
Callable
)
inp_nodes
,
inp_def
=
tree_flatten
((
args
,
kwargs
))
insert_idx
=
-
1
for
i
in
inp_nodes
:
if
isinstance
(
i
,
TensorNode
)
and
i
.
expr
in
self
.
_exprs
:
insert_idx
=
max
(
insert_idx
,
self
.
_exprs
.
index
(
i
.
expr
))
fake_inp_val
=
list
(
F
.
zeros
(
shape
=
i
.
shape
,
dtype
=
i
.
dtype
)
if
isinstance
(
i
,
TensorNode
)
else
i
for
i
in
inp_nodes
)
for
v
,
n
in
zip
(
fake_inp_val
,
inp_nodes
):
if
isinstance
(
n
,
TensorNode
):
NodeMixin
.
wrap_safe
(
v
,
n
)
fake_args
,
fake_kwargs
=
inp_def
.
unflatten
(
fake_inp_val
)
insert_point
=
self
.
insert_exprs_before
()
if
insert_idx
!=
-
1
:
insert_point
=
self
.
insert_exprs_after
(
self
.
_exprs
[
insert_idx
])
with
insert_point
:
rst
=
func
(
*
fake_args
,
**
fake_kwargs
)
if
rst
is
None
:
return
None
outputs
,
out_def
=
tree_flatten
(
rst
,
is_leaf
=
_is_leaf
)
node_outputs
=
[]
for
out
in
outputs
:
assert
isinstance
(
out
,
RawTensor
)
node_outputs
.
append
(
NodeMixin
.
get
(
out
,
None
))
node_outputs
=
out_def
.
unflatten
(
node_outputs
)
return
node_outputs
def
insert_exprs_after
(
self
,
expr
:
Optional
[
Expr
]
=
None
):
def
insert_exprs
(
self
,
expr
:
Optional
[
Expr
]
=
None
):
if
expr
is
not
None
:
assert
expr
.
top_graph
==
self
,
"Expr to insert after is not in graph."
return
_InsertExprs
(
self
,
expr
,
after
=
True
)
def
insert_exprs_before
(
self
,
expr
:
Optional
[
Expr
]
=
None
):
if
expr
is
not
None
:
assert
expr
.
top_graph
==
self
,
"Expr to insert before is not in graph."
return
_InsertExprs
(
self
,
expr
,
after
=
False
)
return
_InsertExprs
(
self
,
expr
)
def
replace_node
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
while
repl_dict
:
node
,
repl_node
=
repl_dict
.
popitem
()
# check graph inputs and outputs
assert
node
not
in
self
.
inputs
,
"Cannot replace inputs"
#
assert node not in self.inputs, "Cannot replace inputs"
for
i
,
n
in
enumerate
(
self
.
outputs
):
if
n
is
node
:
self
.
outputs
[
i
]
=
repl_node
# update users of node and repl_node
# update inputs of expr in node.users
graph
=
repl_node
.
top_graph
assert
graph
is
not
None
index
=
graph
.
_exprs
.
index
(
repl_node
.
expr
)
dep_exprs
=
self
.
get_dep_exprs
(
repl_node
)
i
=
0
while
i
<
len
(
node
.
users
):
n
=
node
.
users
[
i
]
if
n
in
graph
.
_exprs
and
index
>=
graph
.
_exprs
.
index
(
n
):
i
+=
1
continue
if
n
in
dep_exprs
:
logger
.
info
(
"Find a loop: ignore this replacement once"
)
logger
.
info
(
"node: %s"
%
node
.
__repr__
())
logger
.
info
(
"
repl_node: %s"
%
repl_node
.
__repr__
())
logger
.
info
(
"
expr: %s"
%
n
.
__repr__
())
i
+=
1
continue
repl_node
.
users
.
append
(
n
)
...
...
@@ -598,6 +800,12 @@ class InternalGraph:
Node
.
set_format_spec
(
saved_format_spec
)
return
res
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
if
"_top_graph"
in
state
:
state
.
pop
(
"_top_graph"
)
return
state
def
_get_meth_name
(
obj
,
func
):
tp
=
obj
if
isinstance
(
obj
,
type
)
else
type
(
obj
)
...
...
@@ -611,6 +819,9 @@ def _get_meth_name(obj, func):
def
_wrapped_function
(
orig_func
):
@
functools
.
wraps
(
orig_func
)
def
wrapped_fn
(
*
args
,
**
kwargs
):
method_func
=
wrapped_fn
if
"method_func"
in
kwargs
:
method_func
=
kwargs
.
pop
(
"method_func"
)
if
is_tracing_module
():
unset_module_tracing
()
inputs
,
tree_def
=
tree_flatten
((
args
,
kwargs
))
...
...
@@ -618,9 +829,11 @@ def _wrapped_function(orig_func):
if
not
NodeMixin
.
get
(
i
,
None
):
if
isinstance
(
i
,
(
RawTensor
,
NodeMixin
)):
NodeMixin
.
wrap_safe
(
i
,
Constant
.
make
(
i
))
meth_name
=
_get_meth_name
(
args
[
0
],
wrapped_fn
)
if
args
else
None
arg_type
=
args
[
0
]
if
isinstance
(
args
[
0
],
type
)
else
type
(
args
[
0
])
if
meth_name
and
issubclass
(
arg_type
,
RawTensor
):
meth_name
,
arg_type
=
None
,
None
if
args
:
meth_name
=
_get_meth_name
(
args
[
0
],
method_func
)
arg_type
=
args
[
0
]
if
isinstance
(
args
[
0
],
type
)
else
type
(
args
[
0
])
if
meth_name
and
arg_type
and
issubclass
(
arg_type
,
RawTensor
):
self
=
inputs
[
0
]
if
meth_name
==
"__new__"
:
if
all
([
not
isinstance
(
i
,
RawTensor
)
for
i
in
inputs
]):
...
...
@@ -799,6 +1012,9 @@ class TracedModuleBuilder(NodeMixin):
def
__call__
(
self
,
*
args
,
**
kwargs
):
assert
isinstance
(
self
.
_mod
,
Module
)
# prepare args and kwargs for inner graph
if
"method_func"
in
kwargs
:
kwargs
.
pop
(
"method_func"
)
def
mark_constant
(
x
):
node
=
NodeMixin
.
get
(
x
,
None
)
if
node
is
None
:
# capture as constant
...
...
@@ -829,9 +1045,6 @@ class TracedModuleBuilder(NodeMixin):
else
:
self
.
_mod
.
_is_top
=
False
self
.
_body
=
self
.
_mod
.
graph
name
=
NodeMixin
.
get
(
self
).
_name
if
name
:
self
.
_body
.
_name
=
name
else
:
self_node
=
None
orig_self
=
NodeMixin
.
get
(
self
)
...
...
@@ -841,19 +1054,24 @@ class TracedModuleBuilder(NodeMixin):
graph_prefix_name
=
"{}_{}"
.
format
(
top_graph
.
_prefix_name
,
graph_prefix_name
.
lstrip
(
"_"
)
)
self
.
_body
=
InternalGraph
(
orig_self
.
_name
,
prefix_name
=
graph_prefix_name
)
module_name
=
orig_self
.
_orig_name
if
top_graph
.
_module_name
:
module_name
=
"{}.{}"
.
format
(
top_graph
.
_module_name
,
module_name
)
self
.
_body
=
InternalGraph
(
orig_self
.
_name
,
prefix_name
=
graph_prefix_name
,
module_name
=
module_name
)
active_module_tracer
().
push_scope
(
self
.
_body
)
# rebind self to new input node
if
self_node
:
NodeMixin
.
wrap_safe
(
self
,
self_node
)
active_module_tracer
().
current_scope
().
add_input
(
self_node
)
active_module_tracer
().
current_scope
().
_
add_input
(
self_node
)
else
:
NodeMixin
.
wrap_safe
(
self
,
self_node
if
self_node
else
Input
.
make
(
"self"
,
NodeMixin
.
get_wrapped_type
(
self
)),
else
Input
.
make
(
"self"
,
NodeMixin
.
get_wrapped_type
(
self
)
,
""
),
)
origin_inp_node
=
[
NodeMixin
.
get
(
i
,
None
)
for
i
in
inputs
[
1
:]]
# prepare args and kwargs for inner graph
...
...
@@ -893,12 +1111,13 @@ class TracedModuleBuilder(NodeMixin):
getattr
(
getattr
(
self
.
_mod
,
"forward"
,
self
.
_mod
),
"__globals__"
,
{})
)
rst
=
type
(
self
.
_mod
).
forward
(
*
args
,
**
kwargs
)
if
_convert_node_flag
():
rst
=
_node_to_tensor
(
rst
)[
0
][
0
]
outputs
,
out_def
=
tree_flatten
(
rst
,
is_leaf
=
_is_leaf
)
for
i
in
(
outputs
if
isinstance
(
outputs
,
collections
.
abc
.
Sequence
)
else
(
outputs
,)
):
active_module_tracer
().
current_scope
().
add_output
(
NodeMixin
.
get
(
i
))
NodeMixin
.
get
(
self
,
None
).
actual_mnode
.
append
(
orig_self
)
active_module_tracer
().
current_scope
().
_add_output
(
NodeMixin
.
get
(
i
))
NodeMixin
.
wrap_safe
(
self
,
orig_self
)
for
arg
,
node
in
zip
(
inputs
[
1
:],
origin_inp_node
):
if
node
:
...
...
@@ -923,14 +1142,33 @@ class TracedModuleBuilder(NodeMixin):
attr
=
getattr
(
type
(
self
.
_mod
),
name
).
__get__
(
self
,
type
(
self
))
else
:
attr
=
getattr
(
self
.
_mod
,
name
)
full_name
=
None
if
id
(
attr
)
in
active_module_tracer
().
id2name
:
full_name
=
active_module_tracer
().
id2name
[
id
(
attr
)]
if
isinstance
(
attr
,
Module
):
attr
=
TracedModuleBuilder
(
attr
)
if
isinstance
(
attr
,
(
Module
,
RawTensor
)):
setattr
(
self
,
name
,
attr
)
active_module_tracer
().
id2name
[
id
(
attr
)]
=
full_name
if
full_name
:
scope_name
=
active_module_tracer
().
current_scope
().
_module_name
if
scope_name
:
full_name
=
full_name
[
len
(
scope_name
)
+
1
:]
else
:
full_name
=
name
else
:
full_name
=
name
NodeMixin
.
wrap
(
attr
,
lambda
:
GetAttr
.
make
(
NodeMixin
.
get
(
self
),
name
,
type
=
NodeMixin
.
get_wrapped_type
(
attr
)
NodeMixin
.
get
(
self
),
name
,
type
=
NodeMixin
.
get_wrapped_type
(
attr
),
orig_name
=
full_name
,
),
)
return
attr
...
...
@@ -951,7 +1189,16 @@ class TracedModuleBuilder(NodeMixin):
assert
mod_attr
is
wrapped
.
_mod
else
:
assert
mod_attr
is
wrapped
full_name
=
None
if
id
(
mod_attr
)
in
active_module_tracer
().
id2name
:
full_name
=
active_module_tracer
().
id2name
[
id
(
mod_attr
)]
scope_name
=
active_module_tracer
().
current_scope
().
_module_name
if
full_name
and
scope_name
:
full_name
=
full_name
[
len
(
scope_name
)
+
1
:]
else
:
full_name
=
name
else
:
full_name
=
name
# assert not self._is_builtin
if
isinstance
(
wrapped
,
(
NodeMixin
,
RawTensor
)):
NodeMixin
.
wrap
(
...
...
@@ -960,6 +1207,7 @@ class TracedModuleBuilder(NodeMixin):
NodeMixin
.
get
(
self
),
name
,
type
=
NodeMixin
.
get_wrapped_type
(
wrapped
),
orig_name
=
full_name
,
),
)
...
...
@@ -967,24 +1215,25 @@ class TracedModuleBuilder(NodeMixin):
class
_expr_iter
:
def
__init__
(
self
,
graph
:
InternalGraph
):
def
__init__
(
self
,
graph
:
InternalGraph
,
recursive
:
bool
=
True
):
self
.
graph
=
graph
self
.
recursive
=
recursive
def
__iter__
(
self
):
for
expr
in
self
.
graph
.
_exprs
:
if
isinstance
(
expr
,
CallMethod
)
and
isinstance
(
expr
.
inputs
[
0
],
ModuleNode
):
yield
expr
if
expr
.
graph
is
not
None
:
yield
from
expr
.
graph
.
expr
_filter
if
self
.
recursive
and
expr
.
graph
is
not
None
:
yield
from
expr
.
graph
.
expr
s
(
self
.
recursive
)
else
:
yield
expr
class
_node_iter
:
def
__init__
(
self
,
graph
:
InternalGraph
)
->
None
:
def
__init__
(
self
,
graph
:
InternalGraph
,
recursive
:
bool
=
True
)
->
None
:
nodes
=
[]
node_ids
=
set
()
for
expr
in
graph
.
expr
_filter
:
for
expr
in
graph
.
expr
s
(
recursive
)
:
for
n
in
expr
.
inputs
+
expr
.
outputs
:
if
n
.
_id
in
node_ids
:
continue
...
...
@@ -1210,14 +1459,17 @@ class TracedModule(Module):
assert
len
(
self
.
argdef_graph_map
)
==
1
return
list
(
self
.
argdef_graph_map
.
values
())[
0
]
def
_update_ref
(
self
,
actual_node_map
:
Union
[
Dict
]
=
None
):
def
_update_ref
(
self
,
actual_node_map
:
Union
[
Dict
]
=
None
,
top_graph
=
None
):
for
inp_def
,
graph
in
self
.
argdef_graph_map
.
items
():
if
top_graph
is
not
None
:
graph
.
_top_graph
=
weakref
.
ref
(
top_graph
)
for
n
in
graph
.
_inputs
+
graph
.
outputs
:
n
.
_top_graph
=
weakref
.
ref
(
graph
)
graph
.
_inputs
[
0
].
_owner
=
weakref
.
ref
(
self
)
graph
.
_inputs
[
0
].
actual_mnode
=
[]
if
actual_node_map
is
not
None
and
inp_def
in
actual_node_map
.
keys
():
graph
.
_inputs
[
0
].
actual_mnode
=
actual_node_map
[
inp_def
]
for
i
,
n
in
enumerate
(
graph
.
_inputs
):
n
.
actual_node
=
[]
if
actual_node_map
is
not
None
and
inp_def
in
actual_node_map
.
keys
():
n
.
actual_node
=
list
(
list
(
zip
(
*
(
actual_node_map
[
inp_def
])))[
i
])
node2obj
=
{}
next_actual_node_map
=
collections
.
defaultdict
(
lambda
:
collections
.
defaultdict
(
list
)
...
...
@@ -1246,7 +1498,7 @@ class TracedModule(Module):
):
obj
=
node2obj
[
expr
.
inputs
[
0
]]
if
expr
.
arg_def
is
not
None
:
next_actual_node_map
[
obj
][
expr
.
arg_def
].
append
(
expr
.
inputs
[
0
]
)
next_actual_node_map
[
obj
][
expr
.
arg_def
].
append
(
expr
.
inputs
)
for
obj
in
node2obj
.
values
():
if
obj
is
self
:
...
...
@@ -1255,7 +1507,7 @@ class TracedModule(Module):
if
obj
in
next_actual_node_map
.
keys
():
mnode_map
=
next_actual_node_map
[
obj
]
if
isinstance
(
obj
,
TracedModule
):
obj
.
_update_ref
(
mnode_map
)
obj
.
_update_ref
(
mnode_map
,
graph
)
def
flatten
(
self
):
"""
...
...
@@ -1264,21 +1516,25 @@ class TracedModule(Module):
:return: :class:`TracedModule`
"""
new_module
=
copy
.
deepcopy
(
self
)
module2name
=
{}
assert
active_module_tracer
()
is
None
set_active_module_tracer
(
module_tracer
(
lambda
x
:
x
))
id2name
=
_init_id2name
(
new_module
,
"self"
)
set_active_module_tracer
(
module_tracer
(
lambda
x
:
x
,
{}))
active_module_tracer
().
push_scope
(
new_module
.
graph
)
for
n
,
m
in
new_module
.
named_modules
():
module2name
[
id
(
m
)]
=
n
def
_flatten_subgraph
(
graph
:
InternalGraph
,
module
:
Module
,
call
=
None
,
prefix_name
=
""
graph
:
InternalGraph
,
module
:
Module
,
call
=
None
,
prefix_name
=
""
,
module_name
=
""
,
):
if
graph
is
not
None
and
prefix_name
and
prefix_name
[
-
1
]
!=
"_"
:
if
isinstance
(
prefix_name
,
str
)
and
prefix_name
and
prefix_name
[
-
1
]
!=
"_"
:
prefix_name
+=
"_"
if
isinstance
(
module_name
,
str
)
and
module_name
:
module_name
+=
"."
if
graph
is
None
or
module
.
is_qat
:
assert
not
isinstance
(
module
,
TracedModule
)
or
module
.
is_qat
const
=
Constant
(
module
,
"self.%s"
%
module
2name
[
id
(
module
)])
const
=
Constant
(
module
,
id
2name
[
id
(
module
)])
m_node
=
call
.
inputs
[
0
]
if
m_node
.
top_graph
!=
active_module_tracer
().
current_scope
():
m_node
.
_name
=
(
...
...
@@ -1286,6 +1542,7 @@ class TracedModule(Module):
.
current_scope
()
.
_create_unique_name
(
prefix_name
)
)
m_node
.
_orig_name
=
id2name
[
id
(
module
)][
5
:]
const
.
outputs
[
0
]
=
m_node
const
.
outputs
[
0
].
expr
=
const
return
[
const
,
call
]
...
...
@@ -1312,7 +1569,7 @@ class TracedModule(Module):
continue
repl_dict
[
out
]
=
call
.
outputs
[
ind
]
graph
.
_replace_inputs_outputs
_and_add_prefixname
(
repl_dict
,
prefix
_name
)
graph
.
_replace_inputs_outputs
(
repl_dict
,
prefix_name
,
module
_name
)
for
expr
in
graph
.
_exprs
:
if
isinstance
(
expr
,
GetAttr
):
...
...
@@ -1344,6 +1601,7 @@ class TracedModule(Module):
obj
,
expr
,
prefix_name
+
obj_node
.
_name
.
lstrip
(
"_"
),
module_name
+
obj_node
.
_orig_name
,
)
)
else
:
...
...
@@ -1358,7 +1616,6 @@ class TracedModule(Module):
if
call
is
not
None
:
for
i
in
call
.
inputs
:
i
.
users
.
remove
(
call
)
return
exprs
new_module
.
graph
.
_exprs
=
_flatten_subgraph
(
new_module
.
graph
,
new_module
)
...
...
@@ -1396,7 +1653,22 @@ def register_as_builtin(mod_cls: Type[Module]) -> None:
module_tracer
.
register_as_builtin
(
mod_cls
)
wrap
=
_wrapped_function
def
wrap
(
func
:
Callable
):
"""
Call this function to register func as a builtin function.
"""
assert
callable
(
func
),
"func must be a callable"
assert
hasattr
(
func
,
"__code__"
)
fn_name
=
func
.
__code__
.
co_name
currentframe
=
inspect
.
currentframe
()
assert
currentframe
is
not
None
f
=
currentframe
.
f_back
assert
f
is
not
None
assert
(
f
.
f_code
.
co_name
==
"<module>"
),
"wrap must be called at the top level of a module"
Patcher
.
_builtin_functions
.
append
((
f
.
f_globals
,
fn_name
))
return
func
def
_register_all_builtin_module
():
...
...
@@ -1438,14 +1710,15 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
try
:
use_sym_shape
=
set_symbolic_shape
(
True
)
set_module_tracing
()
set_active_module_tracer
(
module_tracer
(
_wrapped_function
))
set_active_module_tracer
(
module_tracer
(
_wrapped_function
,
_init_id2name
(
mod
,
"self"
))
)
with
active_module_tracer
().
patcher
:
global_scope
=
InternalGraph
(
name
=
""
)
active_module_tracer
().
push_scope
(
global_scope
)
builder
=
TracedModuleBuilder
(
mod
,
True
)
name
=
mod
.
_name
if
mod
.
_name
else
mod
.
__class__
.
__name__
NodeMixin
.
wrap_safe
(
builder
,
Input
.
make
(
name
,
ModuleNode
))
NodeMixin
.
wrap_safe
(
builder
,
Input
.
make
(
name
,
ModuleNode
,
orig_name
=
"self"
))
inputs
,
_
=
tree_flatten
((
args
,
kwargs
))
for
_
,
i
in
enumerate
(
inputs
):
# assert isinstance(i, Tensor), "not support "
...
...
imperative/python/test/unit/traced_module/test_modification.py
浏览文件 @
a3f9073c
...
...
@@ -64,9 +64,10 @@ def test_search():
def
test_insert
():
traced_module
,
x
,
expect
=
_init_block
()
graph
=
traced_module
.
graph
relu_node
=
graph
.
get_function_by_type
(
F
.
relu
).
as_unique
().
outputs
neg_node
=
graph
.
insert_function
(
lambda
x
:
F
.
neg
(
x
),
*
relu_node
)
graph
.
replace_node
({
relu_node
[
0
]:
neg_node
})
relu_out
=
graph
.
get_function_by_type
(
F
.
relu
).
as_unique
().
outputs
[
0
]
with
graph
.
insert_exprs
():
neg_out
=
F
.
neg
(
relu_out
)
graph
.
replace_node
({
relu_out
:
neg_out
})
graph
.
compile
()
np
.
testing
.
assert_allclose
(
expect
-
1
,
1
-
traced_module
(
x
),
atol
=
1e-6
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录