diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index 1df3e31ae4b269e8afdb65f97235aba7a3c4b549..41cce6a0858a6e43b5c78b02d1c266a3df7e258d 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -29,6 +29,7 @@ _supported_int_dtype_ = [ core.VarDesc.VarType.INT16, core.VarDesc.VarType.INT32, core.VarDesc.VarType.INT64, + core.VarDesc.VarType.BOOL, ] # NOTE(chenweihang): We currently do not fully support the type promotion diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index 96947bf72c7ddf299a4f4b372be3d62de4aaa1b5..a68331b156b3bf2f2bc4f18471ce97f59c5d67f4 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -22,6 +22,7 @@ from ..framework import Variable, unique_name from .layer_function_generator import OpProtoHolder _supported_int_dtype_ = [ + core.VarDesc.VarType.BOOL, core.VarDesc.VarType.UINT8, core.VarDesc.VarType.INT8, core.VarDesc.VarType.INT16, diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index fde7ea4b23801ed8b07ea72e078ed7646ec02aa7..cc362005f331193c367128e079ee8805113951f8 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -520,6 +520,23 @@ class TestRealComplexElementwiseAddOp(TestComplexElementwiseAddOp): self.grad_y = self.grad_out +class TestBoolAddFloatElementwiseAddop(unittest.TestCase): + def test_static_add(self): + paddle.enable_static() + a = 1.5 + b = paddle.full([4, 5, 6], True, dtype='bool') + c = a + b + self.assertTrue(c.dtype == core.VarDesc.VarType.FP32) + paddle.enable_static() + + def test_dygraph_add(self): + paddle.disable_static() + a = 1.5 + b = paddle.full([4, 5, 6], True, dtype='bool') + c = a + b + self.assertTrue(c.dtype == core.VarDesc.VarType.FP32) + + if __name__ == '__main__': paddle.enable_static() unittest.main()