未验证 提交 7da4455f 编写于 作者: W wawltor 提交者: GitHub

support the bool tensor and scalar (#32272)

上级 3a804a0e
......@@ -29,6 +29,7 @@ _supported_int_dtype_ = [
core.VarDesc.VarType.INT16,
core.VarDesc.VarType.INT32,
core.VarDesc.VarType.INT64,
core.VarDesc.VarType.BOOL,
]
# NOTE(chenweihang): We currently do not fully support the type promotion
......
......@@ -22,6 +22,7 @@ from ..framework import Variable, unique_name
from .layer_function_generator import OpProtoHolder
_supported_int_dtype_ = [
core.VarDesc.VarType.BOOL,
core.VarDesc.VarType.UINT8,
core.VarDesc.VarType.INT8,
core.VarDesc.VarType.INT16,
......
......@@ -520,6 +520,23 @@ class TestRealComplexElementwiseAddOp(TestComplexElementwiseAddOp):
self.grad_y = self.grad_out
class TestBoolAddFloatElementwiseAddop(unittest.TestCase):
def test_static_add(self):
paddle.enable_static()
a = 1.5
b = paddle.full([4, 5, 6], True, dtype='bool')
c = a + b
self.assertTrue(c.dtype == core.VarDesc.VarType.FP32)
paddle.enable_static()
def test_dygraph_add(self):
paddle.disable_static()
a = 1.5
b = paddle.full([4, 5, 6], True, dtype='bool')
c = a + b
self.assertTrue(c.dtype == core.VarDesc.VarType.FP32)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册