Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9a6a3793
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9a6a3793
编写于
7月 06, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(traced_module): add visit method
GitOrigin-RevId: 251ecebf87c94fd5b60c27596a45149d479603e9
上级
442b4f6c
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
40 addition
and
6 deletion
+40
-6
imperative/python/megengine/experimental/traced_module/traced_module.py
...hon/megengine/experimental/traced_module/traced_module.py
+40
-6
未找到文件。
imperative/python/megengine/experimental/traced_module/traced_module.py
浏览文件 @
9a6a3793
...
@@ -10,7 +10,7 @@ import collections
...
@@ -10,7 +10,7 @@ import collections
import
copy
import
copy
import
functools
import
functools
from
inspect
import
getmembers
,
isclass
,
ismethod
from
inspect
import
getmembers
,
isclass
,
ismethod
from
typing
import
List
,
Type
from
typing
import
Dict
,
List
,
Type
from
...
import
module
as
M
from
...
import
module
as
M
from
...core._imperative_rt.core2
import
Tensor
as
RawTensor
from
...core._imperative_rt.core2
import
Tensor
as
RawTensor
...
@@ -64,6 +64,14 @@ class InternalGraph:
...
@@ -64,6 +64,14 @@ class InternalGraph:
def
insert
(
self
,
expr
):
def
insert
(
self
,
expr
):
self
.
_exprs
.
append
(
expr
)
self
.
_exprs
.
append
(
expr
)
@
property
def
inputs
(
self
):
return
self
.
_inputs
@
property
def
outputs
(
self
):
return
self
.
_outputs
def
add_input
(
self
,
i
):
def
add_input
(
self
,
i
):
self
.
_inputs
.
append
(
i
)
self
.
_inputs
.
append
(
i
)
...
@@ -271,6 +279,22 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -271,6 +279,22 @@ class TracedModuleBuilder(NodeMixin):
return
wrapped
return
wrapped
class
_expr_list
:
def
__init__
(
self
,
module
:
"TracedModule"
):
self
.
module
=
module
def
__iter__
(
self
):
graph
=
self
.
module
.
m_node
.
graph
for
expr
in
graph
.
_exprs
:
if
isinstance
(
expr
,
CallMethod
)
and
isinstance
(
expr
.
inputs
[
0
],
ModuleNode
):
yield
expr
assert
isinstance
(
expr
.
inputs
[
0
].
expr
,
GetAttr
)
(
obj
,)
=
expr
.
inputs
[
0
].
expr
.
interpret
(
self
.
module
)
if
isinstance
(
obj
,
TracedModule
):
yield
from
obj
.
exprs
yield
expr
class
TracedModule
(
Module
):
class
TracedModule
(
Module
):
"""
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
...
@@ -291,14 +315,21 @@ class TracedModule(Module):
...
@@ -291,14 +315,21 @@ class TracedModule(Module):
return
rst
return
rst
@
property
@
property
def
all_exprs
(
self
):
def
exprs
(
self
):
"""
Get all ``Expr`` s recursively.
:return: Iterator[Expr]
"""
"""
Visit all ``Expr``s in the graph recursively.
return
_expr_list
(
self
)
:return: List[Expr]
def
flatten
(
self
):
"""
"""
Get a new module, which eliminates ``GetAttr`` and has no hierarchy.
in_nodes
=
[
i
.
expr
for
i
in
self
.
m_node
.
graph
.
_inputs
if
not
i
is
self
]
:return: :class:`TracedModule`
"""
new_module
=
copy
.
deepcopy
(
self
)
def
_flatten_submodule
(
module
,
call
=
None
):
def
_flatten_submodule
(
module
,
call
=
None
):
if
not
isinstance
(
module
,
TracedModule
):
if
not
isinstance
(
module
,
TracedModule
):
...
@@ -328,6 +359,7 @@ class TracedModule(Module):
...
@@ -328,6 +359,7 @@ class TracedModule(Module):
elif
isinstance
(
expr
,
CallMethod
):
elif
isinstance
(
expr
,
CallMethod
):
obj_node
=
expr
.
inputs
[
0
]
obj_node
=
expr
.
inputs
[
0
]
if
isinstance
(
obj_node
,
ModuleNode
):
if
isinstance
(
obj_node
,
ModuleNode
):
assert
isinstance
(
expr
.
inputs
[
0
].
expr
,
GetAttr
)
(
obj
,)
=
expr
.
inputs
[
0
].
expr
.
interpret
(
module
)
(
obj
,)
=
expr
.
inputs
[
0
].
expr
.
interpret
(
module
)
exprs
.
extend
(
_flatten_submodule
(
obj
,
expr
))
exprs
.
extend
(
_flatten_submodule
(
obj
,
expr
))
else
:
else
:
...
@@ -337,7 +369,9 @@ class TracedModule(Module):
...
@@ -337,7 +369,9 @@ class TracedModule(Module):
return
exprs
return
exprs
return
in_nodes
+
_flatten_submodule
(
self
)
new_module
.
m_node
.
graph
.
_exprs
=
_flatten_submodule
(
new_module
)
return
new_module
def
__getstate__
(
self
):
def
__getstate__
(
self
):
d
=
self
.
__dict__
d
=
self
.
__dict__
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录