diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 7a4e3cb6ec80cd6fcdb638d2ac47578a946e3bb0..56c5b989d16d707b1383d27a36fd795e97fe0e15 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -341,7 +341,13 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): return self._reader.read_next_var_list() else: if self._return_list: - return self._reader.read_next_list() + # static graph organized data on multi-device with list, if + # place number is 1, there is only 1 device, extra the data + # from list for devices to be compatible with dygraph mode + if len(self._places) == 1: + return self._reader.read_next_list()[0] + else: + return self._reader.read_next_list() else: return self._reader.read_next() except StopIteration: diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py index 38497f91fc18847e40efa691a65c2a7adc20e51c..c01e2e75b8195c0b2f2c46a6d18969055a68f977 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py @@ -170,5 +170,50 @@ class TestStaticDataLoader(unittest.TestCase): self.assertLess(diff, 1e-2) +class TestStaticDataLoaderReturnList(unittest.TestCase): + def test_single_place(self): + scope = fluid.Scope() + image = fluid.data( + name='image', shape=[None, IMAGE_SIZE], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + with fluid.scope_guard(scope): + dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) + dataloader = DataLoader( + dataset, + feed_list=[image, label], + num_workers=0, + batch_size=BATCH_SIZE, + drop_last=True, + return_list=True) + + for d in dataloader: + assert isinstance(d, list) + assert len(d) == 2 + assert not isinstance(d[0], list) + assert not isinstance(d[1], list) + + def test_multi_place(self): + scope = fluid.Scope() + image = fluid.data( + name='image', shape=[None, IMAGE_SIZE], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + with fluid.scope_guard(scope): + dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) + dataloader = DataLoader( + dataset, + feed_list=[image, label], + num_workers=0, + batch_size=BATCH_SIZE, + places=[fluid.CPUPlace()] * 2, + drop_last=True, + return_list=True) + + for d in dataloader: + assert isinstance(d, list) + assert len(d) == 2 + assert isinstance(d[0], list) + assert isinstance(d[1], list) + + if __name__ == '__main__': unittest.main()