diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 040a498128e4040e08cabc63eb259bc7e60d748b..0b50c26aa3d644fab2b6c605b8f16a389a2e1082 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -851,34 +851,33 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern)) -def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: - r"""Reshapes a tensor to given target shape; total number of logical elements must - remain unchanged +def reshape(inp: Tensor, shape: Iterable[int]) -> Tensor: + r"""Reshapes a tensor without changing its data. Args: - inp: input tensor. - target_shape: target shape, it can contain an element of -1 representing ``unspec_axis``. - - Examples: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - x = tensor(np.arange(12, dtype=np.int32)) - out = F.reshape(x, (3, 4)) - print(out.numpy()) + inp (Tensor): input tensor to reshape. + shape (sequence of ints): target shape compatible with the original shape. One shape dimension is allowed + to be ``-1``. When a shape dimension is ``-1``, the corresponding output tensor shape dimension + must be inferred from the length of the tensor and the remaining dimensions. - Outputs: + Returns: + an output tensor having the same data type, elements, and underlying element order as ``inp``. - .. testoutput:: + Examples: - [[ 0 1 2 3] - [ 4 5 6 7] - [ 8 9 10 11]] + >>> x = F.arange(12) + >>> x + Tensor([ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11.], device=xpux:0) + >>> F.reshape(x, (3, 4)) + Tensor([[ 0. 1. 2. 3.] + [ 4. 5. 6. 7.] + [ 8. 9. 10. 11.]], device=xpux:0) + >>> F.reshape(x, (2, -1)) + Tensor([[ 0. 1. 2. 3. 4. 5.] + [ 6. 7. 8. 9. 10. 11.]], device=xpux:0) + """ - return inp.reshape(target_shape) + return inp.reshape(shape) def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: