From a25418b32c617f59fb3a797b4d72c9b9dbde12b3 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Mon, 23 Dec 2019 10:42:51 +0800 Subject: [PATCH] Set dtype of compare op's output to bool (#21864) * add unittests, test=develop * set dtype of compare op to bool, test=develop --- python/paddle/fluid/layers/math_op_patch.py | 9 ++++- .../tests/unittests/test_math_op_patch.py | 40 +++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index 3a164a2fd0c..eefb1326472 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -27,6 +27,8 @@ _supported_int_dtype_ = [ core.VarDesc.VarType.INT64, ] +compare_ops = ['__eq__', '__ne__', '__lt__', '__le__', '__gt__', '__ge__'] + def monkey_patch_variable(): def unique_tmp_name(): @@ -224,7 +226,12 @@ def monkey_patch_variable(): self = other_var other_var = tmp - out = create_new_tmp_var(current_block(self), dtype=lhs_dtype) + # NOTE(zhiqiu): the output of compare operator should be bool. + if method_name in compare_ops: + out = create_new_tmp_var(current_block(self), dtype="bool") + else: + out = create_new_tmp_var(current_block(self), dtype=lhs_dtype) + axis = -1 if other_var.shape[0] == -1: axis = 0 diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch.py b/python/paddle/fluid/tests/unittests/test_math_op_patch.py index 62944d0c879..f6eff22d6ce 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch.py @@ -200,6 +200,46 @@ class TestMathOpPatches(unittest.TestCase): b_np_actual = (a_np / 7).astype('int64') self.assertTrue(numpy.array_equal(b_np, b_np_actual)) + @prog_scope() + def test_equal(self): + a = fluid.layers.data(name="a", shape=[1], dtype='float32') + b = fluid.layers.data(name="b", shape=[1], dtype='float32') + c = (a == b) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.array([3, 4, 10, 14, 9, 18]).astype('float32') + b_np = numpy.array([3, 4, 11, 15, 8, 18]).astype('float32') + + c_np, = exe.run(fluid.default_main_program(), + feed={"a": a_np, + "b": b_np}, + fetch_list=[c]) + + self.assertTrue(numpy.array_equal(c_np, a_np == b_np)) + self.assertEqual(c.dtype, fluid.core.VarDesc.VarType.BOOL) + + @prog_scope() + def test_equal_and_cond(self): + a = fluid.layers.data(name="a", shape=[1], dtype='float32') + b = fluid.layers.data(name="b", shape=[1], dtype='float32') + + one = fluid.layers.ones(shape=[1], dtype='int32') + zero = fluid.layers.zeros(shape=[1], dtype='int32') + cond = (one == zero) + c = fluid.layers.cond(cond, lambda: a + b, lambda: a - b) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + a_np = numpy.array([3, 4, 10, 14, 9, 18]).astype('float') + b_np = numpy.array([3, 4, 11, 15, 8, 18]).astype('float') + c_np, = exe.run(fluid.default_main_program(), + feed={"a": a_np, + "b": b_np}, + fetch_list=[c]) + + self.assertTrue(numpy.array_equal(c_np, a_np - b_np)) + @prog_scope() def test_neg(self): a = fluid.layers.data(name="a", shape=[10, 1]) -- GitLab