From df27c264a2b2cf798ce4548ece4a0e49b69baf58 Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Mon, 26 Jul 2021 18:14:31 +0800 Subject: [PATCH] Fix and omptimize flip API (#34379) --- python/paddle/fluid/tests/unittests/test_flip.py | 11 +++++++++-- python/paddle/tensor/manipulation.py | 13 ++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_flip.py b/python/paddle/fluid/tests/unittests/test_flip.py index 6feee9ce573..5e2aacf9cef 100644 --- a/python/paddle/fluid/tests/unittests/test_flip.py +++ b/python/paddle/fluid/tests/unittests/test_flip.py @@ -33,6 +33,8 @@ class TestFlipOp_API(unittest.TestCase): axis = [0] input = fluid.data(name='input', dtype='float32', shape=[2, 3]) output = paddle.flip(input, axis) + output = paddle.flip(output, -1) + output = output.flip(0) place = fluid.CPUPlace() if fluid.core.is_compiled_with_cuda(): place = fluid.CUDAPlace(0) @@ -43,7 +45,7 @@ class TestFlipOp_API(unittest.TestCase): feed={'input': img}, fetch_list=[output]) out_np = np.array(res[0]) - out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32) + out_ref = np.array([[3, 2, 1], [6, 5, 4]]).astype(np.float32) self.assertTrue( (out_np == out_ref).all(), msg='flip output is wrong, out =' + str(out_np)) @@ -53,7 +55,10 @@ class TestFlipOp_API(unittest.TestCase): with fluid.dygraph.guard(): inputs = fluid.dygraph.to_variable(img) ret = paddle.flip(inputs, [0]) - out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32) + ret = ret.flip(0) + ret = paddle.flip(ret, 1) + out_ref = np.array([[3, 2, 1], [6, 5, 4]]).astype(np.float32) + self.assertTrue( (ret.numpy() == out_ref).all(), msg='flip output is wrong, out =' + str(ret.numpy())) @@ -82,6 +87,8 @@ class TestFlipOp(OpTest): def calc_ref_res(self): res = self.inputs['X'] + if isinstance(self.axis, int): + return np.flip(res, self.axis) for axis in self.axis: res = np.flip(res, axis) return res diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 2c02f023bf2..434069fe74b 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -223,7 +223,7 @@ def flip(x, axis, name=None): Args: x (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor x should be float32, float64, int32, int64, bool. - axis (list|tuple): The axis(axes) to flip on. Negative indices for indexing from the end are accepted. + axis (list|tuple|int): The axis(axes) to flip on. Negative indices for indexing from the end are accepted. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . @@ -240,10 +240,17 @@ def flip(x, axis, name=None): x = np.arange(image_shape[0] * image_shape[1] * image_shape[2]).reshape(image_shape) x = x.astype('float32') img = paddle.to_tensor(x) - out = paddle.flip(img, [0,1]) + tmp = paddle.flip(img, [0,1]) + print(tmp) # [[[10,11],[8, 9]], [[6, 7],[4, 5]], [[2, 3],[0, 1]]] - print(out) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]] + out = paddle.flip(tmp,-1) + print(out) # [[[11,10],[9, 8]], [[7, 6],[5, 4]], [[3, 2],[1, 0]]] """ + if isinstance(axis, int): + axis = [axis] + if in_dygraph_mode(): + return core.ops.flip(x, "axis", axis) + helper = LayerHelper("flip", **locals()) check_type(x, 'X', (Variable), 'flip') dtype = helper.input_dtype('x') -- GitLab