提交 d2f5874a 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mge/module): fix non-str key error of dict in module

GitOrigin-RevId: f82cd48230b2cfcf9c8da7442d3eb1e4bdbe3aee
上级 30b3d3aa
......@@ -18,17 +18,25 @@ logger = get_logger(__name__)
def _expand_structure(key, obj):
if isinstance(obj, (list, tuple, dict)):
if isinstance(obj, (Tensor, Module)):
return [(key, obj)]
elif isinstance(obj, (list, tuple, dict)):
ret = []
if isinstance(obj, dict):
targets = ((k, obj[k]) for k in sorted(obj))
else:
targets = ((str(k), v) for k, v in enumerate(obj))
for k, o in targets:
ret.extend(_expand_structure(key + "." + k, o))
sub_ret = _expand_structure(k, o)
if sub_ret and not isinstance(k, str):
raise AssertionError(
"keys for Tensor and Module must be str, error key: {}".format(k)
)
for kt, vt in sub_ret:
ret.extend([(key + "." + kt, vt)])
return ret
else:
return [(key, obj)]
return []
def _is_parameter(obj):
......@@ -72,11 +80,11 @@ class Module(metaclass=ABCMeta):
predicate: Callable[[Any], bool] = lambda _: True,
seen: Optional[Set[int]] = None
) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]:
"""Scans the module object and returns an iterable for the attributes that
agree with the ``predicate``. For multiple calls of this function with same
arguments, the order of objects within the returned iterable is guaranteed to be
identical, as long as all the involved module objects' ``__dict__`` does not
change thoughout those calls.
"""Scans the module object and returns an iterable for the :class:`~.Tensor`
and :class:`~.Module` attributes that agree with the ``predicate``. For multiple
calls of this function with same arguments, the order of objects within the
returned iterable is guaranteed to be identical, as long as all the involved
module objects' ``__dict__`` does not change thoughout those calls.
:param recursive: Whether to recursively scan all the submodules.
:param with_key: Whether to yield keys along with yielded objects.
......
......@@ -14,7 +14,7 @@ import pytest
from helpers import MLP
import megengine as mge
from megengine.core import Buffer, Parameter, tensor
from megengine.core import Buffer, Parameter, Tensor, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential
from megengine.test import assertTensorClose
......@@ -139,6 +139,7 @@ class MyModule2(Module):
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
self.test_bool_key = {True: 1, False: 0}
def forward(self, x):
x = self.bn(x)
......@@ -148,7 +149,7 @@ class MyModule2(Module):
self.bn = BatchNorm2d(4)
self.a = [
BatchNorm2d(4),
{"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()]},
{"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()], "z": 0},
(self.InnerModule(),),
]
......@@ -171,6 +172,14 @@ def test_expand_structure():
]
def test_flatten_others():
def be_others(obj):
return not isinstance(obj, (Tensor, Module))
m = MyModule2()
assert len(list(m._flatten(with_key=True, predicate=be_others))) == 0
def test_flatten_with_parent():
m = MyModule2()
assert list(m.named_modules(with_parent=True)) == [
......@@ -251,6 +260,23 @@ def test_state_dict():
mlp1.load_state_dict(state_dict)
class AssertModule(Module):
def __init__(self):
super().__init__()
self.error_tensor_key = {True: tensor(), False: 0}
def forward(self, x):
return x
def test_assert_message():
m = AssertModule()
with pytest.raises(
AssertionError, match="keys for Tensor and Module must be str, error key: True"
):
list(m._flatten())
class Simple(Module):
def __init__(self):
super().__init__()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册