diff --git a/python/paddle/fluid/tests/unittests/test_rot90_op.py b/python/paddle/fluid/tests/unittests/test_rot90_op.py index 4ab7c4f14f96ba85326b3289d42a1ac40c50f039..404bb3ae1eb67604f0505764df29f2d7c786a088 100644 --- a/python/paddle/fluid/tests/unittests/test_rot90_op.py +++ b/python/paddle/fluid/tests/unittests/test_rot90_op.py @@ -208,6 +208,32 @@ class TestRot90_API(unittest.TestCase): (out_np == out_ref).all(), msg='rot90 output is wrong, out =' + str(out_np)) + def test_static_neg_k_4(self): + paddle.enable_static() + input = fluid.data(name='input', dtype='float32', shape=[2, 3]) + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + input = fluid.data(name='input', dtype='float32', shape=[2, 3]) + output = paddle.rot90(input, k=-4, axes=[0, 1]) + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup_program) + + img = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32) + res = exe.run(train_program, + feed={'input': img}, + fetch_list=[output]) + + out_np = np.array(res[0]) + out_ref = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32) + + self.assertTrue( + (out_np == out_ref).all(), + msg='rot90 output is wrong, out =' + str(out_np)) + def test_error_api(self): paddle.enable_static() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index a81d8c54ffc42fd4e571c1901db5ed4e42ab23f0..a77f17ca55a1ad4f32169eba2ae03c203e952529 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -502,7 +502,7 @@ def rot90(x, k=1, axes=[0, 1], name=None): Args: x (Tensor): The input Tensor(or LoDTensor). The data type of the input Tensor x - should be float16, float32, float64, int32, int64, bool. + should be float16, float32, float64, int32, int64, bool. float16 is only supported on gpu. k (int, optional): Direction and number of times to rotate, default value: 1. axes (list|tuple, optional): Axes to rotate, dimension must be 2. default value: [0, 1]. name (str, optional): The default value is None. Normally there is no need for user to set this property. @@ -577,8 +577,7 @@ def rot90(x, k=1, axes=[0, 1], name=None): raise ValueError("Rotation axis1 out of range, axis1 = {}".format(axes[ 1])) - ## k % 4 - k = k % 4 if k >= 0 else 4 - (-k % 4) + k %= 4 if k == 0: return x if k == 2: