diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 36f13632e75abc32ff78285561355ddecb69797e..9ababcebee4d4fc420072887084b9480396d2df4 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1587,7 +1587,9 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: [0 0 1 0] [0 0 0 1]] """ - zeros_tensor = zeros(list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device) + zeros_tensor = zeros( + list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device + ) ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device) op = builtin.IndexingSetOneHot(axis=inp.ndim) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index c1369abfb81f463e38ef5650b067cd2f27f5a08d..e8ff2e6a2da9addf02860e967aeaa3daae2c5dc5 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -195,28 +195,14 @@ def ones( return full(shape, 1.0, dtype=dtype, device=device) -def zeros( - shape: Union[int, Tuple[int, ...]], - *, - dtype="float32", - device: Optional[CompNode] = None -) -> Tensor: - r"""Returns a new tensor having a specified shape and filled with zeros. +def zeros(shape, dtype="float32", device=None) -> Tensor: + r"""Returns a zero tensor with given shape. Args: - shape (int or sequence of ints): the shape of the output tensor. - - Keyword args: - dtype (:attr:`.Tensor.dtype`): output tensor data type. Default: ``float32``. - device (:attr:`.Tensor.device`): device on which to place the created tensor. Default: ``None``. - - Returns: - a tensor containing zeros. - - Examples: - >>> F.zeros((2, 1)) - Tensor([[0.] - [0.]], device=xpux:0) + shape: a list, tuple or integer defining the shape of the output tensor. + dtype: the desired data type of the output tensor. Default: ``float32``. + device: the desired device of the output tensor. Default: if ``None``, + use the default device (see :func:`~.megengine.get_default_device`). """ return full(shape, 0.0, dtype=dtype, device=device)