From d2f5874a5257e8a961ff2b284e26db1a1ce17326 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 26 May 2020 18:19:46 +0800 Subject: [PATCH] fix(mge/module): fix non-str key error of dict in module GitOrigin-RevId: f82cd48230b2cfcf9c8da7442d3eb1e4bdbe3aee --- python_module/megengine/module/module.py | 24 ++++++++++----- python_module/test/unit/module/test_module.py | 30 +++++++++++++++++-- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 183e3e42b..c0732fc3d 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -18,17 +18,25 @@ logger = get_logger(__name__) def _expand_structure(key, obj): - if isinstance(obj, (list, tuple, dict)): + if isinstance(obj, (Tensor, Module)): + return [(key, obj)] + elif isinstance(obj, (list, tuple, dict)): ret = [] if isinstance(obj, dict): targets = ((k, obj[k]) for k in sorted(obj)) else: targets = ((str(k), v) for k, v in enumerate(obj)) for k, o in targets: - ret.extend(_expand_structure(key + "." + k, o)) + sub_ret = _expand_structure(k, o) + if sub_ret and not isinstance(k, str): + raise AssertionError( + "keys for Tensor and Module must be str, error key: {}".format(k) + ) + for kt, vt in sub_ret: + ret.extend([(key + "." + kt, vt)]) return ret else: - return [(key, obj)] + return [] def _is_parameter(obj): @@ -72,11 +80,11 @@ class Module(metaclass=ABCMeta): predicate: Callable[[Any], bool] = lambda _: True, seen: Optional[Set[int]] = None ) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]: - """Scans the module object and returns an iterable for the attributes that - agree with the ``predicate``. For multiple calls of this function with same - arguments, the order of objects within the returned iterable is guaranteed to be - identical, as long as all the involved module objects' ``__dict__`` does not - change thoughout those calls. + """Scans the module object and returns an iterable for the :class:`~.Tensor` + and :class:`~.Module` attributes that agree with the ``predicate``. For multiple + calls of this function with same arguments, the order of objects within the + returned iterable is guaranteed to be identical, as long as all the involved + module objects' ``__dict__`` does not change thoughout those calls. :param recursive: Whether to recursively scan all the submodules. :param with_key: Whether to yield keys along with yielded objects. diff --git a/python_module/test/unit/module/test_module.py b/python_module/test/unit/module/test_module.py index 1c72e2dd1..16aaf08f0 100644 --- a/python_module/test/unit/module/test_module.py +++ b/python_module/test/unit/module/test_module.py @@ -14,7 +14,7 @@ import pytest from helpers import MLP import megengine as mge -from megengine.core import Buffer, Parameter, tensor +from megengine.core import Buffer, Parameter, Tensor, tensor from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential from megengine.test import assertTensorClose @@ -139,6 +139,7 @@ class MyModule2(Module): def __init__(self): super().__init__() self.bn = BatchNorm2d(4) + self.test_bool_key = {True: 1, False: 0} def forward(self, x): x = self.bn(x) @@ -148,7 +149,7 @@ class MyModule2(Module): self.bn = BatchNorm2d(4) self.a = [ BatchNorm2d(4), - {"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()]}, + {"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()], "z": 0}, (self.InnerModule(),), ] @@ -171,6 +172,14 @@ def test_expand_structure(): ] +def test_flatten_others(): + def be_others(obj): + return not isinstance(obj, (Tensor, Module)) + + m = MyModule2() + assert len(list(m._flatten(with_key=True, predicate=be_others))) == 0 + + def test_flatten_with_parent(): m = MyModule2() assert list(m.named_modules(with_parent=True)) == [ @@ -251,6 +260,23 @@ def test_state_dict(): mlp1.load_state_dict(state_dict) +class AssertModule(Module): + def __init__(self): + super().__init__() + self.error_tensor_key = {True: tensor(), False: 0} + + def forward(self, x): + return x + + +def test_assert_message(): + m = AssertModule() + with pytest.raises( + AssertionError, match="keys for Tensor and Module must be str, error key: True" + ): + list(m._flatten()) + + class Simple(Module): def __init__(self): super().__init__() -- GitLab