From 1019865071a6a2b1369915da6b0775d54de92f45 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 7 Sep 2021 21:16:52 +0800 Subject: [PATCH] feat(module): add tensors and named_tensors GitOrigin-RevId: cb56d65d38154bb437100527833a35c68e72e2df --- imperative/python/megengine/module/module.py | 27 ++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 275a36e56..27efe8dbe 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -301,6 +301,33 @@ class Module(metaclass=ABCMeta): **kwargs, ) + def tensors(self, recursive: bool = True, **kwargs) -> Iterable[Parameter]: + r""" + Returns an iterable for the :class:`~.Tensor` of the module. + + :param recursive: If ``True``, returns all :class:`~.Tensor` within this + module, else only returns :class:`~.Tensor` that are direct attributes + of this module. + """ + + yield from self._flatten(with_key=False, recursive=recursive, **kwargs) + + def named_tensors( + self, prefix: Optional[str] = None, recursive: bool = True, **kwargs + ) -> Iterable[Tuple[str, Tensor]]: + """ + Returns an iterable for key tensor pairs of the module, where + ``key`` is the dotted path from this module to the tensor. + + :param prefix: prefix prepended to the keys. + :param recursive: if ``True``, returns all tensors within this + module, else only returns tensors that are direct attributes + of this module. + """ + yield from self._flatten( + with_key=True, prefix=prefix, recursive=recursive, **kwargs, + ) + def children(self, **kwargs) -> "Iterable[Module]": r"""Returns an iterable for all the submodules that are direct attributes of this module. -- GitLab