未验证 提交 c7ae6c62 编写于 作者: G Guo Sheng 提交者: GitHub

Fix the assign data check (#20564)

* Fix the assign data check. test=develop

* Fix test_assign_op.py. test=develop
上级 7faa3e95
......@@ -455,12 +455,12 @@ def assign(input, output=None):
helper = LayerHelper('assign', **locals())
if isinstance(input, Variable):
if convert_dtype(input.dtype) not in [
'float32', 'float64', 'int32', 'int64'
'float32', 'float64', 'int32', 'int64', 'bool'
]:
raise TypeError(
"When the type of 'input' in assign is Variable, the data "
"type of 'input' must be float32, float64, int32 or int64, "
"but received %s." % convert_dtype(input.dtype))
"type of 'input' must be float32, float64, int32, int64 or "
"bool, but received %s." % convert_dtype(input.dtype))
if output is None:
output = helper.create_variable_for_type_inference(
dtype=input.dtype)
......
......@@ -44,9 +44,7 @@ class TestAssignOpError(op_test.OpTest):
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.assign, x1)
# When the type of input is Variable, the dtype of input must be float32, float64, int32, int64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="bool")
self.assertRaises(TypeError, fluid.layers.assign, x2)
# When the type of input is Variable, the dtype of input must be float32, float64, int32, int64, bool.
x3 = fluid.layers.data(name='x3', shape=[4], dtype="float16")
self.assertRaises(TypeError, fluid.layers.assign, x3)
x4 = fluid.layers.data(name='x4', shape=[4], dtype="uint8")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册