未验证 提交 5b8c5b7b 编写于 作者: 0 0x45f 提交者: GitHub

Fix some PaddleTest UT (#41373)

* Fix some PaddleTest UT

* refine code

* set default value
上级 c02eeb96
...@@ -182,7 +182,8 @@ def equal(x, y, name=None): ...@@ -182,7 +182,8 @@ def equal(x, y, name=None):
y = full(shape=[1], dtype=x.dtype, fill_value=y) y = full(shape=[1], dtype=x.dtype, fill_value=y)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.final_state_equal(x, y) axis = -1
return _C_ops.final_state_equal(x, y, axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _C_ops.equal(x, y) return _C_ops.equal(x, y)
...@@ -231,7 +232,8 @@ def greater_equal(x, y, name=None): ...@@ -231,7 +232,8 @@ def greater_equal(x, y, name=None):
print(result1) # result1 = [True False True] print(result1) # result1 = [True False True]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.final_state_greater_equal(x, y) axis = -1
return _C_ops.final_state_greater_equal(x, y, axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _C_ops.greater_equal(x, y) return _C_ops.greater_equal(x, y)
...@@ -331,7 +333,8 @@ def less_equal(x, y, name=None): ...@@ -331,7 +333,8 @@ def less_equal(x, y, name=None):
print(result1) # result1 = [True True False] print(result1) # result1 = [True True False]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.final_state_less_equal(x, y) axis = -1
return _C_ops.final_state_less_equal(x, y, axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _C_ops.less_equal(x, y) return _C_ops.less_equal(x, y)
...@@ -381,7 +384,8 @@ def less_than(x, y, name=None): ...@@ -381,7 +384,8 @@ def less_than(x, y, name=None):
print(result1) # result1 = [False True False] print(result1) # result1 = [False True False]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.final_state_less_than(x, y) axis = -1
return _C_ops.final_state_less_than(x, y, axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _C_ops.less_than(x, y) return _C_ops.less_than(x, y)
...@@ -431,7 +435,8 @@ def not_equal(x, y, name=None): ...@@ -431,7 +435,8 @@ def not_equal(x, y, name=None):
print(result1) # result1 = [False True True] print(result1) # result1 = [False True True]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.final_state_not_equal(x, y) axis = -1
return _C_ops.final_state_not_equal(x, y, axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _C_ops.not_equal(x, y) return _C_ops.not_equal(x, y)
...@@ -538,7 +543,7 @@ def bitwise_and(x, y, out=None, name=None): ...@@ -538,7 +543,7 @@ def bitwise_and(x, y, out=None, name=None):
res = paddle.bitwise_and(x, y) res = paddle.bitwise_and(x, y)
print(res) # [0, 2, 1] print(res) # [0, 2, 1]
""" """
if in_dygraph_mode() and out == None: if in_dygraph_mode() and out is None:
return _C_ops.final_state_bitwise_and(x, y) return _C_ops.final_state_bitwise_and(x, y)
return _bitwise_op( return _bitwise_op(
op_name="bitwise_and", x=x, y=y, name=name, out=out, binary_op=True) op_name="bitwise_and", x=x, y=y, name=name, out=out, binary_op=True)
...@@ -566,7 +571,7 @@ def bitwise_or(x, y, out=None, name=None): ...@@ -566,7 +571,7 @@ def bitwise_or(x, y, out=None, name=None):
res = paddle.bitwise_or(x, y) res = paddle.bitwise_or(x, y)
print(res) # [-1, -1, -3] print(res) # [-1, -1, -3]
""" """
if in_dygraph_mode() and out == None: if in_dygraph_mode() and out is None:
return _C_ops.final_state_bitwise_or(x, y) return _C_ops.final_state_bitwise_or(x, y)
return _bitwise_op( return _bitwise_op(
...@@ -595,7 +600,7 @@ def bitwise_xor(x, y, out=None, name=None): ...@@ -595,7 +600,7 @@ def bitwise_xor(x, y, out=None, name=None):
res = paddle.bitwise_xor(x, y) res = paddle.bitwise_xor(x, y)
print(res) # [-1, -3, -4] print(res) # [-1, -3, -4]
""" """
if in_dygraph_mode() and out == None: if in_dygraph_mode() and out is None:
return _C_ops.final_state_bitwise_xor(x, y) return _C_ops.final_state_bitwise_xor(x, y)
return _bitwise_op( return _bitwise_op(
op_name="bitwise_xor", x=x, y=y, name=name, out=out, binary_op=True) op_name="bitwise_xor", x=x, y=y, name=name, out=out, binary_op=True)
...@@ -621,7 +626,7 @@ def bitwise_not(x, out=None, name=None): ...@@ -621,7 +626,7 @@ def bitwise_not(x, out=None, name=None):
res = paddle.bitwise_not(x) res = paddle.bitwise_not(x)
print(res) # [4, 0, -2] print(res) # [4, 0, -2]
""" """
if in_dygraph_mode() and out == None: if in_dygraph_mode() and out is None:
return _C_ops.final_state_bitwise_not(x) return _C_ops.final_state_bitwise_not(x)
return _bitwise_op( return _bitwise_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册