From 9b2d53fc0afe6318c13a5f4345944a48da6aa67e Mon Sep 17 00:00:00 2001 From: yeliang2258 <30516196+yeliang2258@users.noreply.github.com> Date: Fri, 17 Sep 2021 10:48:40 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=BC=BAequal=20API=EF=BC=8C=E8=BE=93?= =?UTF-8?q?=E5=85=A5Y=E6=94=AF=E6=8C=81int=EF=BC=8Cfloat=EF=BC=8Cbool?= =?UTF-8?q?=E6=88=96=E8=80=85tensor=E7=B1=BB=E5=9E=8B=20(#35695)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update equal op, input Y can be float,int,bool or tensor * update test * update code style * update code style * update doc * update str check * remote str * add type check --- .../fluid/tests/unittests/test_compare_op.py | 56 +++++++++++++++++++ python/paddle/tensor/logic.py | 8 +++ 2 files changed, 64 insertions(+) mode change 100644 => 100755 python/paddle/fluid/tests/unittests/test_compare_op.py mode change 100644 => 100755 python/paddle/tensor/logic.py diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py old mode 100644 new mode 100755 index 7a14267588..8975638548 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -96,6 +96,21 @@ def create_paddle_case(op_type, callback): fetch_list=[out]) self.assertEqual((res == self.real_result).all(), True) + def test_api_float(self): + if self.op_type == "equal": + paddle.enable_static() + with program_guard(Program(), Program()): + x = fluid.data(name='x', shape=[4], dtype='int64') + y = fluid.data(name='y', shape=[1], dtype='int64') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = fluid.Executor(self.place) + res, = exe.run(feed={"x": self.input_x, + "y": 1.0}, + fetch_list=[out]) + self.real_result = np.array([1, 0, 0, 0]).astype(np.int64) + self.assertEqual((res == self.real_result).all(), True) + def test_dynamic_api(self): paddle.disable_static() x = paddle.to_tensor(self.input_x) @@ -105,6 +120,47 @@ def create_paddle_case(op_type, callback): self.assertEqual((out.numpy() == self.real_result).all(), True) paddle.enable_static() + def test_dynamic_api_int(self): + if self.op_type == "equal": + paddle.disable_static() + x = paddle.to_tensor(self.input_x) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, 1) + self.real_result = np.array([1, 0, 0, 0]).astype(np.int64) + self.assertEqual((out.numpy() == self.real_result).all(), True) + paddle.enable_static() + + def test_dynamic_api_float(self): + if self.op_type == "equal": + paddle.disable_static() + x = paddle.to_tensor(self.input_x) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, 1.0) + self.real_result = np.array([1, 0, 0, 0]).astype(np.int64) + self.assertEqual((out.numpy() == self.real_result).all(), True) + paddle.enable_static() + + def test_assert(self): + def test_dynamic_api_string(self): + if self.op_type == "equal": + paddle.disable_static() + x = paddle.to_tensor(self.input_x) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, "1.0") + paddle.enable_static() + + self.assertRaises(TypeError, test_dynamic_api_string) + + def test_dynamic_api_bool(self): + if self.op_type == "equal": + paddle.disable_static() + x = paddle.to_tensor(self.input_x) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, True) + self.real_result = np.array([1, 0, 0, 0]).astype(np.int64) + self.assertEqual((out.numpy() == self.real_result).all(), True) + paddle.enable_static() + def test_broadcast_api_1(self): paddle.enable_static() with program_guard(Program(), Program()): diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py old mode 100644 new mode 100755 index 65ad308875..f944813f8e --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -28,6 +28,7 @@ from ..fluid.layers import logical_xor # noqa: F401 from paddle.common_ops_import import core from paddle import _C_ops +from paddle.tensor.creation import full __all__ = [] @@ -174,6 +175,13 @@ def equal(x, y, name=None): result1 = paddle.equal(x, y) print(result1) # result1 = [True False False] """ + if not isinstance(y, (int, bool, float, Variable)): + raise TypeError( + "Type of input args must be float, bool, int or Tensor, but received type {}". + format(type(y))) + if not isinstance(y, Variable): + y = full(shape=[1], dtype=x.dtype, fill_value=y) + if in_dygraph_mode(): return _C_ops.equal(x, y) -- GitLab