未验证 提交 472c9085 编写于 作者: A Aurelius84 提交者: GitHub

[@to_static]Enhance error msg in paddle.assign in static mode (#38069)

* Enhance error msg in paddle.assign in static mode

* fix unittest
上级 8f800dc0
......@@ -616,6 +616,11 @@ def assign(input, output=None):
helper.append_op(
type='assign', inputs={'X': [input]}, outputs={'Out': [output]})
elif isinstance(input, numpy.ndarray):
# Not support [var, var, ...] currently.
if len(input.shape) > 0 and any(isinstance(x, Variable) for x in input):
raise TypeError(
"Required type(input) numpy.ndarray, but found `list(Variable)` in input."
)
dtype = convert_np_dtype_to_dtype_(input.dtype)
if dtype == VarDesc.VarType.FP64:
# Setting FP64 numpy data is not supported in Paddle, so we
......
......@@ -181,6 +181,13 @@ class TestAssignOpErrorApi(unittest.TestCase):
x2 = np.array([[2.5, 2.5]], dtype='uint8')
self.assertRaises(TypeError, paddle.assign, x2)
def test_type_error(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = [paddle.randn([3, 3]), paddle.randn([3, 3])]
# not support to assign list(var)
self.assertRaises(TypeError, paddle.assign, x)
if __name__ == '__main__':
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册