未验证 提交 a25418b3 编写于 作者: L Leo Chen 提交者: GitHub

Set dtype of compare op's output to bool (#21864)

* add unittests, test=develop

* set dtype of compare op to bool, test=develop
上级 934d9986
......@@ -27,6 +27,8 @@ _supported_int_dtype_ = [
core.VarDesc.VarType.INT64,
]
compare_ops = ['__eq__', '__ne__', '__lt__', '__le__', '__gt__', '__ge__']
def monkey_patch_variable():
def unique_tmp_name():
......@@ -224,7 +226,12 @@ def monkey_patch_variable():
self = other_var
other_var = tmp
out = create_new_tmp_var(current_block(self), dtype=lhs_dtype)
# NOTE(zhiqiu): the output of compare operator should be bool.
if method_name in compare_ops:
out = create_new_tmp_var(current_block(self), dtype="bool")
else:
out = create_new_tmp_var(current_block(self), dtype=lhs_dtype)
axis = -1
if other_var.shape[0] == -1:
axis = 0
......
......@@ -200,6 +200,46 @@ class TestMathOpPatches(unittest.TestCase):
b_np_actual = (a_np / 7).astype('int64')
self.assertTrue(numpy.array_equal(b_np, b_np_actual))
@prog_scope()
def test_equal(self):
a = fluid.layers.data(name="a", shape=[1], dtype='float32')
b = fluid.layers.data(name="b", shape=[1], dtype='float32')
c = (a == b)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
a_np = numpy.array([3, 4, 10, 14, 9, 18]).astype('float32')
b_np = numpy.array([3, 4, 11, 15, 8, 18]).astype('float32')
c_np, = exe.run(fluid.default_main_program(),
feed={"a": a_np,
"b": b_np},
fetch_list=[c])
self.assertTrue(numpy.array_equal(c_np, a_np == b_np))
self.assertEqual(c.dtype, fluid.core.VarDesc.VarType.BOOL)
@prog_scope()
def test_equal_and_cond(self):
a = fluid.layers.data(name="a", shape=[1], dtype='float32')
b = fluid.layers.data(name="b", shape=[1], dtype='float32')
one = fluid.layers.ones(shape=[1], dtype='int32')
zero = fluid.layers.zeros(shape=[1], dtype='int32')
cond = (one == zero)
c = fluid.layers.cond(cond, lambda: a + b, lambda: a - b)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
a_np = numpy.array([3, 4, 10, 14, 9, 18]).astype('float')
b_np = numpy.array([3, 4, 11, 15, 8, 18]).astype('float')
c_np, = exe.run(fluid.default_main_program(),
feed={"a": a_np,
"b": b_np},
fetch_list=[c])
self.assertTrue(numpy.array_equal(c_np, a_np - b_np))
@prog_scope()
def test_neg(self):
a = fluid.layers.data(name="a", shape=[10, 1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册