Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4bb25369
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4bb25369
编写于
7月 07, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(traced_module): let CallFunction own graph
GitOrigin-RevId: 66cdbca7e54df07576a984c3fd48d3bcafb678f1
上级
9a6a3793
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
115 addition
and
65 deletion
+115
-65
imperative/python/megengine/experimental/traced_module/expr.py
...ative/python/megengine/experimental/traced_module/expr.py
+13
-2
imperative/python/megengine/experimental/traced_module/node.py
...ative/python/megengine/experimental/traced_module/node.py
+6
-2
imperative/python/megengine/experimental/traced_module/pytree.py
...ive/python/megengine/experimental/traced_module/pytree.py
+33
-5
imperative/python/megengine/experimental/traced_module/traced_module.py
...hon/megengine/experimental/traced_module/traced_module.py
+63
-56
未找到文件。
imperative/python/megengine/experimental/traced_module/expr.py
浏览文件 @
4bb25369
...
...
@@ -17,7 +17,7 @@ from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module
from
...core.ops.special
import
Const
from
...module
import
Module
from
...tensor
import
Tensor
from
.module_tracer
import
active_module_tracer
from
.module_tracer
import
active_module_tracer
,
module_tracer
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.pytree
import
TreeDef
...
...
@@ -148,6 +148,15 @@ class CallMethod(Expr):
active_module_tracer
().
current_scope
().
insert
(
expr
)
return
expr
@
property
def
graph
(
self
):
if
isinstance
(
self
.
inputs
[
0
],
ModuleNode
):
m_node
=
self
.
inputs
[
0
]
if
m_node
.
argdef_graph_map
:
assert
self
.
arg_def
in
m_node
.
argdef_graph_map
return
m_node
.
argdef_graph_map
[
self
.
arg_def
]
return
None
def
interpret
(
self
,
*
inputs
):
args
,
kwargs
=
self
.
unflatten_args
(
inputs
)
obj
=
args
[
0
]
...
...
@@ -252,7 +261,9 @@ class Constant(Expr):
_constant_cache
=
{}
def
__init__
(
self
,
c
):
# TODO: type check, since not all types should be captured as constant
assert
isinstance
(
c
,
(
RawTensor
,
Module
))
if
isinstance
(
c
,
Module
):
assert
module_tracer
.
is_builtin
(
c
)
self
.
value
=
c
self
.
inputs
=
[]
node_cls
=
NodeMixin
.
get_wrapped_type
(
c
)
...
...
imperative/python/megengine/experimental/traced_module/node.py
浏览文件 @
4bb25369
...
...
@@ -57,9 +57,13 @@ class ModuleNode(Node):
"""
module_type
=
Module
# type: Type[Module]
graph
=
None
attr_type_map
=
None
# type: Dict[str, Type[Any]]
arg_def
=
None
# type: TreeDef
argdef_graph_map
=
None
# type: Dict[Treedef, "InternalGraph"]
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
):
super
().
__init__
(
expr
,
name
)
self
.
attr_type_map
=
{}
self
.
argdef_graph_map
=
{}
def
__repr__
(
self
):
if
self
.
_name
is
None
:
...
...
imperative/python/megengine/experimental/traced_module/pytree.py
浏览文件 @
4bb25369
...
...
@@ -25,7 +25,7 @@ def _dict_flatten(inp):
for
key
,
value
in
sorted
(
inp
.
items
()):
results
.
append
(
value
)
aux_data
.
append
(
key
)
return
results
,
aux_data
return
results
,
tuple
(
aux_data
)
def
_dict_unflatten
(
inps
,
aux_data
):
...
...
@@ -43,16 +43,23 @@ register_supported_type(
def
tree_flatten
(
values
,
leaf_type
:
Callable
=
lambda
x
:
type
(
x
),
is_leaf
:
Callable
=
lambda
x
:
True
values
,
leaf_type
:
Callable
=
lambda
x
:
type
(
x
),
is_leaf
:
Callable
=
lambda
_
:
True
,
is_const_leaf
:
Callable
=
lambda
_
:
False
,
):
if
type
(
values
)
not
in
SUPPORTED_TYPE
:
assert
is_leaf
(
values
)
return
[
values
,],
LeafDef
(
leaf_type
(
values
))
node
=
LeafDef
(
leaf_type
(
values
))
if
is_const_leaf
(
values
):
node
.
const_val
=
values
return
[
values
,],
node
rst
=
[]
children_defs
=
[]
children_values
,
aux_data
=
SUPPORTED_TYPE
[
type
(
values
)].
flatten
(
values
)
for
v
in
children_values
:
v_list
,
treedef
=
tree_flatten
(
v
,
leaf_type
)
v_list
,
treedef
=
tree_flatten
(
v
,
leaf_type
,
is_leaf
,
is_const_leaf
)
rst
.
extend
(
v_list
)
children_defs
.
append
(
treedef
)
...
...
@@ -75,6 +82,18 @@ class TreeDef:
start
+=
ch
.
num_leaves
return
SUPPORTED_TYPE
[
self
.
type
].
unflatten
(
children
,
self
.
aux_data
)
def
__hash__
(
self
):
return
hash
(
tuple
(
[
self
.
type
,
self
.
aux_data
,
self
.
num_leaves
,
tuple
([
hash
(
x
)
for
x
in
self
.
children_defs
]),
]
)
)
def
__eq__
(
self
,
other
):
return
(
self
.
type
==
other
.
type
...
...
@@ -93,11 +112,20 @@ class LeafDef(TreeDef):
type
=
(
type
,)
super
().
__init__
(
type
,
None
,
[])
self
.
num_leaves
=
1
self
.
const_val
=
None
def
unflatten
(
self
,
leaves
):
assert
len
(
leaves
)
==
1
assert
isinstance
(
leaves
[
0
],
self
.
type
),
self
.
type
return
leaves
[
0
]
def
__eq__
(
self
,
other
):
return
self
.
type
==
other
.
type
and
self
.
const_val
==
other
.
const_val
def
__hash__
(
self
):
return
hash
(
tuple
([
self
.
type
,
self
.
const_val
]))
def
__repr__
(
self
):
return
"Leaf({})"
.
format
(
", "
.
join
(
t
.
__name__
for
t
in
self
.
type
))
return
"Leaf({}[{}])"
.
format
(
", "
.
join
(
t
.
__name__
for
t
in
self
.
type
),
self
.
const_val
)
imperative/python/megengine/experimental/traced_module/traced_module.py
浏览文件 @
4bb25369
...
...
@@ -42,6 +42,12 @@ def _leaf_type(node):
return
type
(
node
)
def
_is_const_leaf
(
node
):
if
isinstance
(
node
,
(
RawTensor
,
NodeMixin
,
Module
)):
return
False
return
True
class
InternalGraph
:
"""
``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
...
...
@@ -72,6 +78,10 @@ class InternalGraph:
def
outputs
(
self
):
return
self
.
_outputs
@
property
def
exprs
(
self
):
return
_expr_list
(
self
)
def
add_input
(
self
,
i
):
self
.
_inputs
.
append
(
i
)
...
...
@@ -111,7 +121,9 @@ def _wrapped_function(orig_func):
def
wrapped_fn
(
*
args
,
**
kwargs
):
if
is_tracing_module
():
unset_module_tracing
()
inputs
,
tree_def
=
tree_flatten
((
args
,
kwargs
),
leaf_type
=
_leaf_type
)
inputs
,
tree_def
=
tree_flatten
(
(
args
,
kwargs
),
leaf_type
=
_leaf_type
,
is_const_leaf
=
_is_const_leaf
)
for
i
in
inputs
:
if
not
NodeMixin
.
get
(
i
,
None
):
if
isinstance
(
i
,
(
RawTensor
,
NodeMixin
)):
...
...
@@ -140,21 +152,18 @@ class TracedModuleBuilder(NodeMixin):
_mod
=
None
# type: Module
_body
=
None
# type: InternalGraph
_is_builtin
=
None
# type: bool
_arg_def
=
None
# type: TreeDef
__builder_attributes__
=
[
"_mod"
,
"_body"
,
"_NodeMixin__node"
,
"_is_builtin"
,
"_is_traced"
,
"_arg_def"
"build"
,
"build"
,
]
def
__init__
(
self
,
mod
):
def
__init__
(
self
,
mod
,
is_top_module
=
False
):
super
(
TracedModuleBuilder
,
self
).
__init__
()
self
.
_mod
=
mod
self
.
_body
=
InternalGraph
()
self
.
_is_traced
=
False
self
.
_body
=
None
self
.
_is_builtin
=
module_tracer
.
is_builtin
(
mod
)
def
build
(
self
):
...
...
@@ -164,9 +173,6 @@ class TracedModuleBuilder(NodeMixin):
return
self
.
_mod
else
:
node
=
NodeMixin
.
get
(
self
)
node
.
graph
=
self
.
_body
node
.
attr_type_map
=
{}
node
.
arg_def
=
self
.
_arg_def
traced_module
=
TracedModule
(
node
)
for
k
,
v
in
self
.
__dict__
.
items
():
if
k
not
in
TracedModuleBuilder
.
__builder_attributes__
:
...
...
@@ -178,21 +184,15 @@ class TracedModuleBuilder(NodeMixin):
def
__call__
(
self
,
*
args
,
**
kwargs
):
assert
isinstance
(
self
.
_mod
,
Module
)
for
arg
in
args
:
assert
isinstance
(
arg
,
RawTensor
)
for
k
,
v
in
kwargs
.
items
():
assert
isinstance
(
v
,
RawTensor
)
# prepare args and kwargs for inner graph
def
mark_constant
(
x
):
node
=
NodeMixin
.
get
(
x
,
None
)
if
node
is
None
:
# capture as constant
NodeMixin
.
wrap
(
x
,
lambda
:
Constant
.
make
(
x
))
inputs
,
tree_def
=
tree_flatten
(((
self
,
*
args
),
kwargs
),
leaf_type
=
_leaf_type
)
if
self
.
_arg_def
is
None
:
self
.
_arg_def
=
tree_def
assert
self
.
_arg_def
==
tree_def
inputs
,
tree_def
=
tree_flatten
(
((
self
,
*
args
),
kwargs
),
leaf_type
=
_leaf_type
,
is_const_leaf
=
_is_const_leaf
)
for
i
in
inputs
:
mark_constant
(
i
)
callnode
=
CallMethod
.
make
(
NodeMixin
.
get
(
self
))
...
...
@@ -201,13 +201,14 @@ class TracedModuleBuilder(NodeMixin):
callnode
.
arg_def
=
tree_def
if
self
.
_is_builtin
or
self
.
_is_traced
:
if
self
.
_is_builtin
:
unset_module_tracing
()
outputs
=
self
.
_mod
(
*
args
,
**
kwargs
)
set_module_tracing
()
if
self
.
_is_builtin
:
self
.
_body
=
None
else
:
self
.
_body
=
InternalGraph
()
active_module_tracer
().
push_scope
(
self
.
_body
)
# rebind self to new input node
orig_self
=
NodeMixin
.
get
(
self
)
...
...
@@ -238,11 +239,12 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer
().
current_scope
().
add_output
(
NodeMixin
.
get
(
i
))
NodeMixin
.
wrap_safe
(
self
,
orig_self
)
self
.
_is_traced
=
True
active_module_tracer
().
pop_scope
()
# rebind output to outer graph
callnode
.
add_outputs
(
outputs
)
self_node
=
NodeMixin
.
get
(
self
)
self_node
.
argdef_graph_map
[
callnode
.
arg_def
]
=
self
.
_body
return
outputs
def
__getattr__
(
self
,
name
):
...
...
@@ -280,24 +282,23 @@ class TracedModuleBuilder(NodeMixin):
class
_expr_list
:
def
__init__
(
self
,
module
:
"TracedModule"
):
self
.
module
=
module
def
__init__
(
self
,
graph
:
InternalGraph
):
self
.
graph
=
graph
def
__iter__
(
self
):
graph
=
self
.
module
.
m_node
.
graph
for
expr
in
graph
.
_exprs
:
for
expr
in
self
.
graph
.
_exprs
:
if
isinstance
(
expr
,
CallMethod
)
and
isinstance
(
expr
.
inputs
[
0
],
ModuleNode
):
yield
expr
assert
isinstance
(
expr
.
inputs
[
0
].
expr
,
GetAttr
)
(
obj
,)
=
expr
.
inputs
[
0
].
expr
.
interpret
(
self
.
module
)
if
isinstance
(
obj
,
TracedModule
):
yield
from
obj
.
exprs
yield
expr
if
expr
.
graph
is
not
None
:
yield
from
expr
.
graph
.
exprs
else
:
yield
expr
class
TracedModule
(
Module
):
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
interpreted by CallMethod Expr.
"""
m_node
=
None
# type: ModuleNode
...
...
@@ -307,21 +308,24 @@ class TracedModule(Module):
self
.
m_node
=
node
def
forward
(
self
,
*
args
,
**
kwargs
):
inputs
,
treedef
=
tree_flatten
(((
self
,
*
args
),
kwargs
),
leaf_type
=
_leaf_type
)
assert
treedef
==
self
.
m_node
.
arg_def
rst
=
self
.
m_node
.
graph
.
interpret
(
*
inputs
)
if
len
(
rst
)
==
1
:
rst
=
rst
[
0
]
return
rst
inputs
,
treedef
=
tree_flatten
(
((
self
,
*
args
),
kwargs
),
_leaf_type
,
is_const_leaf
=
_is_const_leaf
)
assert
treedef
in
self
.
m_node
.
argdef_graph_map
inputs
=
[
i
for
i
in
inputs
if
isinstance
(
i
,
(
Module
,
RawTensor
))]
outputs
=
self
.
m_node
.
argdef_graph_map
[
treedef
].
interpret
(
*
inputs
)
if
len
(
outputs
)
==
1
:
return
outputs
[
0
]
return
outputs
@
property
def
exprs
(
self
):
"""
Get all ``Expr`` s recursively.
def
graph
(
self
):
assert
len
(
self
.
m_node
.
argdef_graph_map
)
==
1
return
list
(
self
.
m_node
.
argdef_graph_map
.
values
())[
0
]
:return: Iterator[Expr]
"""
return
_expr_list
(
self
)
@
property
def
exprs
(
self
):
return
self
.
graph
.
exprs
def
flatten
(
self
):
"""
...
...
@@ -331,24 +335,26 @@ class TracedModule(Module):
"""
new_module
=
copy
.
deepcopy
(
self
)
def
_flatten_submodule
(
module
,
call
=
None
):
if
not
isinstance
(
module
,
TracedModule
):
call
.
inputs
[
0
]
=
module
return
(
call
,)
def
_flatten_subgraph
(
graph
,
module
,
call
=
None
):
if
graph
is
None
:
assert
not
isinstance
(
module
,
TracedModule
)
const
=
Constant
(
module
)
modulenode
=
const
.
outputs
[
0
]
modulenode
.
module_type
=
type
(
module
)
call
.
inputs
[
0
]
=
modulenode
return
[
const
,
call
]
exprs
=
[]
graph
=
module
.
m_node
.
graph
for
expr
in
graph
.
_exprs
:
# replace inputs for submodule's expr
for
idx
,
inp
in
enumerate
(
expr
.
inputs
):
if
call
and
inp
in
graph
.
_inputs
:
expr
.
inputs
[
idx
]
=
call
.
inputs
[
idx
]
inp_idx
=
graph
.
_inputs
.
index
(
inp
)
expr
.
inputs
[
idx
]
=
call
.
inputs
[
inp_idx
]
# replace outputs for submodule's expr
for
idx
,
outp
in
enumerate
(
expr
.
outputs
):
if
call
and
outp
in
graph
.
_outputs
:
expr
.
outputs
[
idx
]
=
call
.
outputs
[
idx
]
oup_idx
=
graph
.
_outputs
.
index
(
outp
)
expr
.
outputs
[
idx
]
=
call
.
outputs
[
oup_idx
]
if
isinstance
(
expr
,
GetAttr
):
# replace GetAttr with Constant
...
...
@@ -356,12 +362,13 @@ class TracedModule(Module):
const
=
Constant
(
getattr
(
module
,
expr
.
name
))
const
.
outputs
=
expr
.
outputs
exprs
.
append
(
const
)
elif
isinstance
(
expr
,
CallMethod
):
obj_node
=
expr
.
inputs
[
0
]
if
isinstance
(
obj_node
,
ModuleNode
):
assert
isinstance
(
expr
.
inputs
[
0
].
expr
,
GetAttr
)
(
obj
,)
=
expr
.
inputs
[
0
].
expr
.
interpret
(
module
)
exprs
.
extend
(
_flatten_sub
module
(
obj
,
expr
))
exprs
.
extend
(
_flatten_sub
graph
(
expr
.
graph
,
obj
,
expr
))
else
:
exprs
.
append
(
expr
)
else
:
...
...
@@ -369,7 +376,7 @@ class TracedModule(Module):
return
exprs
new_module
.
m_node
.
graph
.
_exprs
=
_flatten_submodule
(
new_module
)
new_module
.
graph
.
_exprs
=
_flatten_subgraph
(
new_module
.
graph
,
new_module
)
return
new_module
...
...
@@ -421,7 +428,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
global_scope
=
InternalGraph
()
active_module_tracer
().
push_scope
(
global_scope
)
builder
=
TracedModuleBuilder
(
mod
)
builder
=
TracedModuleBuilder
(
mod
,
True
)
NodeMixin
.
wrap_safe
(
builder
,
Input
.
make
(
"TopModule"
,
ModuleNode
))
inputs
,
_
=
tree_flatten
((
args
,
kwargs
))
for
_
,
i
in
enumerate
(
inputs
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录