diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 4a50b3bc0c7dc581febdaf12ab70c663c0263377..ac924580e17e8885346cab51c696dddb43279ba6 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -196,7 +196,7 @@ class DataLoader(object): the key of the dict is the name of each fed variables. If :attr:`return_list=True`, the return value on each device would be a list(Tensor). :attr:`return_list` can only be True - in dynamic graph mode. Default False. + in dynamic graph mode. Default True. batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler` to generate batch indices to draw samples from :attr:`dataset` and combine a batch. Default None. @@ -308,7 +308,7 @@ class DataLoader(object): dataset, feed_list=None, places=None, - return_list=False, + return_list=True, batch_sampler=None, batch_size=1, shuffle=False, @@ -403,10 +403,10 @@ class DataLoader(object): if self.dataset_kind == _DatasetKind.ITER: raise ValueError("length of IterableDataset not supported") else: - if self.batch_size is None: - return len(self.dataset) - else: + if self.auto_collate_batch: return len(self.batch_sampler) + else: + return len(self.dataset) def __iter__(self): if self.num_workers == 0: diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py index 4615bf85ce69f45cf6d852cdb17e5dff1d533979..fe66f1733546bd3e9296f3355982ae5fa633a7aa 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py @@ -112,6 +112,7 @@ class TestStaticDataLoader(unittest.TestCase): places=places, num_workers=num_workers, batch_size=BATCH_SIZE, + return_list=False, drop_last=True) # assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) @@ -199,6 +200,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): places=places, num_workers=num_workers, batch_size=None, + return_list=False, drop_last=True) exe = fluid.Executor(place=places[0]) diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py index 5ec907c290b946dc79ee6f865c42fd8094afbb35..8fd250f2a52c27e2e2fa6a6a6c917e55af292d6c 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py @@ -113,6 +113,7 @@ class TestStaticDataLoader(unittest.TestCase): places=places, num_workers=num_workers, batch_size=BATCH_SIZE, + return_list=False, drop_last=True) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) @@ -226,7 +227,8 @@ class RandomBatchedDataset(Dataset): labels = [] for _ in range(BATCH_SIZE): image = np.random.random([IMAGE_SIZE]).astype('float32') - label = np.random.randint(0, self.class_num - 1, (1, )).astype('int64') + label = np.random.randint(0, self.class_num - 1, + (1, )).astype('int64') images.append(image) labels.append(label) return np.stack(images, axis=0), np.stack(labels, axis=0) @@ -248,6 +250,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): places=places, num_workers=num_workers, batch_size=None, + return_list=False, drop_last=True) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)