提交 10198650 编写于 作者: M Megvii Engine Team

feat(module): add tensors and named_tensors

GitOrigin-RevId: cb56d65d38154bb437100527833a35c68e72e2df
上级 bd817f3a
...@@ -301,6 +301,33 @@ class Module(metaclass=ABCMeta): ...@@ -301,6 +301,33 @@ class Module(metaclass=ABCMeta):
**kwargs, **kwargs,
) )
def tensors(self, recursive: bool = True, **kwargs) -> Iterable[Parameter]:
r"""
Returns an iterable for the :class:`~.Tensor` of the module.
:param recursive: If ``True``, returns all :class:`~.Tensor` within this
module, else only returns :class:`~.Tensor` that are direct attributes
of this module.
"""
yield from self._flatten(with_key=False, recursive=recursive, **kwargs)
def named_tensors(
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
) -> Iterable[Tuple[str, Tensor]]:
"""
Returns an iterable for key tensor pairs of the module, where
``key`` is the dotted path from this module to the tensor.
:param prefix: prefix prepended to the keys.
:param recursive: if ``True``, returns all tensors within this
module, else only returns tensors that are direct attributes
of this module.
"""
yield from self._flatten(
with_key=True, prefix=prefix, recursive=recursive, **kwargs,
)
def children(self, **kwargs) -> "Iterable[Module]": def children(self, **kwargs) -> "Iterable[Module]":
r"""Returns an iterable for all the submodules that are direct attributes of this r"""Returns an iterable for all the submodules that are direct attributes of this
module. module.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册