From 91bab752a94f0e67ea5817287e61e9675541f500 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Fri, 20 Nov 2020 20:39:52 +0800 Subject: [PATCH] fix dataloader default value and doc (#28728) * fix dataloader. test=develop --- python/paddle/fluid/reader.py | 10 +++++----- ..._multiprocess_dataloader_iterable_dataset_static.py | 2 ++ .../unittests/test_multiprocess_dataloader_static.py | 5 ++++- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 4a50b3bc0c7..ac924580e17 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -196,7 +196,7 @@ class DataLoader(object): the key of the dict is the name of each fed variables. If :attr:`return_list=True`, the return value on each device would be a list(Tensor). :attr:`return_list` can only be True - in dynamic graph mode. Default False. + in dynamic graph mode. Default True. batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler` to generate batch indices to draw samples from :attr:`dataset` and combine a batch. Default None. @@ -308,7 +308,7 @@ class DataLoader(object): dataset, feed_list=None, places=None, - return_list=False, + return_list=True, batch_sampler=None, batch_size=1, shuffle=False, @@ -403,10 +403,10 @@ class DataLoader(object): if self.dataset_kind == _DatasetKind.ITER: raise ValueError("length of IterableDataset not supported") else: - if self.batch_size is None: - return len(self.dataset) - else: + if self.auto_collate_batch: return len(self.batch_sampler) + else: + return len(self.dataset) def __iter__(self): if self.num_workers == 0: diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py index 4615bf85ce6..fe66f173354 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py @@ -112,6 +112,7 @@ class TestStaticDataLoader(unittest.TestCase): places=places, num_workers=num_workers, batch_size=BATCH_SIZE, + return_list=False, drop_last=True) # assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) @@ -199,6 +200,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): places=places, num_workers=num_workers, batch_size=None, + return_list=False, drop_last=True) exe = fluid.Executor(place=places[0]) diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py index 5ec907c290b..8fd250f2a52 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py @@ -113,6 +113,7 @@ class TestStaticDataLoader(unittest.TestCase): places=places, num_workers=num_workers, batch_size=BATCH_SIZE, + return_list=False, drop_last=True) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) @@ -226,7 +227,8 @@ class RandomBatchedDataset(Dataset): labels = [] for _ in range(BATCH_SIZE): image = np.random.random([IMAGE_SIZE]).astype('float32') - label = np.random.randint(0, self.class_num - 1, (1, )).astype('int64') + label = np.random.randint(0, self.class_num - 1, + (1, )).astype('int64') images.append(image) labels.append(label) return np.stack(images, axis=0), np.stack(labels, axis=0) @@ -248,6 +250,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): places=places, num_workers=num_workers, batch_size=None, + return_list=False, drop_last=True) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) -- GitLab