未验证 提交 9b2d53fc 编写于 作者: Y yeliang2258 提交者: GitHub

增强equal API,输入Y支持int,float,bool或者tensor类型 (#35695)

* update equal op, input Y can be float,int,bool or tensor

* update test

* update code style

* update code style

* update doc

* update str check

* remote str

* add type check
上级 28fffef6
...@@ -96,6 +96,21 @@ def create_paddle_case(op_type, callback): ...@@ -96,6 +96,21 @@ def create_paddle_case(op_type, callback):
fetch_list=[out]) fetch_list=[out])
self.assertEqual((res == self.real_result).all(), True) self.assertEqual((res == self.real_result).all(), True)
def test_api_float(self):
if self.op_type == "equal":
paddle.enable_static()
with program_guard(Program(), Program()):
x = fluid.data(name='x', shape=[4], dtype='int64')
y = fluid.data(name='y', shape=[1], dtype='int64')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = fluid.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
"y": 1.0},
fetch_list=[out])
self.real_result = np.array([1, 0, 0, 0]).astype(np.int64)
self.assertEqual((res == self.real_result).all(), True)
def test_dynamic_api(self): def test_dynamic_api(self):
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(self.input_x) x = paddle.to_tensor(self.input_x)
...@@ -105,6 +120,47 @@ def create_paddle_case(op_type, callback): ...@@ -105,6 +120,47 @@ def create_paddle_case(op_type, callback):
self.assertEqual((out.numpy() == self.real_result).all(), True) self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static() paddle.enable_static()
def test_dynamic_api_int(self):
if self.op_type == "equal":
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
op = eval("paddle.%s" % (self.op_type))
out = op(x, 1)
self.real_result = np.array([1, 0, 0, 0]).astype(np.int64)
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()
def test_dynamic_api_float(self):
if self.op_type == "equal":
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
op = eval("paddle.%s" % (self.op_type))
out = op(x, 1.0)
self.real_result = np.array([1, 0, 0, 0]).astype(np.int64)
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()
def test_assert(self):
def test_dynamic_api_string(self):
if self.op_type == "equal":
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
op = eval("paddle.%s" % (self.op_type))
out = op(x, "1.0")
paddle.enable_static()
self.assertRaises(TypeError, test_dynamic_api_string)
def test_dynamic_api_bool(self):
if self.op_type == "equal":
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
op = eval("paddle.%s" % (self.op_type))
out = op(x, True)
self.real_result = np.array([1, 0, 0, 0]).astype(np.int64)
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()
def test_broadcast_api_1(self): def test_broadcast_api_1(self):
paddle.enable_static() paddle.enable_static()
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
...@@ -28,6 +28,7 @@ from ..fluid.layers import logical_xor # noqa: F401 ...@@ -28,6 +28,7 @@ from ..fluid.layers import logical_xor # noqa: F401
from paddle.common_ops_import import core from paddle.common_ops_import import core
from paddle import _C_ops from paddle import _C_ops
from paddle.tensor.creation import full
__all__ = [] __all__ = []
...@@ -174,6 +175,13 @@ def equal(x, y, name=None): ...@@ -174,6 +175,13 @@ def equal(x, y, name=None):
result1 = paddle.equal(x, y) result1 = paddle.equal(x, y)
print(result1) # result1 = [True False False] print(result1) # result1 = [True False False]
""" """
if not isinstance(y, (int, bool, float, Variable)):
raise TypeError(
"Type of input args must be float, bool, int or Tensor, but received type {}".
format(type(y)))
if not isinstance(y, Variable):
y = full(shape=[1], dtype=x.dtype, fill_value=y)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.equal(x, y) return _C_ops.equal(x, y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册