From f7d05db7d69f19715fd80f164657c8f509b5f08b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 12 Oct 2021 13:12:11 +0800 Subject: [PATCH] fix(mge/module): fix named_tensors GitOrigin-RevId: bb5aa1f41d3577c6b346c913337839d12f8c1559 --- imperative/python/megengine/module/module.py | 15 +++++++++--- .../python/test/unit/module/test_module.py | 24 +++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 27efe8dbe..593aab602 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -67,6 +67,10 @@ def _is_parameter(obj): return isinstance(obj, Parameter) +def _is_tensor(obj): + return isinstance(obj, Tensor) + + def _is_buffer(obj): return isinstance(obj, Tensor) and not isinstance(obj, Parameter) @@ -309,8 +313,9 @@ class Module(metaclass=ABCMeta): module, else only returns :class:`~.Tensor` that are direct attributes of this module. """ - - yield from self._flatten(with_key=False, recursive=recursive, **kwargs) + yield from self._flatten( + with_key=False, predicate=_is_tensor, recursive=recursive, **kwargs + ) def named_tensors( self, prefix: Optional[str] = None, recursive: bool = True, **kwargs @@ -325,7 +330,11 @@ class Module(metaclass=ABCMeta): of this module. """ yield from self._flatten( - with_key=True, prefix=prefix, recursive=recursive, **kwargs, + with_key=True, + prefix=prefix, + predicate=_is_tensor, + recursive=recursive, + **kwargs, ) def children(self, **kwargs) -> "Iterable[Module]": diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index b7f789046..c573e224d 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -124,6 +124,30 @@ def test_module_api(test_traced_module): ("i.bn.weight", m.i.bn.weight), ("param", m.param), ] + assert list(m.tensors()) == [ + m.bn.bias, + m.bn.running_mean, + m.bn.running_var, + m.bn.weight, + m.buff, + m.i.bn.bias, + m.i.bn.running_mean, + m.i.bn.running_var, + m.i.bn.weight, + m.param, + ] + assert list(m.named_tensors()) == [ + ("bn.bias", m.bn.bias), + ("bn.running_mean", m.bn.running_mean), + ("bn.running_var", m.bn.running_var), + ("bn.weight", m.bn.weight), + ("buff", m.buff), + ("i.bn.bias", m.i.bn.bias), + ("i.bn.running_mean", m.i.bn.running_mean), + ("i.bn.running_var", m.i.bn.running_var), + ("i.bn.weight", m.i.bn.weight), + ("param", m.param), + ] m.eval() assert ( m.training == False -- GitLab