diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 27efe8dbec4fb992cb4d50ad55f6fd87720ee370..593aab602dae56f35ea76ff8e4ba8d50c354aa53 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -67,6 +67,10 @@ def _is_parameter(obj): return isinstance(obj, Parameter) +def _is_tensor(obj): + return isinstance(obj, Tensor) + + def _is_buffer(obj): return isinstance(obj, Tensor) and not isinstance(obj, Parameter) @@ -309,8 +313,9 @@ class Module(metaclass=ABCMeta): module, else only returns :class:`~.Tensor` that are direct attributes of this module. """ - - yield from self._flatten(with_key=False, recursive=recursive, **kwargs) + yield from self._flatten( + with_key=False, predicate=_is_tensor, recursive=recursive, **kwargs + ) def named_tensors( self, prefix: Optional[str] = None, recursive: bool = True, **kwargs @@ -325,7 +330,11 @@ class Module(metaclass=ABCMeta): of this module. """ yield from self._flatten( - with_key=True, prefix=prefix, recursive=recursive, **kwargs, + with_key=True, + prefix=prefix, + predicate=_is_tensor, + recursive=recursive, + **kwargs, ) def children(self, **kwargs) -> "Iterable[Module]": diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index b7f7890467f702adf19a046877303f1060ef18b6..c573e224d79be6109ff6df9a00279178766bdead 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -124,6 +124,30 @@ def test_module_api(test_traced_module): ("i.bn.weight", m.i.bn.weight), ("param", m.param), ] + assert list(m.tensors()) == [ + m.bn.bias, + m.bn.running_mean, + m.bn.running_var, + m.bn.weight, + m.buff, + m.i.bn.bias, + m.i.bn.running_mean, + m.i.bn.running_var, + m.i.bn.weight, + m.param, + ] + assert list(m.named_tensors()) == [ + ("bn.bias", m.bn.bias), + ("bn.running_mean", m.bn.running_mean), + ("bn.running_var", m.bn.running_var), + ("bn.weight", m.bn.weight), + ("buff", m.buff), + ("i.bn.bias", m.i.bn.bias), + ("i.bn.running_mean", m.i.bn.running_mean), + ("i.bn.running_var", m.i.bn.running_var), + ("i.bn.weight", m.i.bn.weight), + ("param", m.param), + ] m.eval() assert ( m.training == False