From 201b7c9dedf93269593f9799486e8e789ed9cffb Mon Sep 17 00:00:00 2001 From: zmxdream Date: Mon, 13 Dec 2021 11:07:34 +0800 Subject: [PATCH] [rot90] fix rot90 (#38042) * [rot90] fix rot90 * fix rot90 * fix for ci. test=develop * fix rot90. test=develop * update. test=develop * update. test=develop --- .../fluid/tests/unittests/test_rot90_op.py | 26 +++++++++++++++++++ python/paddle/tensor/manipulation.py | 5 ++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_rot90_op.py b/python/paddle/fluid/tests/unittests/test_rot90_op.py index 4ab7c4f14f..404bb3ae1e 100644 --- a/python/paddle/fluid/tests/unittests/test_rot90_op.py +++ b/python/paddle/fluid/tests/unittests/test_rot90_op.py @@ -208,6 +208,32 @@ class TestRot90_API(unittest.TestCase): (out_np == out_ref).all(), msg='rot90 output is wrong, out =' + str(out_np)) + def test_static_neg_k_4(self): + paddle.enable_static() + input = fluid.data(name='input', dtype='float32', shape=[2, 3]) + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + input = fluid.data(name='input', dtype='float32', shape=[2, 3]) + output = paddle.rot90(input, k=-4, axes=[0, 1]) + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup_program) + + img = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32) + res = exe.run(train_program, + feed={'input': img}, + fetch_list=[output]) + + out_np = np.array(res[0]) + out_ref = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32) + + self.assertTrue( + (out_np == out_ref).all(), + msg='rot90 output is wrong, out =' + str(out_np)) + def test_error_api(self): paddle.enable_static() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index a81d8c54ff..a77f17ca55 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -502,7 +502,7 @@ def rot90(x, k=1, axes=[0, 1], name=None): Args: x (Tensor): The input Tensor(or LoDTensor). The data type of the input Tensor x - should be float16, float32, float64, int32, int64, bool. + should be float16, float32, float64, int32, int64, bool. float16 is only supported on gpu. k (int, optional): Direction and number of times to rotate, default value: 1. axes (list|tuple, optional): Axes to rotate, dimension must be 2. default value: [0, 1]. name (str, optional): The default value is None. Normally there is no need for user to set this property. @@ -577,8 +577,7 @@ def rot90(x, k=1, axes=[0, 1], name=None): raise ValueError("Rotation axis1 out of range, axis1 = {}".format(axes[ 1])) - ## k % 4 - k = k % 4 if k >= 0 else 4 - (-k % 4) + k %= 4 if k == 0: return x if k == 2: -- GitLab