Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d2f5874a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d2f5874a
编写于
5月 26, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/module): fix non-str key error of dict in module
GitOrigin-RevId: f82cd48230b2cfcf9c8da7442d3eb1e4bdbe3aee
上级
30b3d3aa
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
44 addition
and
10 deletion
+44
-10
python_module/megengine/module/module.py
python_module/megengine/module/module.py
+16
-8
python_module/test/unit/module/test_module.py
python_module/test/unit/module/test_module.py
+28
-2
未找到文件。
python_module/megengine/module/module.py
浏览文件 @
d2f5874a
...
...
@@ -18,17 +18,25 @@ logger = get_logger(__name__)
def
_expand_structure
(
key
,
obj
):
if
isinstance
(
obj
,
(
list
,
tuple
,
dict
)):
if
isinstance
(
obj
,
(
Tensor
,
Module
)):
return
[(
key
,
obj
)]
elif
isinstance
(
obj
,
(
list
,
tuple
,
dict
)):
ret
=
[]
if
isinstance
(
obj
,
dict
):
targets
=
((
k
,
obj
[
k
])
for
k
in
sorted
(
obj
))
else
:
targets
=
((
str
(
k
),
v
)
for
k
,
v
in
enumerate
(
obj
))
for
k
,
o
in
targets
:
ret
.
extend
(
_expand_structure
(
key
+
"."
+
k
,
o
))
sub_ret
=
_expand_structure
(
k
,
o
)
if
sub_ret
and
not
isinstance
(
k
,
str
):
raise
AssertionError
(
"keys for Tensor and Module must be str, error key: {}"
.
format
(
k
)
)
for
kt
,
vt
in
sub_ret
:
ret
.
extend
([(
key
+
"."
+
kt
,
vt
)])
return
ret
else
:
return
[
(
key
,
obj
)
]
return
[]
def
_is_parameter
(
obj
):
...
...
@@ -72,11 +80,11 @@ class Module(metaclass=ABCMeta):
predicate
:
Callable
[[
Any
],
bool
]
=
lambda
_
:
True
,
seen
:
Optional
[
Set
[
int
]]
=
None
)
->
Union
[
Iterable
[
Any
],
Iterable
[
Tuple
[
str
,
Any
]]]:
"""Scans the module object and returns an iterable for the
attributes that
a
gree with the ``predicate``. For multiple calls of this function with sam
e
arguments, the order of objects within the returned iterable is guaranteed to b
e
identical, as long as all the involved module objects' ``__dict__`` does not
change thoughout those calls.
"""Scans the module object and returns an iterable for the
:class:`~.Tensor`
a
nd :class:`~.Module` attributes that agree with the ``predicate``. For multipl
e
calls of this function with same arguments, the order of objects within th
e
returned iterable is guaranteed to be identical, as long as all the involved
module objects' ``__dict__`` does not
change thoughout those calls.
:param recursive: Whether to recursively scan all the submodules.
:param with_key: Whether to yield keys along with yielded objects.
...
...
python_module/test/unit/module/test_module.py
浏览文件 @
d2f5874a
...
...
@@ -14,7 +14,7 @@ import pytest
from
helpers
import
MLP
import
megengine
as
mge
from
megengine.core
import
Buffer
,
Parameter
,
tensor
from
megengine.core
import
Buffer
,
Parameter
,
Tensor
,
tensor
from
megengine.module
import
BatchNorm1d
,
BatchNorm2d
,
Conv2d
,
Module
,
Sequential
from
megengine.test
import
assertTensorClose
...
...
@@ -139,6 +139,7 @@ class MyModule2(Module):
def
__init__
(
self
):
super
().
__init__
()
self
.
bn
=
BatchNorm2d
(
4
)
self
.
test_bool_key
=
{
True
:
1
,
False
:
0
}
def
forward
(
self
,
x
):
x
=
self
.
bn
(
x
)
...
...
@@ -148,7 +149,7 @@ class MyModule2(Module):
self
.
bn
=
BatchNorm2d
(
4
)
self
.
a
=
[
BatchNorm2d
(
4
),
{
"x"
:
BatchNorm2d
(
4
),
"y"
:
[
BatchNorm2d
(
4
),
self
.
InnerModule
()]},
{
"x"
:
BatchNorm2d
(
4
),
"y"
:
[
BatchNorm2d
(
4
),
self
.
InnerModule
()]
,
"z"
:
0
},
(
self
.
InnerModule
(),),
]
...
...
@@ -171,6 +172,14 @@ def test_expand_structure():
]
def
test_flatten_others
():
def
be_others
(
obj
):
return
not
isinstance
(
obj
,
(
Tensor
,
Module
))
m
=
MyModule2
()
assert
len
(
list
(
m
.
_flatten
(
with_key
=
True
,
predicate
=
be_others
)))
==
0
def
test_flatten_with_parent
():
m
=
MyModule2
()
assert
list
(
m
.
named_modules
(
with_parent
=
True
))
==
[
...
...
@@ -251,6 +260,23 @@ def test_state_dict():
mlp1
.
load_state_dict
(
state_dict
)
class
AssertModule
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
error_tensor_key
=
{
True
:
tensor
(),
False
:
0
}
def
forward
(
self
,
x
):
return
x
def
test_assert_message
():
m
=
AssertModule
()
with
pytest
.
raises
(
AssertionError
,
match
=
"keys for Tensor and Module must be str, error key: True"
):
list
(
m
.
_flatten
())
class
Simple
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录