提交 10198650 编写于 作者: M Megvii Engine Team

feat(module): add tensors and named_tensors

GitOrigin-RevId: cb56d65d38154bb437100527833a35c68e72e2df
上级 bd817f3a
......@@ -301,6 +301,33 @@ class Module(metaclass=ABCMeta):
**kwargs,
)
def tensors(self, recursive: bool = True, **kwargs) -> Iterable[Parameter]:
r"""
Returns an iterable for the :class:`~.Tensor` of the module.
:param recursive: If ``True``, returns all :class:`~.Tensor` within this
module, else only returns :class:`~.Tensor` that are direct attributes
of this module.
"""
yield from self._flatten(with_key=False, recursive=recursive, **kwargs)
def named_tensors(
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
) -> Iterable[Tuple[str, Tensor]]:
"""
Returns an iterable for key tensor pairs of the module, where
``key`` is the dotted path from this module to the tensor.
:param prefix: prefix prepended to the keys.
:param recursive: if ``True``, returns all tensors within this
module, else only returns tensors that are direct attributes
of this module.
"""
yield from self._flatten(
with_key=True, prefix=prefix, recursive=recursive, **kwargs,
)
def children(self, **kwargs) -> "Iterable[Module]":
r"""Returns an iterable for all the submodules that are direct attributes of this
module.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册