未验证 提交 bda7eab7 编写于 作者: L liuwei1031 提交者: GitHub

improve the error message when handling ndarray with unsupported dtype (#19949)

* impove error message when passing ndarray with object dtype

* imporve message format

* change assert to raise TypeError

* remind user how to locate the irregular data instead of printing

* add unittest for input array type check
上级 1d32897c
......@@ -451,6 +451,17 @@ class GeneratorLoader(DataLoaderBase):
assert not self._iterable, "reset() cannot be called when DataLoader is iterable"
self._reset()
@classmethod
def _check_input_array(cls, item):
arr = np.array(item)
if arr.dtype == np.object:
raise TypeError((
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
"this means the input data contains nested lists with different lengths. "
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
" to locate the data causes this issue.\n\t* Please consider using "
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor."))
def _start(self):
def __thread_main__():
try:
......@@ -458,6 +469,7 @@ class GeneratorLoader(DataLoaderBase):
array = core.LoDTensorArray()
for item in tensors:
if not isinstance(item, core.LoDTensor):
self._check_input_array(item)
tmp = core.LoDTensor()
tmp.set(item, core.CPUPlace())
item = tmp
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import numpy as np
class TestPyReaderErrorMsg(unittest.TestCase):
def test_check_input_array(self):
fluid.reader.GeneratorLoader._check_input_array([
np.random.randint(
100, size=[2]), np.random.randint(
100, size=[2]), np.random.randint(
100, size=[2])
])
self.assertRaises(
TypeError,
fluid.reader.GeneratorLoader._check_input_array, [
np.random.randint(
100, size=[2]), np.random.randint(
100, size=[1]), np.random.randint(
100, size=[3])
])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册