未验证 提交 c7ae6c62 编写于 作者: G Guo Sheng 提交者: GitHub

Fix the assign data check (#20564)

* Fix the assign data check. test=develop

* Fix test_assign_op.py. test=develop
上级 7faa3e95
...@@ -455,12 +455,12 @@ def assign(input, output=None): ...@@ -455,12 +455,12 @@ def assign(input, output=None):
helper = LayerHelper('assign', **locals()) helper = LayerHelper('assign', **locals())
if isinstance(input, Variable): if isinstance(input, Variable):
if convert_dtype(input.dtype) not in [ if convert_dtype(input.dtype) not in [
'float32', 'float64', 'int32', 'int64' 'float32', 'float64', 'int32', 'int64', 'bool'
]: ]:
raise TypeError( raise TypeError(
"When the type of 'input' in assign is Variable, the data " "When the type of 'input' in assign is Variable, the data "
"type of 'input' must be float32, float64, int32 or int64, " "type of 'input' must be float32, float64, int32, int64 or "
"but received %s." % convert_dtype(input.dtype)) "bool, but received %s." % convert_dtype(input.dtype))
if output is None: if output is None:
output = helper.create_variable_for_type_inference( output = helper.create_variable_for_type_inference(
dtype=input.dtype) dtype=input.dtype)
......
...@@ -44,9 +44,7 @@ class TestAssignOpError(op_test.OpTest): ...@@ -44,9 +44,7 @@ class TestAssignOpError(op_test.OpTest):
x1 = fluid.create_lod_tensor( x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace()) np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.assign, x1) self.assertRaises(TypeError, fluid.layers.assign, x1)
# When the type of input is Variable, the dtype of input must be float32, float64, int32, int64. # When the type of input is Variable, the dtype of input must be float32, float64, int32, int64, bool.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="bool")
self.assertRaises(TypeError, fluid.layers.assign, x2)
x3 = fluid.layers.data(name='x3', shape=[4], dtype="float16") x3 = fluid.layers.data(name='x3', shape=[4], dtype="float16")
self.assertRaises(TypeError, fluid.layers.assign, x3) self.assertRaises(TypeError, fluid.layers.assign, x3)
x4 = fluid.layers.data(name='x4', shape=[4], dtype="uint8") x4 = fluid.layers.data(name='x4', shape=[4], dtype="uint8")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册