From b3975685c125ed0d59118c9ffd6ba6f1ef391a91 Mon Sep 17 00:00:00 2001 From: DBJ <974658390@qq.com> Date: Tue, 7 Dec 2021 13:41:28 +0800 Subject: [PATCH] docs(mge/functional): update functional.tensor.reshape docstring --- .../python/megengine/functional/tensor.py | 43 +++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 040a49812..0b50c26aa 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -851,34 +851,33 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern)) -def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: - r"""Reshapes a tensor to given target shape; total number of logical elements must - remain unchanged +def reshape(inp: Tensor, shape: Iterable[int]) -> Tensor: + r"""Reshapes a tensor without changing its data. Args: - inp: input tensor. - target_shape: target shape, it can contain an element of -1 representing ``unspec_axis``. - - Examples: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - x = tensor(np.arange(12, dtype=np.int32)) - out = F.reshape(x, (3, 4)) - print(out.numpy()) + inp (Tensor): input tensor to reshape. + shape (sequence of ints): target shape compatible with the original shape. One shape dimension is allowed + to be ``-1``. When a shape dimension is ``-1``, the corresponding output tensor shape dimension + must be inferred from the length of the tensor and the remaining dimensions. - Outputs: + Returns: + an output tensor having the same data type, elements, and underlying element order as ``inp``. - .. testoutput:: + Examples: - [[ 0 1 2 3] - [ 4 5 6 7] - [ 8 9 10 11]] + >>> x = F.arange(12) + >>> x + Tensor([ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11.], device=xpux:0) + >>> F.reshape(x, (3, 4)) + Tensor([[ 0. 1. 2. 3.] + [ 4. 5. 6. 7.] + [ 8. 9. 10. 11.]], device=xpux:0) + >>> F.reshape(x, (2, -1)) + Tensor([[ 0. 1. 2. 3. 4. 5.] + [ 6. 7. 8. 9. 10. 11.]], device=xpux:0) + """ - return inp.reshape(target_shape) + return inp.reshape(shape) def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: -- GitLab