diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 275a36e56be96daaa5ef1acfb36728991ff1dd1c..27efe8dbec4fb992cb4d50ad55f6fd87720ee370 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -301,6 +301,33 @@ class Module(metaclass=ABCMeta): **kwargs, ) + def tensors(self, recursive: bool = True, **kwargs) -> Iterable[Parameter]: + r""" + Returns an iterable for the :class:`~.Tensor` of the module. + + :param recursive: If ``True``, returns all :class:`~.Tensor` within this + module, else only returns :class:`~.Tensor` that are direct attributes + of this module. + """ + + yield from self._flatten(with_key=False, recursive=recursive, **kwargs) + + def named_tensors( + self, prefix: Optional[str] = None, recursive: bool = True, **kwargs + ) -> Iterable[Tuple[str, Tensor]]: + """ + Returns an iterable for key tensor pairs of the module, where + ``key`` is the dotted path from this module to the tensor. + + :param prefix: prefix prepended to the keys. + :param recursive: if ``True``, returns all tensors within this + module, else only returns tensors that are direct attributes + of this module. + """ + yield from self._flatten( + with_key=True, prefix=prefix, recursive=recursive, **kwargs, + ) + def children(self, **kwargs) -> "Iterable[Module]": r"""Returns an iterable for all the submodules that are direct attributes of this module.