未验证 提交 16146088 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix drop_last not work on IterableDataset (#34801)

* fix drop_last not work in IterableDataset. test=develop
上级 181f7cec
...@@ -59,6 +59,7 @@ class _DataLoaderIterBase(object): ...@@ -59,6 +59,7 @@ class _DataLoaderIterBase(object):
self._places = loader.places self._places = loader.places
self._return_list = loader.return_list self._return_list = loader.return_list
self._batch_sampler = loader.batch_sampler self._batch_sampler = loader.batch_sampler
self._drop_last = loader.drop_last
self._auto_collate_batch = loader.auto_collate_batch self._auto_collate_batch = loader.auto_collate_batch
self._num_workers = loader.num_workers self._num_workers = loader.num_workers
self._use_buffer_reader = loader.use_buffer_reader self._use_buffer_reader = loader.use_buffer_reader
...@@ -111,7 +112,7 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -111,7 +112,7 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collate_batch, self._dataset_kind, self._dataset, self._auto_collate_batch,
self._collate_fn, True) self._collate_fn, self._drop_last)
# NOTE: _structrue_infos used to record the data structure of # NOTE: _structrue_infos used to record the data structure of
# batch to restore batch structure after reading Tensor # batch to restore batch structure after reading Tensor
...@@ -309,8 +310,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -309,8 +310,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
args=(self._dataset, self._dataset_kind, indices_queue, args=(self._dataset, self._dataset_kind, indices_queue,
self._data_queue, self._workers_done_event, self._data_queue, self._workers_done_event,
self._auto_collate_batch, self._collate_fn, self._auto_collate_batch, self._collate_fn,
self._worker_init_fn, i, self._num_workers, self._drop_last, self._worker_init_fn, i,
self._use_shared_memory)) self._num_workers, self._use_shared_memory))
worker.daemon = True worker.daemon = True
worker.start() worker.start()
self._workers.append(worker) self._workers.append(worker)
......
...@@ -253,7 +253,7 @@ def _generate_states(base_seed=0, worker_id=0): ...@@ -253,7 +253,7 @@ def _generate_states(base_seed=0, worker_id=0):
def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
auto_collate_batch, collate_fn, init_fn, worker_id, auto_collate_batch, collate_fn, drop_last, init_fn, worker_id,
num_workers, use_shared_memory): num_workers, use_shared_memory):
try: try:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly, # NOTE: [ mmap files clear ] When the child process exits unexpectedly,
...@@ -282,8 +282,9 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, ...@@ -282,8 +282,9 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
try: try:
if init_fn is not None: if init_fn is not None:
init_fn(worker_id) init_fn(worker_id)
fetcher = _DatasetKind.create_fetcher( fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset,
dataset_kind, dataset, auto_collate_batch, collate_fn, True) auto_collate_batch,
collate_fn, drop_last)
except: except:
init_exception = _WorkerException(worker_id) init_exception = _WorkerException(worker_id)
......
...@@ -401,6 +401,7 @@ class DataLoader(object): ...@@ -401,6 +401,7 @@ class DataLoader(object):
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
self.drop_last = drop_last
self.auto_collate_batch = self.batch_sampler is not None self.auto_collate_batch = self.batch_sampler is not None
self.pin_memory = False self.pin_memory = False
......
...@@ -397,5 +397,30 @@ class TestDataLoaderGenerateStates(unittest.TestCase): ...@@ -397,5 +397,30 @@ class TestDataLoaderGenerateStates(unittest.TestCase):
assert out == outp assert out == outp
class TestDatasetWithDropLast(unittest.TestCase):
def run_main(self, dataset, num_samples, batch_size):
for num_workers in [0, 1]:
for drop_last in [True, False]:
steps = (num_samples + (1 - int(drop_last)) * \
(batch_size - 1)) // batch_size
dataloader = DataLoader(
dataset,
batch_size=batch_size,
drop_last=drop_last,
num_workers=num_workers)
datas = []
for data in dataloader:
datas.append(data)
assert len(datas) == steps
def test_map_dataset(self):
dataset = RandomDataset(10)
self.run_main(dataset, 10, 3)
def test_iterable_dataset(self):
dataset = RandomIterableDataset(10)
self.run_main(dataset, 10, 3)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -180,7 +180,7 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): ...@@ -180,7 +180,7 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
indices_queue.put(None) indices_queue.put(None)
_worker_loop(loader._dataset, 0, indices_queue, _worker_loop(loader._dataset, 0, indices_queue,
loader._data_queue, loader._workers_done_event, loader._data_queue, loader._workers_done_event,
True, _collate_fn, _init_fn, 0, 1, True, _collate_fn, True, _init_fn, 0, 1,
loader._use_shared_memory) loader._use_shared_memory)
self.assertTrue(False) self.assertTrue(False)
except AssertionError: except AssertionError:
...@@ -224,7 +224,7 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): ...@@ -224,7 +224,7 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
loader._workers_done_event.set() loader._workers_done_event.set()
_worker_loop(loader._dataset, 0, indices_queue, _worker_loop(loader._dataset, 0, indices_queue,
loader._data_queue, loader._workers_done_event, loader._data_queue, loader._workers_done_event,
True, _collate_fn, _init_fn, 0, 1, True, _collate_fn, True, _init_fn, 0, 1,
loader._use_shared_memory) loader._use_shared_memory)
self.assertTrue(True) self.assertTrue(True)
except AssertionError: except AssertionError:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册