diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 0d7fc17da172c6bb598330e1cf3f727168c0c24d..af60776a3f1c58b42cf0275deba2a2d625a023d8 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -276,12 +276,12 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): data = self._reader.read_next_list() for i in range(len(data)): data[i] = data[i]._move_to_list() - data = [ - _restore_batch(d, s) for d, s in zip( - data, self._structure_infos[:len(self._places)]) + structs = [ + self._structure_infos.pop(0) + for _ in range(len(self._places)) ] - self._structure_infos = self._structure_infos[ - len(self._places):] + data = [_restore_batch(d, s) \ + for d, s in zip(data, structs)] # static graph organized data on multi-device with list, if # place number is 1, there is only 1 device, extra the data # from list for devices to be compatible with dygraph mode @@ -750,12 +750,12 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): data = self._reader.read_next_list() for i in range(len(data)): data[i] = data[i]._move_to_list() - data = [ - _restore_batch(d, s) for d, s in zip( - data, self._structure_infos[:len(self._places)]) + structs = [ + self._structure_infos.pop(0) + for _ in range(len(self._places)) ] - self._structure_infos = self._structure_infos[ - len(self._places):] + data = [_restore_batch(d, s) \ + for d, s in zip(data, structs)] # static graph organized data on multi-device with list, if # place number is 1, there is only 1 device, extra the data # from list for devices to be compatible with dygraph mode diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py index 4da22817be296ee6ef39a2bc54f4a60c11c53293..08e7f8502dccf85ebdc7a1bc3a44a150b2147c6e 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py @@ -22,6 +22,7 @@ import unittest import multiprocessing import numpy as np +import paddle import paddle.fluid as fluid from paddle.io import Dataset, BatchSampler, DataLoader @@ -182,7 +183,7 @@ class TestStaticDataLoader(unittest.TestCase): class TestStaticDataLoaderReturnList(unittest.TestCase): - def test_single_place(self): + def run_single_place(self, num_workers): scope = fluid.Scope() image = fluid.data(name='image', shape=[None, IMAGE_SIZE], @@ -192,7 +193,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase): dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) dataloader = DataLoader(dataset, feed_list=[image, label], - num_workers=0, + num_workers=num_workers, batch_size=BATCH_SIZE, drop_last=True, return_list=True) @@ -203,7 +204,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase): assert not isinstance(d[0], list) assert not isinstance(d[1], list) - def test_multi_place(self): + def run_multi_place(self, num_workers): scope = fluid.Scope() image = fluid.data(name='image', shape=[None, IMAGE_SIZE], @@ -213,7 +214,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase): dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) dataloader = DataLoader(dataset, feed_list=[image, label], - num_workers=0, + num_workers=num_workers, batch_size=BATCH_SIZE, places=[fluid.CPUPlace()] * 2, drop_last=True, @@ -225,6 +226,12 @@ class TestStaticDataLoaderReturnList(unittest.TestCase): assert isinstance(d[0], list) assert isinstance(d[1], list) + def test_main(self): + paddle.enable_static() + for num_workers in [0, 2]: + self.run_single_place(num_workers) + self.run_multi_place(num_workers) + class RandomBatchedDataset(Dataset):