diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 4802145bc3d8c35908a86bf67f2d76dfb966dc81..22e3b7cd29e0ec7c505806f5a41db8788fd9e569 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1489,7 +1489,8 @@ def roll(x, shifts, axis=None, name=None): x (Tensor): The x tensor as input. shifts (int|list|tuple): The number of places by which the elements of the `x` tensor are shifted. - axis (int|list|tuple|None): axis(axes) along which to roll. + axis (int|list|tuple, optional): axis(axes) along which to roll. Default: None + name (str, optional): Name for the operation. Default: None Returns: Tensor: A Tensor with same data type as `x`. @@ -1512,6 +1513,11 @@ def roll(x, shifts, axis=None, name=None): #[[7. 8. 9.] # [1. 2. 3.] # [4. 5. 6.]] + out_z3 = paddle.roll(x, shifts=1, axis=1) + print(out_z3) + #[[3. 1. 2.] + # [6. 4. 5.] + # [9. 7. 8.]] """ origin_shape = x.shape if type(shifts) == int: