diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 069dff28ccf9b771cfe6621e79004caa16dcfafb..cc98d378f148949e4c443c4494e3936ed6a34e09 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -59,6 +59,7 @@ class _DataLoaderIterBase(object): self._places = loader.places self._return_list = loader.return_list self._batch_sampler = loader.batch_sampler + self._drop_last = loader.drop_last self._auto_collate_batch = loader.auto_collate_batch self._num_workers = loader.num_workers self._use_buffer_reader = loader.use_buffer_reader @@ -111,7 +112,7 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, self._auto_collate_batch, - self._collate_fn, True) + self._collate_fn, self._drop_last) # NOTE: _structrue_infos used to record the data structure of # batch to restore batch structure after reading Tensor @@ -309,8 +310,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): args=(self._dataset, self._dataset_kind, indices_queue, self._data_queue, self._workers_done_event, self._auto_collate_batch, self._collate_fn, - self._worker_init_fn, i, self._num_workers, - self._use_shared_memory)) + self._drop_last, self._worker_init_fn, i, + self._num_workers, self._use_shared_memory)) worker.daemon = True worker.start() self._workers.append(worker) diff --git a/python/paddle/fluid/dataloader/worker.py b/python/paddle/fluid/dataloader/worker.py index 66ca4150460d7485504f955206cceeb551e5725b..622f85cf65ab79945041b4a2b9648c72e876e72b 100644 --- a/python/paddle/fluid/dataloader/worker.py +++ b/python/paddle/fluid/dataloader/worker.py @@ -253,7 +253,7 @@ def _generate_states(base_seed=0, worker_id=0): def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, - auto_collate_batch, collate_fn, init_fn, worker_id, + auto_collate_batch, collate_fn, drop_last, init_fn, worker_id, num_workers, use_shared_memory): try: # NOTE: [ mmap files clear ] When the child process exits unexpectedly, @@ -282,8 +282,9 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, try: if init_fn is not None: init_fn(worker_id) - fetcher = _DatasetKind.create_fetcher( - dataset_kind, dataset, auto_collate_batch, collate_fn, True) + fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, + auto_collate_batch, + collate_fn, drop_last) except: init_exception = _WorkerException(worker_id) diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 7076ef22ba605c28605334882810b3497b2e3c09..dfc887292e7cff408a4e87ffcf8e4e90523b1e84 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -401,6 +401,7 @@ class DataLoader(object): shuffle=shuffle, drop_last=drop_last) + self.drop_last = drop_last self.auto_collate_batch = self.batch_sampler is not None self.pin_memory = False diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py index 30e70a77c369c1f79f00c7c971b06b1f6bfc4a2d..8f1febcdeddf71a97e6d25104618f15eb0beced6 100755 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py @@ -397,5 +397,30 @@ class TestDataLoaderGenerateStates(unittest.TestCase): assert out == outp +class TestDatasetWithDropLast(unittest.TestCase): + def run_main(self, dataset, num_samples, batch_size): + for num_workers in [0, 1]: + for drop_last in [True, False]: + steps = (num_samples + (1 - int(drop_last)) * \ + (batch_size - 1)) // batch_size + dataloader = DataLoader( + dataset, + batch_size=batch_size, + drop_last=drop_last, + num_workers=num_workers) + datas = [] + for data in dataloader: + datas.append(data) + assert len(datas) == steps + + def test_map_dataset(self): + dataset = RandomDataset(10) + self.run_main(dataset, 10, 3) + + def test_iterable_dataset(self): + dataset = RandomIterableDataset(10) + self.run_main(dataset, 10, 3) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py index 1bda6edfecf1c78526d25b1428f33505b3a58b2c..52f4c2567730f58dd7dfab2f7e309d1de2d1e90d 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py @@ -180,7 +180,7 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): indices_queue.put(None) _worker_loop(loader._dataset, 0, indices_queue, loader._data_queue, loader._workers_done_event, - True, _collate_fn, _init_fn, 0, 1, + True, _collate_fn, True, _init_fn, 0, 1, loader._use_shared_memory) self.assertTrue(False) except AssertionError: @@ -224,7 +224,7 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): loader._workers_done_event.set() _worker_loop(loader._dataset, 0, indices_queue, loader._data_queue, loader._workers_done_event, - True, _collate_fn, _init_fn, 0, 1, + True, _collate_fn, True, _init_fn, 0, 1, loader._use_shared_memory) self.assertTrue(True) except AssertionError: