未验证 提交 201b7c9d 编写于 作者: Z zmxdream 提交者: GitHub

[rot90] fix rot90 (#38042)

* [rot90] fix rot90

* fix rot90

* fix for ci. test=develop

* fix rot90. test=develop

* update. test=develop

* update. test=develop
上级 b76ef045
...@@ -208,6 +208,32 @@ class TestRot90_API(unittest.TestCase): ...@@ -208,6 +208,32 @@ class TestRot90_API(unittest.TestCase):
(out_np == out_ref).all(), (out_np == out_ref).all(),
msg='rot90 output is wrong, out =' + str(out_np)) msg='rot90 output is wrong, out =' + str(out_np))
def test_static_neg_k_4(self):
paddle.enable_static()
input = fluid.data(name='input', dtype='float32', shape=[2, 3])
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
input = fluid.data(name='input', dtype='float32', shape=[2, 3])
output = paddle.rot90(input, k=-4, axes=[0, 1])
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_program)
img = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)
res = exe.run(train_program,
feed={'input': img},
fetch_list=[output])
out_np = np.array(res[0])
out_ref = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)
self.assertTrue(
(out_np == out_ref).all(),
msg='rot90 output is wrong, out =' + str(out_np))
def test_error_api(self): def test_error_api(self):
paddle.enable_static() paddle.enable_static()
......
...@@ -502,7 +502,7 @@ def rot90(x, k=1, axes=[0, 1], name=None): ...@@ -502,7 +502,7 @@ def rot90(x, k=1, axes=[0, 1], name=None):
Args: Args:
x (Tensor): The input Tensor(or LoDTensor). The data type of the input Tensor x x (Tensor): The input Tensor(or LoDTensor). The data type of the input Tensor x
should be float16, float32, float64, int32, int64, bool. should be float16, float32, float64, int32, int64, bool. float16 is only supported on gpu.
k (int, optional): Direction and number of times to rotate, default value: 1. k (int, optional): Direction and number of times to rotate, default value: 1.
axes (list|tuple, optional): Axes to rotate, dimension must be 2. default value: [0, 1]. axes (list|tuple, optional): Axes to rotate, dimension must be 2. default value: [0, 1].
name (str, optional): The default value is None. Normally there is no need for user to set this property. name (str, optional): The default value is None. Normally there is no need for user to set this property.
...@@ -577,8 +577,7 @@ def rot90(x, k=1, axes=[0, 1], name=None): ...@@ -577,8 +577,7 @@ def rot90(x, k=1, axes=[0, 1], name=None):
raise ValueError("Rotation axis1 out of range, axis1 = {}".format(axes[ raise ValueError("Rotation axis1 out of range, axis1 = {}".format(axes[
1])) 1]))
## k % 4 k %= 4
k = k % 4 if k >= 0 else 4 - (-k % 4)
if k == 0: if k == 0:
return x return x
if k == 2: if k == 2:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册