未验证 提交 cc4bbfc0 编写于 作者: R Roc 提交者: GitHub

Fix and omptimize flip API (#34379) (#34477)

上级 2dffe32b
......@@ -33,6 +33,8 @@ class TestFlipOp_API(unittest.TestCase):
axis = [0]
input = fluid.data(name='input', dtype='float32', shape=[2, 3])
output = paddle.flip(input, axis)
output = paddle.flip(output, -1)
output = output.flip(0)
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
......@@ -43,7 +45,7 @@ class TestFlipOp_API(unittest.TestCase):
feed={'input': img},
fetch_list=[output])
out_np = np.array(res[0])
out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32)
out_ref = np.array([[3, 2, 1], [6, 5, 4]]).astype(np.float32)
self.assertTrue(
(out_np == out_ref).all(),
msg='flip output is wrong, out =' + str(out_np))
......@@ -53,7 +55,10 @@ class TestFlipOp_API(unittest.TestCase):
with fluid.dygraph.guard():
inputs = fluid.dygraph.to_variable(img)
ret = paddle.flip(inputs, [0])
out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32)
ret = ret.flip(0)
ret = paddle.flip(ret, 1)
out_ref = np.array([[3, 2, 1], [6, 5, 4]]).astype(np.float32)
self.assertTrue(
(ret.numpy() == out_ref).all(),
msg='flip output is wrong, out =' + str(ret.numpy()))
......@@ -82,6 +87,8 @@ class TestFlipOp(OpTest):
def calc_ref_res(self):
res = self.inputs['X']
if isinstance(self.axis, int):
return np.flip(res, self.axis)
for axis in self.axis:
res = np.flip(res, axis)
return res
......
......@@ -128,7 +128,7 @@ def flip(x, axis, name=None):
Args:
x (Tensor): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor x
should be float32, float64, int32, int64, bool.
axis (list|tuple): The axis(axes) to flip on. Negative indices for indexing from the end are accepted.
axis (list|tuple|int): The axis(axes) to flip on. Negative indices for indexing from the end are accepted.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
......@@ -145,10 +145,17 @@ def flip(x, axis, name=None):
x = np.arange(image_shape[0] * image_shape[1] * image_shape[2]).reshape(image_shape)
x = x.astype('float32')
img = paddle.to_tensor(x)
out = paddle.flip(img, [0,1])
tmp = paddle.flip(img, [0,1])
print(tmp) # [[[10,11],[8, 9]], [[6, 7],[4, 5]], [[2, 3],[0, 1]]]
print(out) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]]
out = paddle.flip(tmp,-1)
print(out) # [[[11,10],[9, 8]], [[7, 6],[5, 4]], [[3, 2],[1, 0]]]
"""
if isinstance(axis, int):
axis = [axis]
if in_dygraph_mode():
return core.ops.flip(x, "axis", axis)
helper = LayerHelper("flip", **locals())
check_type(x, 'X', (Variable), 'flip')
dtype = helper.input_dtype('x')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册