提交 d2f5874a 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mge/module): fix non-str key error of dict in module

GitOrigin-RevId: f82cd48230b2cfcf9c8da7442d3eb1e4bdbe3aee
上级 30b3d3aa
...@@ -18,17 +18,25 @@ logger = get_logger(__name__) ...@@ -18,17 +18,25 @@ logger = get_logger(__name__)
def _expand_structure(key, obj): def _expand_structure(key, obj):
if isinstance(obj, (list, tuple, dict)): if isinstance(obj, (Tensor, Module)):
return [(key, obj)]
elif isinstance(obj, (list, tuple, dict)):
ret = [] ret = []
if isinstance(obj, dict): if isinstance(obj, dict):
targets = ((k, obj[k]) for k in sorted(obj)) targets = ((k, obj[k]) for k in sorted(obj))
else: else:
targets = ((str(k), v) for k, v in enumerate(obj)) targets = ((str(k), v) for k, v in enumerate(obj))
for k, o in targets: for k, o in targets:
ret.extend(_expand_structure(key + "." + k, o)) sub_ret = _expand_structure(k, o)
if sub_ret and not isinstance(k, str):
raise AssertionError(
"keys for Tensor and Module must be str, error key: {}".format(k)
)
for kt, vt in sub_ret:
ret.extend([(key + "." + kt, vt)])
return ret return ret
else: else:
return [(key, obj)] return []
def _is_parameter(obj): def _is_parameter(obj):
...@@ -72,11 +80,11 @@ class Module(metaclass=ABCMeta): ...@@ -72,11 +80,11 @@ class Module(metaclass=ABCMeta):
predicate: Callable[[Any], bool] = lambda _: True, predicate: Callable[[Any], bool] = lambda _: True,
seen: Optional[Set[int]] = None seen: Optional[Set[int]] = None
) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]: ) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]:
"""Scans the module object and returns an iterable for the attributes that """Scans the module object and returns an iterable for the :class:`~.Tensor`
agree with the ``predicate``. For multiple calls of this function with same and :class:`~.Module` attributes that agree with the ``predicate``. For multiple
arguments, the order of objects within the returned iterable is guaranteed to be calls of this function with same arguments, the order of objects within the
identical, as long as all the involved module objects' ``__dict__`` does not returned iterable is guaranteed to be identical, as long as all the involved
change thoughout those calls. module objects' ``__dict__`` does not change thoughout those calls.
:param recursive: Whether to recursively scan all the submodules. :param recursive: Whether to recursively scan all the submodules.
:param with_key: Whether to yield keys along with yielded objects. :param with_key: Whether to yield keys along with yielded objects.
......
...@@ -14,7 +14,7 @@ import pytest ...@@ -14,7 +14,7 @@ import pytest
from helpers import MLP from helpers import MLP
import megengine as mge import megengine as mge
from megengine.core import Buffer, Parameter, tensor from megengine.core import Buffer, Parameter, Tensor, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
...@@ -139,6 +139,7 @@ class MyModule2(Module): ...@@ -139,6 +139,7 @@ class MyModule2(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.bn = BatchNorm2d(4) self.bn = BatchNorm2d(4)
self.test_bool_key = {True: 1, False: 0}
def forward(self, x): def forward(self, x):
x = self.bn(x) x = self.bn(x)
...@@ -148,7 +149,7 @@ class MyModule2(Module): ...@@ -148,7 +149,7 @@ class MyModule2(Module):
self.bn = BatchNorm2d(4) self.bn = BatchNorm2d(4)
self.a = [ self.a = [
BatchNorm2d(4), BatchNorm2d(4),
{"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()]}, {"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()], "z": 0},
(self.InnerModule(),), (self.InnerModule(),),
] ]
...@@ -171,6 +172,14 @@ def test_expand_structure(): ...@@ -171,6 +172,14 @@ def test_expand_structure():
] ]
def test_flatten_others():
def be_others(obj):
return not isinstance(obj, (Tensor, Module))
m = MyModule2()
assert len(list(m._flatten(with_key=True, predicate=be_others))) == 0
def test_flatten_with_parent(): def test_flatten_with_parent():
m = MyModule2() m = MyModule2()
assert list(m.named_modules(with_parent=True)) == [ assert list(m.named_modules(with_parent=True)) == [
...@@ -251,6 +260,23 @@ def test_state_dict(): ...@@ -251,6 +260,23 @@ def test_state_dict():
mlp1.load_state_dict(state_dict) mlp1.load_state_dict(state_dict)
class AssertModule(Module):
def __init__(self):
super().__init__()
self.error_tensor_key = {True: tensor(), False: 0}
def forward(self, x):
return x
def test_assert_message():
m = AssertModule()
with pytest.raises(
AssertionError, match="keys for Tensor and Module must be str, error key: True"
):
list(m._flatten())
class Simple(Module): class Simple(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册