From d814f196061e3555fd41fa408ab730738c201c92 Mon Sep 17 00:00:00 2001 From: TaoweiZhang <2501129113@qq.com> Date: Wed, 5 Jan 2022 13:24:19 +0800 Subject: [PATCH] docs(mge/functional): update functional.zeros docstring --- imperative/python/megengine/functional/nn.py | 4 +-- .../python/megengine/functional/tensor.py | 26 ++++++++++++++----- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 9ababcebe..36f13632e 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1587,9 +1587,7 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: [0 0 1 0] [0 0 0 1]] """ - zeros_tensor = zeros( - list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device - ) + zeros_tensor = zeros(list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device) ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device) op = builtin.IndexingSetOneHot(axis=inp.ndim) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index e8ff2e6a2..c1369abfb 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -195,14 +195,28 @@ def ones( return full(shape, 1.0, dtype=dtype, device=device) -def zeros(shape, dtype="float32", device=None) -> Tensor: - r"""Returns a zero tensor with given shape. +def zeros( + shape: Union[int, Tuple[int, ...]], + *, + dtype="float32", + device: Optional[CompNode] = None +) -> Tensor: + r"""Returns a new tensor having a specified shape and filled with zeros. Args: - shape: a list, tuple or integer defining the shape of the output tensor. - dtype: the desired data type of the output tensor. Default: ``float32``. - device: the desired device of the output tensor. Default: if ``None``, - use the default device (see :func:`~.megengine.get_default_device`). + shape (int or sequence of ints): the shape of the output tensor. + + Keyword args: + dtype (:attr:`.Tensor.dtype`): output tensor data type. Default: ``float32``. + device (:attr:`.Tensor.device`): device on which to place the created tensor. Default: ``None``. + + Returns: + a tensor containing zeros. + + Examples: + >>> F.zeros((2, 1)) + Tensor([[0.] + [0.]], device=xpux:0) """ return full(shape, 0.0, dtype=dtype, device=device) -- GitLab