From 5b8c5b7bc0fbf0a0e8a70442eefd7432011dfbf5 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Mon, 4 Apr 2022 15:51:11 +0800 Subject: [PATCH] Fix some PaddleTest UT (#41373) * Fix some PaddleTest UT * refine code * set default value --- python/paddle/tensor/logic.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 3896fa535ff..a4ff8724663 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -182,7 +182,8 @@ def equal(x, y, name=None): y = full(shape=[1], dtype=x.dtype, fill_value=y) if in_dygraph_mode(): - return _C_ops.final_state_equal(x, y) + axis = -1 + return _C_ops.final_state_equal(x, y, axis) else: if _in_legacy_dygraph(): return _C_ops.equal(x, y) @@ -231,7 +232,8 @@ def greater_equal(x, y, name=None): print(result1) # result1 = [True False True] """ if in_dygraph_mode(): - return _C_ops.final_state_greater_equal(x, y) + axis = -1 + return _C_ops.final_state_greater_equal(x, y, axis) else: if _in_legacy_dygraph(): return _C_ops.greater_equal(x, y) @@ -331,7 +333,8 @@ def less_equal(x, y, name=None): print(result1) # result1 = [True True False] """ if in_dygraph_mode(): - return _C_ops.final_state_less_equal(x, y) + axis = -1 + return _C_ops.final_state_less_equal(x, y, axis) else: if _in_legacy_dygraph(): return _C_ops.less_equal(x, y) @@ -381,7 +384,8 @@ def less_than(x, y, name=None): print(result1) # result1 = [False True False] """ if in_dygraph_mode(): - return _C_ops.final_state_less_than(x, y) + axis = -1 + return _C_ops.final_state_less_than(x, y, axis) else: if _in_legacy_dygraph(): return _C_ops.less_than(x, y) @@ -431,7 +435,8 @@ def not_equal(x, y, name=None): print(result1) # result1 = [False True True] """ if in_dygraph_mode(): - return _C_ops.final_state_not_equal(x, y) + axis = -1 + return _C_ops.final_state_not_equal(x, y, axis) else: if _in_legacy_dygraph(): return _C_ops.not_equal(x, y) @@ -538,7 +543,7 @@ def bitwise_and(x, y, out=None, name=None): res = paddle.bitwise_and(x, y) print(res) # [0, 2, 1] """ - if in_dygraph_mode() and out == None: + if in_dygraph_mode() and out is None: return _C_ops.final_state_bitwise_and(x, y) return _bitwise_op( op_name="bitwise_and", x=x, y=y, name=name, out=out, binary_op=True) @@ -566,7 +571,7 @@ def bitwise_or(x, y, out=None, name=None): res = paddle.bitwise_or(x, y) print(res) # [-1, -1, -3] """ - if in_dygraph_mode() and out == None: + if in_dygraph_mode() and out is None: return _C_ops.final_state_bitwise_or(x, y) return _bitwise_op( @@ -595,7 +600,7 @@ def bitwise_xor(x, y, out=None, name=None): res = paddle.bitwise_xor(x, y) print(res) # [-1, -3, -4] """ - if in_dygraph_mode() and out == None: + if in_dygraph_mode() and out is None: return _C_ops.final_state_bitwise_xor(x, y) return _bitwise_op( op_name="bitwise_xor", x=x, y=y, name=name, out=out, binary_op=True) @@ -621,7 +626,7 @@ def bitwise_not(x, out=None, name=None): res = paddle.bitwise_not(x) print(res) # [4, 0, -2] """ - if in_dygraph_mode() and out == None: + if in_dygraph_mode() and out is None: return _C_ops.final_state_bitwise_not(x) return _bitwise_op( -- GitLab