From 4671d85a03a6ff501f32775ef88b8e0468ffb2a6 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Mon, 26 Oct 2020 20:15:23 +0800 Subject: [PATCH] fix DataLoader return same format between static & dynamic in single mode (#28176) * fix DataLoader return same format between static & dynamic in single mode. test=develop --- .../fluid/dataloader/dataloader_iter.py | 8 +++- .../test_multiprocess_dataloader_static.py | 45 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 7d203b349a1..d32a543eb49 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -341,7 +341,13 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): return self._reader.read_next_var_list() else: if self._return_list: - return self._reader.read_next_list() + # static graph organized data on multi-device with list, if + # place number is 1, there is only 1 device, extra the data + # from list for devices to be compatible with dygraph mode + if len(self._places) == 1: + return self._reader.read_next_list()[0] + else: + return self._reader.read_next_list() else: return self._reader.read_next() except StopIteration: diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py index 38497f91fc1..c01e2e75b81 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py @@ -170,5 +170,50 @@ class TestStaticDataLoader(unittest.TestCase): self.assertLess(diff, 1e-2) +class TestStaticDataLoaderReturnList(unittest.TestCase): + def test_single_place(self): + scope = fluid.Scope() + image = fluid.data( + name='image', shape=[None, IMAGE_SIZE], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + with fluid.scope_guard(scope): + dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) + dataloader = DataLoader( + dataset, + feed_list=[image, label], + num_workers=0, + batch_size=BATCH_SIZE, + drop_last=True, + return_list=True) + + for d in dataloader: + assert isinstance(d, list) + assert len(d) == 2 + assert not isinstance(d[0], list) + assert not isinstance(d[1], list) + + def test_multi_place(self): + scope = fluid.Scope() + image = fluid.data( + name='image', shape=[None, IMAGE_SIZE], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + with fluid.scope_guard(scope): + dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) + dataloader = DataLoader( + dataset, + feed_list=[image, label], + num_workers=0, + batch_size=BATCH_SIZE, + places=[fluid.CPUPlace()] * 2, + drop_last=True, + return_list=True) + + for d in dataloader: + assert isinstance(d, list) + assert len(d) == 2 + assert isinstance(d[0], list) + assert isinstance(d[1], list) + + if __name__ == '__main__': unittest.main() -- GitLab