Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f2691566
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f2691566
编写于
7月 05, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(traced_module): add pytree
GitOrigin-RevId: 6c6e53521c71474c67590e0a94723a1d6be89218
上级
bee305be
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
266 addition
and
120 deletion
+266
-120
imperative/python/megengine/experimental/traced_module/expr.py
...ative/python/megengine/experimental/traced_module/expr.py
+51
-64
imperative/python/megengine/experimental/traced_module/module_tracer.py
...hon/megengine/experimental/traced_module/module_tracer.py
+68
-3
imperative/python/megengine/experimental/traced_module/node.py
...ative/python/megengine/experimental/traced_module/node.py
+2
-0
imperative/python/megengine/experimental/traced_module/pytree.py
...ive/python/megengine/experimental/traced_module/pytree.py
+80
-0
imperative/python/megengine/experimental/traced_module/traced_module.py
...hon/megengine/experimental/traced_module/traced_module.py
+65
-53
未找到文件。
imperative/python/megengine/experimental/traced_module/expr.py
浏览文件 @
f2691566
...
...
@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
builtins
import
collections
from
typing
import
Callable
,
List
...
...
@@ -19,6 +19,7 @@ from ...module import Module
from
...tensor
import
Tensor
from
.module_tracer
import
active_module_tracer
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.pytree
import
TreeDef
class
Expr
:
...
...
@@ -28,9 +29,22 @@ class Expr:
inputs
=
None
# type: List[Node]
outputs
=
None
# type: List[Node]
def
add_input
(
self
,
node
):
self
.
inputs
.
append
(
node
)
const_val
=
None
# type: List[Any]
arg_def
=
None
# type: TreeDef
def
add_inputs
(
self
,
vals
):
if
not
isinstance
(
vals
,
collections
.
abc
.
Sequence
):
vals
=
(
vals
,)
for
val
in
vals
:
node
=
NodeMixin
.
get
(
val
,
None
)
if
isinstance
(
node
,
(
TensorNode
,
ModuleNode
)):
if
node
not
in
self
.
inputs
:
self
.
inputs
.
append
(
node
)
else
:
assert
node
is
None
assert
type
(
val
)
in
builtins
.
__dict__
.
values
()
idx
=
len
(
self
.
inputs
)
+
len
(
self
.
const_val
)
self
.
const_val
.
append
((
idx
,
val
))
def
add_outputs
(
self
,
outputs
):
self
.
outputs
=
[]
...
...
@@ -38,50 +52,31 @@ class Expr:
outputs
=
(
outputs
,)
for
i
in
outputs
:
assert
isinstance
(
i
,
RawTensor
)
self
.
outputs
.
append
(
NodeMixin
.
get_wrapped_type
(
i
)(
self
))
for
i
,
node
in
zip
(
outputs
,
self
.
outputs
,):
NodeMixin
.
wrap_safe
(
i
,
node
)
@
classmethod
def
get_args_node
(
cls
,
arg
):
"""
Create nodes by ``arg``, which may be a container.
Return the same structure with arg.
If ``arg`` was not Tensor or Module, it will be stored as const.
:param arg: tensor, module or const.
"""
if
isinstance
(
arg
,
(
RawTensor
,
Module
)):
if
not
NodeMixin
.
get
(
arg
,
None
):
NodeMixin
.
wrap_safe
(
arg
,
Constant
.
make
(
arg
))
return
NodeMixin
.
get
(
arg
)
elif
isinstance
(
arg
,
collections
.
abc
.
Sequence
):
seq_cls
=
type
(
arg
)
return
seq_cls
([
Expr
.
get_args_node
(
a
)
for
a
in
arg
])
def
unflatten_args
(
self
,
inputs
):
if
self
.
arg_def
is
not
None
:
inputs
=
list
(
inputs
)
for
idx
,
val
in
self
.
const_val
:
inputs
.
insert
(
idx
,
val
)
args
,
kwargs
=
self
.
arg_def
.
unflatten
(
inputs
)
return
args
,
kwargs
else
:
# TODO: assert arg type
return
arg
# as const
return
inputs
,
{}
@
classmethod
def
get_arg_value
(
cls
,
inp_node
,
node2value
):
"""
Get values from node2value by inp_node, which may be a container.
Return the same structure with inp_node.
If ``inp_node`` was not in node2value, it is a const.
:param inp_node: nodes.
:param node2value: dict from node to tensor and module.
"""
if
inp_node
in
node2value
:
return
node2value
[
inp_node
]
elif
isinstance
(
inp_node
,
collections
.
abc
.
Sequence
):
seq_cls
=
type
(
inp_node
)
return
seq_cls
([
Expr
.
get_arg_value
(
i
,
node2value
)
for
i
in
inp_node
])
else
:
return
inp_node
@
property
def
kwargs
(
self
):
_
,
kwargs
=
self
.
unflatten_args
(
self
.
inputs
)
return
kwargs
@
property
def
args
(
self
):
args
,
_
=
self
.
unflatten_args
(
self
.
inputs
)
return
args
# expr: None (i.e. fake expression which is used to mark input)
...
...
@@ -144,16 +139,8 @@ class CallMethod(Expr):
self
.
inputs
=
[
module
,
]
self
.
const_val
=
[]
self
.
method
=
method
self
.
arg_names
=
[]
self
.
kwargs
=
{}
# const kwargs
def
add_input
(
self
,
node
,
arg_name
=
None
):
if
arg_name
==
"self"
:
# FIXME: <XP>
return
self
.
inputs
.
append
(
node
)
if
arg_name
is
not
None
:
self
.
arg_names
.
append
(
arg_name
)
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
...
...
@@ -162,19 +149,22 @@ class CallMethod(Expr):
return
expr
def
interpret
(
self
,
*
inputs
):
mod
=
inputs
[
0
]
args
=
inputs
[
1
:]
outputs
=
getattr
(
mod
,
self
.
method
)(
*
args
,
**
self
.
kwargs
)
args
,
kwargs
=
self
.
unflatten_args
(
inputs
)
obj
=
args
[
0
]
args
=
args
[
1
:]
outputs
=
getattr
(
obj
,
self
.
method
)(
*
args
,
**
kwargs
)
if
isinstance
(
outputs
,
RawTensor
):
outputs
=
(
outputs
,)
return
outputs
def
__repr__
(
self
):
return
"{} = CallMethod({}, {})({})"
.
format
(
args
=
", "
.
join
(
str
(
i
)
for
i
in
self
.
args
[
1
:])
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
return
"{} = {}.{}({})"
.
format
(
", "
.
join
(
str
(
i
)
for
i
in
self
.
outputs
),
self
.
inputs
[
0
],
self
.
method
,
", "
.
join
(
str
(
i
)
for
i
in
self
.
inputs
[
1
:
]),
", "
.
join
(
[
args
,
kwargs
]),
)
...
...
@@ -227,13 +217,8 @@ class CallFunction(Expr):
def
__init__
(
self
,
func
):
assert
isinstance
(
func
,
Callable
)
self
.
func
=
func
self
.
const_val
=
[]
self
.
inputs
=
[]
self
.
arg_names
=
[]
self
.
kwargs
=
{}
# const kwargs
def
add_input
(
self
,
node
,
arg_name
):
self
.
inputs
.
append
(
node
)
self
.
arg_names
.
append
(
arg_name
)
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
...
...
@@ -242,18 +227,20 @@ class CallFunction(Expr):
return
expr
def
interpret
(
self
,
*
inputs
):
inp_dict
=
dict
([(
name
,
node
)
for
node
,
name
in
zip
(
inputs
,
self
.
arg_names
)]
)
outputs
=
self
.
func
(
*
*
inp_dict
,
**
self
.
kwargs
)
args
,
kwargs
=
self
.
unflatten_args
(
inputs
)
outputs
=
self
.
func
(
*
args
,
**
kwargs
)
outputs
=
(
outputs
if
isinstance
(
outputs
,
collections
.
abc
.
Sequence
)
else
(
outputs
,)
)
return
outputs
def
__repr__
(
self
):
args
=
", "
.
join
(
str
(
i
)
for
i
in
self
.
args
)
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
return
"{} = {}({})"
.
format
(
", "
.
join
(
str
(
i
)
for
i
in
self
.
outputs
),
self
.
func
.
__module__
+
"."
+
self
.
func
.
__name__
,
", "
.
join
(
str
(
i
)
for
i
in
self
.
inputs
),
", "
.
join
(
[
args
,
kwargs
]
),
)
...
...
imperative/python/megengine/experimental/traced_module/module_tracer.py
浏览文件 @
f2691566
...
...
@@ -15,6 +15,72 @@ from ...module import Module
_active_module_tracer
=
None
BUILTIN_ARRAY_METHOD
=
[
"__lt__"
,
"__le__"
,
"__gt__"
,
"__ge__"
,
"__eq__"
,
"__ne__"
,
"__neg__"
,
"__pos__"
,
"__abs__"
,
"__invert__"
,
"__round__"
,
"__floor__"
,
"__ceil__"
,
"__add__"
,
"__sub__"
,
"__mul__"
,
"__matmul__"
,
"__truediv__"
,
"__floordiv__"
,
"__mod__"
,
"__pow__"
,
"__lshift__"
,
"__rshift__"
,
"__and__"
,
"__or__"
,
"__xor__"
,
"__radd__"
,
"__rsub__"
,
"__rmul__"
,
"__rmatmul__"
,
"__rtruediv__"
,
"__rfloordiv__"
,
"__rmod__"
,
"__rpow__"
,
"__rlshift__"
,
"__rrshift__"
,
"__rand__"
,
"__ror__"
,
"__rxor__"
,
"__iadd__"
,
"__isub__"
,
"__imul__"
,
"__imatmul__"
,
"__itruediv__"
,
"__ifloordiv__"
,
"__imod__"
,
"__ipow__"
,
"__ilshift__"
,
"__irshift__"
,
"__iand__"
,
"__ior__"
,
"__ixor__"
,
"T"
,
"astype"
,
"reshape"
,
"_broadcast"
,
"transpose"
,
"flatten"
,
"sum"
,
"prod"
,
"min"
,
"max"
,
"mean"
,
]
def
active_module_tracer
():
return
_active_module_tracer
...
...
@@ -108,9 +174,8 @@ class Patcher:
self
.
wrap_fn
=
wrap_fn
for
module
in
self
.
_builtin_modules
:
self
.
patch_module
(
module
)
for
cls
in
self
.
_builtin_methods
:
self
.
patch_cls
(
cls
)
for
meth
in
BUILTIN_ARRAY_METHOD
:
self
.
patch_method
(
ArrayMethodMixin
,
meth
,
self
.
wrap_fn
)
for
i
,
j
in
self
.
_builtin_functions
:
if
id
(
i
)
not
in
self
.
visited_frames_ids
:
...
...
imperative/python/megengine/experimental/traced_module/node.py
浏览文件 @
f2691566
...
...
@@ -13,6 +13,7 @@ import numpy
from
...core._imperative_rt.core2
import
Tensor
as
RawTensor
from
...module
import
Module
from
...tensor
import
Tensor
from
.pytree
import
TreeDef
class
Node
:
...
...
@@ -58,6 +59,7 @@ class ModuleNode(Node):
module_type
=
Module
# type: Type[Module]
graph
=
None
attr_type_map
=
None
# type: Dict[str, Type[Any]]
arg_def
=
None
# type: TreeDef
def
__repr__
(
self
):
if
self
.
_name
is
None
:
...
...
imperative/python/megengine/experimental/traced_module/pytree.py
0 → 100644
浏览文件 @
f2691566
from
typing
import
Callable
,
NamedTuple
SUPPORTED_TYPE
=
{}
NodeType
=
NamedTuple
(
"NodeType"
,
[(
"flatten"
,
Callable
),
(
"unflatten"
,
Callable
)])
def
register_supported_type
(
type
,
flatten
,
unflatten
):
SUPPORTED_TYPE
[
type
]
=
NodeType
(
flatten
,
unflatten
)
register_supported_type
(
list
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
list
(
x
))
register_supported_type
(
tuple
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
list
(
x
))
register_supported_type
(
dict
,
lambda
x
:
(
list
(
x
.
values
()),
list
(
x
.
keys
())),
lambda
x
,
y
:
dict
(
zip
(
y
,
x
))
)
register_supported_type
(
slice
,
lambda
x
:
([
x
.
start
,
x
.
stop
,
x
.
step
],
None
),
lambda
x
,
aux_data
:
slice
(
x
[
0
],
x
[
1
],
x
[
2
]),
)
def
tree_flatten
(
values
,
leaf_type
:
Callable
=
lambda
x
:
type
(
x
),
is_leaf
:
Callable
=
lambda
x
:
True
):
if
type
(
values
)
not
in
SUPPORTED_TYPE
:
assert
is_leaf
(
values
)
return
[
values
,],
LeafDef
(
leaf_type
(
values
))
rst
=
[]
children_defs
=
[]
children_values
,
aux_data
=
SUPPORTED_TYPE
[
type
(
values
)].
flatten
(
values
)
for
v
in
children_values
:
v_list
,
treedef
=
tree_flatten
(
v
,
leaf_type
)
rst
.
extend
(
v_list
)
children_defs
.
append
(
treedef
)
return
rst
,
TreeDef
(
type
(
values
),
aux_data
,
children_defs
)
class
TreeDef
:
def
__init__
(
self
,
type
,
aux_data
,
children_defs
):
self
.
type
=
type
self
.
aux_data
=
aux_data
self
.
children_defs
=
children_defs
self
.
num_leaves
=
sum
(
ch
.
num_leaves
for
ch
in
children_defs
)
def
unflatten
(
self
,
leaves
):
assert
len
(
leaves
)
==
self
.
num_leaves
start
=
0
children
=
[]
for
ch
in
self
.
children_defs
:
children
.
append
(
ch
.
unflatten
(
leaves
[
start
:
start
+
ch
.
num_leaves
]))
start
+=
ch
.
num_leaves
return
SUPPORTED_TYPE
[
self
.
type
].
unflatten
(
children
,
self
.
aux_data
)
def
__eq__
(
self
,
other
):
return
(
self
.
type
==
other
.
type
and
self
.
aux_data
==
other
.
aux_data
and
self
.
num_leaves
==
other
.
num_leaves
and
self
.
children_defs
==
other
.
children_defs
)
def
__repr__
(
self
):
return
"{}[{}]"
.
format
(
self
.
type
.
__name__
,
self
.
children_defs
)
class
LeafDef
(
TreeDef
):
def
__init__
(
self
,
type
):
super
().
__init__
(
type
,
None
,
[])
self
.
num_leaves
=
1
def
unflatten
(
self
,
leaves
):
assert
len
(
leaves
)
==
1
assert
isinstance
(
leaves
[
0
],
self
.
type
),
self
.
type
return
leaves
[
0
]
def
__repr__
(
self
):
return
"Leaf({})"
.
format
(
self
.
type
.
__name__
)
imperative/python/megengine/experimental/traced_module/traced_module.py
浏览文件 @
f2691566
...
...
@@ -9,9 +9,11 @@
import
collections
import
copy
import
functools
from
inspect
import
getmembers
,
isclass
,
ismethod
from
typing
import
List
,
Type
from
...
import
module
as
M
from
...core._imperative_rt.core2
import
Tensor
as
RawTensor
from
...core._imperative_rt.core2
import
(
is_tracing_module
,
set_module_tracing
,
...
...
@@ -28,6 +30,16 @@ from .module_tracer import (
set_active_module_tracer
,
)
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.pytree
import
tree_flatten
def
_leaf_type
(
node
):
if
isinstance
(
node
,
RawTensor
):
return
(
Tensor
,
TensorNode
)
elif
isinstance
(
node
,
(
NodeMixin
,
Module
)):
return
(
Module
,
ModuleNode
,
NodeMixin
)
else
:
return
type
(
node
)
class
InternalGraph
:
...
...
@@ -65,9 +77,7 @@ class InternalGraph:
for
n
,
v
in
zip
(
self
.
_inputs
,
inputs
):
node2value
[
n
]
=
v
for
expr
in
self
.
_exprs
:
values
=
expr
.
interpret
(
*
list
(
Expr
.
get_arg_value
(
i
,
node2value
)
for
i
in
expr
.
inputs
)
)
values
=
expr
.
interpret
(
*
list
(
node2value
[
i
]
for
i
in
expr
.
inputs
))
for
n
,
v
in
zip
(
expr
.
outputs
,
values
):
node2value
[
n
]
=
v
return
list
(
node2value
[
i
]
for
i
in
self
.
_outputs
)
...
...
@@ -80,37 +90,39 @@ class InternalGraph:
)
def
_get_meth_name
(
obj
,
func
):
for
cls
in
type
(
obj
).
mro
():
for
k
,
v
in
cls
.
__dict__
.
items
():
if
v
==
func
:
return
k
return
None
def
_wrapped_function
(
orig_func
):
@
functools
.
wraps
(
orig_func
)
def
wrapped_fn
(
*
input
s
,
**
kwargs
):
def
wrapped_fn
(
*
arg
s
,
**
kwargs
):
if
is_tracing_module
():
unset_module_tracing
()
const_kwargs
=
{}
arg_names
=
orig_func
.
__code__
.
co_varnames
if
orig_func
.
__qualname__
.
split
(
"."
).
__len__
()
>
1
:
# FIXME: a robust way to distinguish method and function. <XP>
inputs
,
tree_def
=
tree_flatten
((
args
,
kwargs
),
leaf_type
=
_leaf_type
)
for
i
in
inputs
:
if
not
NodeMixin
.
get
(
i
,
None
):
if
isinstance
(
i
,
(
RawTensor
,
NodeMixin
)):
NodeMixin
.
wrap_safe
(
i
,
Constant
.
make
(
i
))
meth_name
=
_get_meth_name
(
args
[
0
],
wrapped_fn
)
if
meth_name
:
self
=
inputs
[
0
]
call_node
=
CallMethod
.
make
(
NodeMixin
.
get
(
self
),
orig_func
.
__name__
)
call_node
=
CallMethod
.
make
(
NodeMixin
.
get
(
self
),
meth_name
)
else
:
call_node
=
CallFunction
.
make
(
orig_func
)
def
add_input
(
inp
,
varname
=
None
):
node
=
Expr
.
get_args_node
(
inp
)
if
node
is
not
None
:
call_node
.
add_input
(
node
,
varname
)
else
:
const_kwargs
[
varname
]
=
inp
for
ind
,
inp
in
enumerate
(
inputs
):
add_input
(
inp
,
arg_names
[
ind
])
for
k
,
v
in
kwargs
.
items
():
add_input
(
v
,
k
)
call_node
.
kwargs
=
const_kwargs
outputs
=
orig_func
(
*
inputs
,
**
kwargs
)
call_node
.
add_inputs
(
inputs
)
call_node
.
arg_def
=
tree_def
outputs
=
orig_func
(
*
args
,
**
kwargs
)
call_node
.
add_outputs
(
outputs
)
set_module_tracing
()
return
outputs
return
orig_func
(
*
input
s
,
**
kwargs
)
return
orig_func
(
*
arg
s
,
**
kwargs
)
return
wrapped_fn
...
...
@@ -120,14 +132,14 @@ class TracedModuleBuilder(NodeMixin):
_mod
=
None
# type: Module
_body
=
None
# type: InternalGraph
_is_builtin
=
None
# type: bool
_arg_def
=
None
# type: TreeDef
__builder_attributes__
=
[
"_mod"
,
"_body"
,
"_NodeMixin__node"
,
"_is_builtin"
,
"_is_traced"
,
"build"
,
"
_arg_def"
"
build"
,
]
def
__init__
(
self
,
mod
):
...
...
@@ -146,6 +158,7 @@ class TracedModuleBuilder(NodeMixin):
node
=
NodeMixin
.
get
(
self
)
node
.
graph
=
self
.
_body
node
.
attr_type_map
=
{}
node
.
arg_def
=
self
.
_arg_def
traced_module
=
TracedModule
(
node
)
for
k
,
v
in
self
.
__dict__
.
items
():
if
k
not
in
TracedModuleBuilder
.
__builder_attributes__
:
...
...
@@ -155,32 +168,34 @@ class TracedModuleBuilder(NodeMixin):
traced_module
.
m_node
.
attr_type_map
[
k
]
=
type
(
v
)
return
traced_module
def
__call__
(
self
,
*
input
s
,
**
kwargs
):
def
__call__
(
self
,
*
arg
s
,
**
kwargs
):
assert
isinstance
(
self
.
_mod
,
Module
)
for
arg
in
args
:
assert
isinstance
(
arg
,
RawTensor
)
for
k
,
v
in
kwargs
.
items
():
assert
isinstance
(
v
,
RawTensor
)
# prepare args and kwargs for inner graph
def
mark_constant
(
x
):
node
=
NodeMixin
.
get
(
x
,
None
)
if
node
is
None
:
# capture as constant
NodeMixin
.
wrap
(
x
,
lambda
:
Constant
.
make
(
x
))
inputs
,
tree_def
=
tree_flatten
(((
self
,
*
args
),
kwargs
),
leaf_type
=
_leaf_type
)
if
self
.
_arg_def
is
None
:
self
.
_arg_def
=
tree_def
assert
self
.
_arg_def
==
tree_def
for
i
in
inputs
:
mark_constant
(
i
)
for
k
,
v
in
kwargs
.
items
():
mark_constant
(
v
)
callnode
=
CallMethod
.
make
(
NodeMixin
.
get
(
self
))
def
add_input
(
x
):
callnode
.
add_input
(
NodeMixin
.
get
(
x
))
callnode
.
add_inputs
(
inputs
)
for
i
in
inputs
:
add_input
(
i
)
for
k
,
v
in
kwargs
.
items
():
add_input
(
v
)
callnode
.
arg_def
=
tree_def
if
self
.
_is_builtin
or
self
.
_is_traced
:
unset_module_tracing
()
outputs
=
self
.
_mod
(
*
input
s
,
**
kwargs
)
outputs
=
self
.
_mod
(
*
arg
s
,
**
kwargs
)
set_module_tracing
()
if
self
.
_is_builtin
:
self
.
_body
=
None
...
...
@@ -193,23 +208,21 @@ class TracedModuleBuilder(NodeMixin):
)
# prepare args and kwargs for inner graph
def
wrap
(
x
):
# wrapped = copy.copy(x) # FIXME
wrapped
=
x
# FIXME: <XP>
wrapped
=
copy
.
copy
(
x
)
# FIXME
NodeMixin
.
wrap
(
wrapped
,
lambda
:
Input
.
make
(
type
=
NodeMixin
.
get_wrapped_type
(
wrapped
)),
)
return
wrapped
args
=
[]
for
i
in
inputs
:
args
=
[
self
]
for
i
in
inputs
[
1
:]
:
args
.
append
(
wrap
(
i
))
for
k
,
v
in
kwargs
.
items
():
kwargs
[
k
]
=
wrap
(
v
)
args
,
kwargs
=
tree_def
.
unflatten
(
args
)
active_module_tracer
().
patcher
.
auto_patch
(
getattr
(
getattr
(
self
.
_mod
,
"forward"
,
self
.
_mod
),
"__globals__"
,
{})
)
outputs
=
type
(
self
.
_mod
).
forward
(
self
,
*
args
,
**
kwargs
)
outputs
=
type
(
self
.
_mod
).
forward
(
*
args
,
**
kwargs
)
for
i
in
(
outputs
if
isinstance
(
outputs
,
collections
.
abc
.
Sequence
)
else
(
outputs
,)
...
...
@@ -269,8 +282,10 @@ class TracedModule(Module):
super
(
TracedModule
,
self
).
__init__
()
self
.
m_node
=
node
def
forward
(
self
,
*
inputs
):
rst
=
self
.
m_node
.
graph
.
interpret
(
self
,
*
inputs
)
def
forward
(
self
,
*
args
,
**
kwargs
):
inputs
,
treedef
=
tree_flatten
(((
self
,
*
args
),
kwargs
),
leaf_type
=
_leaf_type
)
assert
treedef
==
self
.
m_node
.
arg_def
rst
=
self
.
m_node
.
graph
.
interpret
(
*
inputs
)
if
len
(
rst
)
==
1
:
rst
=
rst
[
0
]
return
rst
...
...
@@ -345,7 +360,6 @@ def register_as_builtin(mod_cls: Type[Module]) -> None:
def
_register_all_builtin_module
():
from
inspect
import
getmembers
,
isclass
for
sub_mod
in
[
M
,
M
.
qat
,
M
.
quantized
]:
for
m
in
getmembers
(
sub_mod
):
...
...
@@ -357,7 +371,7 @@ def _register_all_builtin_module():
module_tracer
.
register_as_builtin
(
m
[
1
])
def
trace_module
(
mod
:
Module
,
*
input
s
:
Tensor
,
**
kwargs
:
Tensor
)
->
TracedModule
:
def
trace_module
(
mod
:
Module
,
*
arg
s
:
Tensor
,
**
kwargs
:
Tensor
)
->
TracedModule
:
"""
Traces module ``mod`` and returns corresponding TracedModule.
...
...
@@ -375,15 +389,13 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule
builder
=
TracedModuleBuilder
(
mod
)
NodeMixin
.
wrap_safe
(
builder
,
Input
.
make
(
"TopModule"
,
ModuleNode
))
inputs
,
_
=
tree_flatten
((
args
,
kwargs
))
for
_
,
i
in
enumerate
(
inputs
):
NodeMixin
.
wrap_safe
(
i
,
Input
.
make
(
"arg_{}"
.
format
(
_
)))
for
k
,
v
in
kwargs
.
items
():
NodeMixin
.
wrap_safe
(
v
,
Input
.
make
(
"kwarg_{}"
.
format
(
k
)))
builder
(
*
inputs
,
**
kwargs
)
NodeMixin
.
wrap_safe
(
i
,
Input
.
make
(
"arg_{}"
.
format
(
_
),
NodeMixin
.
get_wrapped_type
(
i
))
)
builder
(
*
args
,
**
kwargs
)
active_module_tracer
().
pop_scope
()
return
builder
.
build
()
finally
:
set_active_module_tracer
(
None
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录