未验证 提交 d06d3276 编写于 作者: H hong19860320 提交者: GitHub

[cherry-pick] Fix the error message of assign op (#20508) (#20542)

test=develop, test=release/1.6

* Fix the error message of assign op
test=develop

* Refine input type checking for assign op
test=develop

* Refine unittest of assign op
test=develop
上级 91e353bd
......@@ -400,9 +400,17 @@ def assign(input, output=None):
result3 = fluid.layers.assign(np.array([[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]], dtype='float32')) # result3 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]]
"""
helper = LayerHelper('assign', **locals())
if output is None:
output = helper.create_variable_for_type_inference(dtype=input.dtype)
if isinstance(input, Variable):
if convert_dtype(input.dtype) not in [
'float32', 'float64', 'int32', 'int64'
]:
raise TypeError(
"When the type of 'input' in assign is Variable, the data "
"type of 'input' must be float32, float64, int32 or int64, "
"but received %s." % convert_dtype(input.dtype))
if output is None:
output = helper.create_variable_for_type_inference(
dtype=input.dtype)
helper.append_op(
type='assign', inputs={'X': [input]}, outputs={'Out': [output]})
elif isinstance(input, numpy.ndarray):
......@@ -414,11 +422,16 @@ def assign(input, output=None):
value_name = "int32_values"
values = [int(v) for v in input.flat]
else:
raise ValueError("Unsupported dtype %s", input.dtype)
raise TypeError(
"When the type of 'input' in assign is numpy.ndarray, "
"the data type of 'input' must be float32 or int32, but "
"received %s." % convert_dtype(dtype))
if input.size > 1024 * 1024:
raise ValueError("The size of input is too big. Please consider "
"saving it to file and 'load_op' to load it")
if output is None:
output = helper.create_variable_for_type_inference(
dtype=input.dtype)
helper.append_op(
type='assign_value',
outputs={'Out': [output]},
......@@ -428,7 +441,8 @@ def assign(input, output=None):
value_name: values
})
else:
raise ValueError("Wrong type for assign input: %s" % type(input))
raise TypeError("The type of 'input' in assign must be Variable or "
"numpy.ndarray, but received %s" % type(input))
return output
......
......@@ -15,14 +15,18 @@
from __future__ import print_function
import op_test
import numpy
import numpy as np
import unittest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
class TestAssignOp(op_test.OpTest):
def setUp(self):
self.op_type = "assign"
x = numpy.random.random(size=(100, 10))
x = np.random.random(size=(100, 10))
self.inputs = {'X': x}
self.outputs = {'Out': x}
......@@ -33,5 +37,32 @@ class TestAssignOp(op_test.OpTest):
self.check_grad(['X'], 'Out')
class TestAssignOpError(op_test.OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The type of input must be Variable or numpy.ndarray.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.assign, x1)
# When the type of input is Variable, the dtype of input must be float32, float64, int32, int64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="bool")
self.assertRaises(TypeError, fluid.layers.assign, x2)
x3 = fluid.layers.data(name='x3', shape=[4], dtype="float16")
self.assertRaises(TypeError, fluid.layers.assign, x3)
x4 = fluid.layers.data(name='x4', shape=[4], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.assign, x4)
# When the type of input is numpy.ndarray, the dtype of input must be float32, int32.
x5 = np.array([[2.5, 2.5]], dtype='bool')
self.assertRaises(TypeError, fluid.layers.assign, x5)
x6 = np.array([[2.5, 2.5]], dtype='float16')
self.assertRaises(TypeError, fluid.layers.assign, x6)
x7 = np.array([[2.5, 2.5]], dtype='float64')
self.assertRaises(TypeError, fluid.layers.assign, x7)
x8 = np.array([[2.5, 2.5]], dtype='int64')
self.assertRaises(TypeError, fluid.layers.assign, x8)
x9 = np.array([[2.5, 2.5]], dtype='uint8')
self.assertRaises(TypeError, fluid.layers.assign, x9)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册