未验证 提交 df27c264 编写于 作者: R Roc 提交者: GitHub

Fix and omptimize flip API (#34379)

上级 d3d174f7
...@@ -33,6 +33,8 @@ class TestFlipOp_API(unittest.TestCase): ...@@ -33,6 +33,8 @@ class TestFlipOp_API(unittest.TestCase):
axis = [0] axis = [0]
input = fluid.data(name='input', dtype='float32', shape=[2, 3]) input = fluid.data(name='input', dtype='float32', shape=[2, 3])
output = paddle.flip(input, axis) output = paddle.flip(input, axis)
output = paddle.flip(output, -1)
output = output.flip(0)
place = fluid.CPUPlace() place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
...@@ -43,7 +45,7 @@ class TestFlipOp_API(unittest.TestCase): ...@@ -43,7 +45,7 @@ class TestFlipOp_API(unittest.TestCase):
feed={'input': img}, feed={'input': img},
fetch_list=[output]) fetch_list=[output])
out_np = np.array(res[0]) out_np = np.array(res[0])
out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32) out_ref = np.array([[3, 2, 1], [6, 5, 4]]).astype(np.float32)
self.assertTrue( self.assertTrue(
(out_np == out_ref).all(), (out_np == out_ref).all(),
msg='flip output is wrong, out =' + str(out_np)) msg='flip output is wrong, out =' + str(out_np))
...@@ -53,7 +55,10 @@ class TestFlipOp_API(unittest.TestCase): ...@@ -53,7 +55,10 @@ class TestFlipOp_API(unittest.TestCase):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
inputs = fluid.dygraph.to_variable(img) inputs = fluid.dygraph.to_variable(img)
ret = paddle.flip(inputs, [0]) ret = paddle.flip(inputs, [0])
out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32) ret = ret.flip(0)
ret = paddle.flip(ret, 1)
out_ref = np.array([[3, 2, 1], [6, 5, 4]]).astype(np.float32)
self.assertTrue( self.assertTrue(
(ret.numpy() == out_ref).all(), (ret.numpy() == out_ref).all(),
msg='flip output is wrong, out =' + str(ret.numpy())) msg='flip output is wrong, out =' + str(ret.numpy()))
...@@ -82,6 +87,8 @@ class TestFlipOp(OpTest): ...@@ -82,6 +87,8 @@ class TestFlipOp(OpTest):
def calc_ref_res(self): def calc_ref_res(self):
res = self.inputs['X'] res = self.inputs['X']
if isinstance(self.axis, int):
return np.flip(res, self.axis)
for axis in self.axis: for axis in self.axis:
res = np.flip(res, axis) res = np.flip(res, axis)
return res return res
......
...@@ -223,7 +223,7 @@ def flip(x, axis, name=None): ...@@ -223,7 +223,7 @@ def flip(x, axis, name=None):
Args: Args:
x (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor x x (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor x
should be float32, float64, int32, int64, bool. should be float32, float64, int32, int64, bool.
axis (list|tuple): The axis(axes) to flip on. Negative indices for indexing from the end are accepted. axis (list|tuple|int): The axis(axes) to flip on. Negative indices for indexing from the end are accepted.
name (str, optional): The default value is None. Normally there is no need for user to set this property. name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` . For more information, please refer to :ref:`api_guide_Name` .
...@@ -240,10 +240,17 @@ def flip(x, axis, name=None): ...@@ -240,10 +240,17 @@ def flip(x, axis, name=None):
x = np.arange(image_shape[0] * image_shape[1] * image_shape[2]).reshape(image_shape) x = np.arange(image_shape[0] * image_shape[1] * image_shape[2]).reshape(image_shape)
x = x.astype('float32') x = x.astype('float32')
img = paddle.to_tensor(x) img = paddle.to_tensor(x)
out = paddle.flip(img, [0,1]) tmp = paddle.flip(img, [0,1])
print(tmp) # [[[10,11],[8, 9]], [[6, 7],[4, 5]], [[2, 3],[0, 1]]]
print(out) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]] out = paddle.flip(tmp,-1)
print(out) # [[[11,10],[9, 8]], [[7, 6],[5, 4]], [[3, 2],[1, 0]]]
""" """
if isinstance(axis, int):
axis = [axis]
if in_dygraph_mode():
return core.ops.flip(x, "axis", axis)
helper = LayerHelper("flip", **locals()) helper = LayerHelper("flip", **locals())
check_type(x, 'X', (Variable), 'flip') check_type(x, 'X', (Variable), 'flip')
dtype = helper.input_dtype('x') dtype = helper.input_dtype('x')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册