未验证 提交 1ac6a093 编写于 作者: K Kaipeng Deng 提交者: GitHub

[cherry pick] dataloader fix (#31028)

* fix dataloader collate return list mix tensor and numpy array. test=develop

* remove numpy array check in single-process dataloader. test=develop
上级 9ad9f357
...@@ -320,7 +320,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -320,7 +320,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
array = core.LoDTensorArray() array = core.LoDTensorArray()
for slot in batch: for slot in batch:
if not isinstance(slot, core.LoDTensor): if not isinstance(slot, core.LoDTensor):
self._check_input_array(slot)
# FIXME(dkp): blocking_queue only support # FIXME(dkp): blocking_queue only support
# core.LoDTensorArray as input now, read # core.LoDTensorArray as input now, read
# numpy data into a LoDTensorArray here, # numpy data into a LoDTensorArray here,
...@@ -346,19 +345,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -346,19 +345,6 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
logging.warning("DataLoader reader thread raised an exception.") logging.warning("DataLoader reader thread raised an exception.")
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
@classmethod
def _check_input_array(cls, item):
if isinstance(item, paddle.Tensor):
return
arr = np.array(item)
if arr.dtype == np.object:
raise TypeError((
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
"this means the input data contains nested lists with different lengths. "
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
" to locate the data causes this issue.\n\t* Please consider using "
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor."))
def __next__(self): def __next__(self):
try: try:
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -454,11 +440,16 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, ...@@ -454,11 +440,16 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
if use_shared_memory: if use_shared_memory:
# FIXME(dkp): _convert_to_tensor_list only support np.array # FIXME(dkp): _convert_to_tensor_list only support np.array
# list now, should support paddle.Tensor list # list now, should support paddle.Tensor list
if isinstance(batch[0][0], paddle.Tensor): new_batch = []
np_batch = [] for sample in batch:
for sample in batch: new_sample = []
np_batch.append([s.numpy() for s in sample]) for s in sample:
batch = np_batch if isinstance(s, paddle.Tensor):
new_sample.append(s.numpy())
else:
new_sample.append(s)
new_batch.append(new_sample)
batch = new_batch
tensor_list = core._convert_to_tensor_list(batch) tensor_list = core._convert_to_tensor_list(batch)
out_queue.put((idx, tensor_list)) out_queue.put((idx, tensor_list))
......
...@@ -142,5 +142,43 @@ class TestChainDataset(unittest.TestCase): ...@@ -142,5 +142,43 @@ class TestChainDataset(unittest.TestCase):
self.run_main(num_workers=0, places=p) self.run_main(num_workers=0, places=p)
class NumpyMixTensorDataset(Dataset):
def __init__(self, sample_num):
self.sample_num = sample_num
def __len__(self):
return self.sample_num
def __getitem__(self, idx):
np.random.seed(idx)
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return paddle.to_tensor(image, place=paddle.CPUPlace()), label
class TestNumpyMixTensorDataset(TestTensorDataset):
def run_main(self, num_workers, places):
paddle.static.default_startup_program().random_seed = 1
paddle.static.default_main_program().random_seed = 1
place = paddle.CPUPlace()
with fluid.dygraph.guard(place):
dataset = NumpyMixTensorDataset(16)
assert len(dataset) == 16
dataloader = DataLoader(
dataset,
places=place,
num_workers=num_workers,
batch_size=1,
drop_last=True)
for i, (input, label) in enumerate(dataloader()):
assert len(input) == 1
assert len(label) == 1
assert input.shape == [1, IMAGE_SIZE]
assert label.shape == [1, 1]
assert isinstance(input, paddle.Tensor)
assert isinstance(label, paddle.Tensor)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册