Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
6d1a4f20
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6d1a4f20
编写于
8月 21, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(traced_module): support tracing submodules in list/dict
GitOrigin-RevId: 4076b47a89ff5fdbe7c94778a649f8a01d6cc0b6
上级
a3f9073c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
243 addition
and
14 deletion
+243
-14
imperative/python/megengine/experimental/traced_module/traced_module.py
...hon/megengine/experimental/traced_module/traced_module.py
+26
-10
imperative/python/megengine/experimental/traced_module/utils.py
...tive/python/megengine/experimental/traced_module/utils.py
+186
-0
imperative/python/test/unit/traced_module/test_trace_module.py
...ative/python/test/unit/traced_module/test_trace_module.py
+31
-4
未找到文件。
imperative/python/megengine/experimental/traced_module/traced_module.py
浏览文件 @
6d1a4f20
...
...
@@ -58,6 +58,7 @@ from .module_tracer import (
)
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.pytree
import
ArgsIndex
,
tree_flatten
from
.utils
import
replace_container_with_module_container
logger
=
get_logger
(
__name__
)
...
...
@@ -988,7 +989,9 @@ class TracedModuleBuilder(NodeMixin):
if
k
not
in
TracedModuleBuilder
.
__builder_attributes__
:
if
isinstance
(
v
,
TracedModuleBuilder
):
v
=
v
.
build
()
setattr
(
traced_module
,
k
,
v
)
setattr
(
traced_module
,
k
,
v
)
elif
isinstance
(
v
,
RawTensor
):
setattr
(
traced_module
,
k
,
v
)
if
isinstance
(
self
.
_mod
,
QATModule
):
unset_module_tracing
()
...
...
@@ -1146,7 +1149,16 @@ class TracedModuleBuilder(NodeMixin):
if
id
(
attr
)
in
active_module_tracer
().
id2name
:
full_name
=
active_module_tracer
().
id2name
[
id
(
attr
)]
if
isinstance
(
attr
,
(
List
,
Dict
)):
unset_module_tracing
()
has_module
,
m_container
=
replace_container_with_module_container
(
attr
)
if
m_container
:
attr
=
m_container
if
has_module
and
not
m_container
:
raise
ValueError
(
"Can not trace the module that uses the same container to store Module and Non-Module objects "
)
set_module_tracing
()
if
isinstance
(
attr
,
Module
):
attr
=
TracedModuleBuilder
(
attr
)
...
...
@@ -1178,17 +1190,22 @@ class TracedModuleBuilder(NodeMixin):
return
object
.
__getattribute__
(
self
,
name
)
else
:
wrapped
=
object
.
__getattribute__
(
self
,
name
)
class_members
=
dict
(
inspect
.
getmembers
(
self
.
__class__
))
if
name
in
self
.
_mod
.
__dict__
:
mod_attr
=
getattr
(
self
.
_mod
,
name
)
if
not
isinstance
(
mod_attr
,
Module
)
and
wrapped
is
not
mod_attr
:
wrapped
=
mod_attr
setattr
(
self
,
name
,
wrapped
)
if
isinstance
(
mod_attr
,
Module
):
assert
mod_attr
is
wrapped
.
_mod
if
name
in
class_members
:
if
(
not
isinstance
(
wrapped
,
TracedModuleBuilder
)
and
wrapped
is
not
mod_attr
):
wrapped
=
self
.
__getattr__
(
name
)
if
isinstance
(
wrapped
,
TracedModuleBuilder
):
if
not
isinstance
(
mod_attr
,
(
List
,
Dict
)):
assert
mod_attr
is
wrapped
.
_mod
else
:
assert
mod_attr
is
wrapped
full_name
=
None
if
id
(
mod_attr
)
in
active_module_tracer
().
id2name
:
full_name
=
active_module_tracer
().
id2name
[
id
(
mod_attr
)]
...
...
@@ -1679,7 +1696,6 @@ def _register_all_builtin_module():
isclass
(
m
[
1
])
and
issubclass
(
m
[
1
],
M
.
Module
)
and
m
[
1
]
is
not
M
.
Sequential
and
m
[
1
]
is
not
M
.
ModuleList
):
module_tracer
.
register_as_builtin
(
m
[
1
])
...
...
imperative/python/megengine/experimental/traced_module/utils.py
0 → 100644
浏览文件 @
6d1a4f20
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
copy
from
collections.abc
import
MutableMapping
,
MutableSequence
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
from
...module
import
Module
def
replace_container_with_module_container
(
container
):
has_module
=
False
module_container
=
None
if
isinstance
(
container
,
Dict
):
m_dic
=
copy
.
copy
(
container
)
for
key
,
value
in
container
.
items
():
if
isinstance
(
value
,
Module
):
has_module
=
True
elif
isinstance
(
value
,
(
List
,
Dict
)):
(
_has_module
,
_module_container
,
)
=
replace_container_with_module_container
(
value
)
m_dic
[
key
]
=
_module_container
if
_has_module
:
has_module
=
True
if
not
all
(
isinstance
(
v
,
Module
)
for
v
in
m_dic
.
values
()):
return
has_module
,
None
else
:
return
has_module
,
_ModuleDict
(
m_dic
)
elif
isinstance
(
container
,
List
):
m_list
=
copy
.
copy
(
container
)
for
ind
,
value
in
enumerate
(
container
):
if
isinstance
(
value
,
Module
):
has_module
=
True
elif
isinstance
(
value
,
(
List
,
Dict
)):
(
_has_module
,
_module_container
,
)
=
replace_container_with_module_container
(
value
)
m_list
[
ind
]
=
_module_container
if
_has_module
:
has_module
=
True
if
not
all
(
isinstance
(
v
,
Module
)
for
v
in
m_list
):
return
has_module
,
None
else
:
return
has_module
,
_ModuleList
(
m_list
)
return
has_module
,
module_container
class
_ModuleList
(
Module
,
MutableSequence
):
r
"""
A List-like container.
Using a ``ModuleList``, one can visit, add, delete and modify submodules
just like an ordinary python list.
"""
def
__init__
(
self
,
modules
:
Optional
[
Iterable
[
Module
]]
=
None
):
super
().
__init__
()
self
.
_size
=
0
if
modules
is
None
:
return
for
mod
in
modules
:
self
.
append
(
mod
)
@
classmethod
def
_ikey
(
cls
,
idx
):
return
"{}"
.
format
(
idx
)
def
_check_idx
(
self
,
idx
):
L
=
len
(
self
)
if
idx
<
0
:
idx
=
L
+
idx
if
idx
<
0
or
idx
>=
L
:
raise
IndexError
(
"list index out of range"
)
return
idx
def
__getitem__
(
self
,
idx
:
int
):
if
isinstance
(
idx
,
slice
):
idx
=
range
(
self
.
_size
)[
idx
]
if
not
isinstance
(
idx
,
Sequence
):
idx
=
[
idx
,
]
rst
=
[]
for
i
in
idx
:
i
=
self
.
_check_idx
(
i
)
key
=
self
.
_ikey
(
i
)
try
:
rst
.
append
(
getattr
(
self
,
key
))
except
AttributeError
:
raise
IndexError
(
"list index out of range"
)
return
rst
if
len
(
rst
)
>
1
else
rst
[
0
]
def
__setitem__
(
self
,
idx
:
int
,
mod
:
Module
):
if
not
isinstance
(
mod
,
Module
):
raise
ValueError
(
"invalid sub-module"
)
idx
=
self
.
_check_idx
(
idx
)
setattr
(
self
,
self
.
_ikey
(
idx
),
mod
)
def
__delitem__
(
self
,
idx
):
idx
=
self
.
_check_idx
(
idx
)
L
=
len
(
self
)
for
orig_idx
in
range
(
idx
+
1
,
L
):
new_idx
=
orig_idx
-
1
self
[
new_idx
]
=
self
[
orig_idx
]
delattr
(
self
,
self
.
_ikey
(
L
-
1
))
self
.
_size
-=
1
def
__len__
(
self
):
return
self
.
_size
def
insert
(
self
,
idx
,
mod
:
Module
):
assert
isinstance
(
mod
,
Module
)
L
=
len
(
self
)
if
idx
<
0
:
idx
=
L
-
idx
# clip idx to (0, L)
if
idx
>
L
:
idx
=
L
elif
idx
<
0
:
idx
=
0
for
new_idx
in
range
(
L
,
idx
,
-
1
):
orig_idx
=
new_idx
-
1
key
=
self
.
_ikey
(
new_idx
)
setattr
(
self
,
key
,
self
[
orig_idx
])
key
=
self
.
_ikey
(
idx
)
setattr
(
self
,
key
,
mod
)
self
.
_size
+=
1
def
forward
(
self
):
raise
RuntimeError
(
"ModuleList is not callable"
)
class
_ModuleDict
(
Module
,
MutableMapping
):
r
"""
A Dict-like container.
Using a ``ModuleDict``, one can visit, add, delete and modify submodules
just like an ordinary python dict.
"""
def
__init__
(
self
,
modules
:
Optional
[
Dict
[
str
,
Module
]]
=
None
):
super
().
__init__
()
self
.
_size
=
0
if
modules
is
not
None
:
self
.
update
(
modules
)
def
__delitem__
(
self
,
key
):
delattr
(
self
,
key
)
self
.
_size
-=
1
def
__getitem__
(
self
,
key
):
return
getattr
(
self
,
key
)
def
__setitem__
(
self
,
key
,
value
):
if
not
isinstance
(
value
,
Module
):
raise
ValueError
(
"invalid sub-module"
)
setattr
(
self
,
key
,
value
)
self
.
_size
+=
1
def
__iter__
(
self
):
return
iter
(
self
.
keys
())
def
__len__
(
self
):
return
self
.
_size
def
items
(
self
):
return
dict
(
self
.
named_children
()).
items
()
def
values
(
self
):
return
dict
(
self
.
named_children
()).
values
()
def
keys
(
self
):
return
dict
(
self
.
named_children
()).
keys
()
def
forward
(
self
):
raise
RuntimeError
(
"ModuleList is not callable"
)
imperative/python/test/unit/traced_module/test_trace_module.py
浏览文件 @
6d1a4f20
import
numpy
as
np
import
megengine.module
as
M
from
megengine
import
Tensor
from
megengine.experimental.traced_module
import
trace_module
from
megengine.module
import
Module
as
M
from
megengine.experimental.traced_module
import
TracedModule
,
trace_module
class
MyModule1
(
M
):
class
MyModule1
(
M
.
Module
):
def
forward
(
self
,
x
):
y
=
Tensor
(
x
)
y
+=
1
...
...
@@ -13,7 +13,7 @@ class MyModule1(M):
return
x
,
y
class
MyModule2
(
M
):
class
MyModule2
(
M
.
Module
):
def
forward
(
self
,
x
):
y
=
Tensor
([
1
,
x
,
1
])
y
+=
1
...
...
@@ -21,6 +21,23 @@ class MyModule2(M):
return
x
,
y
class
MyModule3
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
modules
=
[
M
.
Elemwise
(
"ADD"
),
M
.
Elemwise
(
"ADD"
),
{
"a"
:
M
.
Elemwise
(
"ADD"
),
"b"
:
M
.
Elemwise
(
"ADD"
)},
]
def
forward
(
self
,
a
,
b
):
x
=
self
.
modules
[
0
](
a
,
b
)
y
=
self
.
modules
[
1
](
a
,
b
)
y
=
self
.
modules
[
2
][
"a"
](
x
,
y
)
y
=
self
.
modules
[
2
][
"b"
](
x
,
y
)
return
y
def
test_trace_module
():
x
=
Tensor
(
1
)
...
...
@@ -40,3 +57,13 @@ def test_trace_module():
for
a
,
b
in
zip
(
output1
,
gt1
):
np
.
testing
.
assert_equal
(
a
.
numpy
(),
b
.
numpy
())
a
,
b
=
Tensor
(
1
),
Tensor
(
2
)
m3
=
MyModule3
()
gt
=
m3
(
a
,
b
)
tm3
=
trace_module
(
m3
,
a
,
b
)
out
=
tm3
(
a
,
b
)
np
.
testing
.
assert_equal
(
out
.
numpy
(),
gt
.
numpy
())
assert
isinstance
(
tm3
.
modules
.
__dict__
[
"0"
],
M
.
Elemwise
)
assert
isinstance
(
tm3
.
modules
.
__dict__
[
"2"
],
TracedModule
)
assert
isinstance
(
tm3
.
modules
.
__dict__
[
"2"
].
a
,
M
.
Elemwise
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录