未验证 提交 1c7215ac 编写于 作者: W WangXi 提交者: GitHub

Speedup reader check_input_array when item is array (#25395)

上级 52be62c5
......@@ -97,6 +97,18 @@ class DataLoaderBase(object):
def __next__(self):
raise NotImplementedError()
@classmethod
def _check_input_array(cls, item):
arr = np.asarray(item)
if arr.dtype == np.object:
raise TypeError(
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
"this means the input data contains nested lists with different lengths. "
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
" to locate the data causes this issue.\n\t* Please consider using "
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor.")
return arr
class DataLoader(object):
"""
......@@ -807,17 +819,6 @@ class DygraphGeneratorLoader(DataLoaderBase):
self._reset()
six.reraise(*sys.exc_info())
@classmethod
def _check_input_array(cls, item):
arr = np.array(item)
if arr.dtype == np.object:
raise TypeError(
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
"this means the input data contains nested lists with different lengths. "
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
" to locate the data causes this issue.\n\t* Please consider using "
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor.")
def _exit_thread_expectedly(self):
self._thread_done_event.set()
self._blocking_queue.close()
......@@ -894,7 +895,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
array = core.LoDTensorArray()
for item in sample:
if not isinstance(item, core.LoDTensor):
self._check_input_array(item)
item = self._check_input_array(item)
tmp = core.LoDTensor()
tmp.set(item, core.CPUPlace())
item = tmp
......@@ -1115,19 +1116,6 @@ class GeneratorLoader(DataLoaderBase):
assert not self._iterable, "reset() cannot be called when DataLoader is iterable"
self._reset()
@classmethod
def _check_input_array(cls, item):
arr = np.array(item)
if arr.dtype == np.object:
raise TypeError((
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
"this means the input data contains nested lists with different lengths. "
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
" to locate the data causes this issue.\n\t* Please consider using "
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor."))
return arr
def _start(self):
def __thread_main__():
try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册