diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 5cb831eee3a4b0497419ae5eec2972b4cda9e60b..a81d73d7e9a621d2a02ed91541f32b827bdff38c 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -359,6 +359,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): self._outstanding_capacity = 2 * max(self._num_workers, len(self._places)) + # see _try_put_indices + self._thread_lock = threading.Lock() + # init workers and indices queues and put 2 indices in each indices queue self._init_workers() for _ in range(self._outstanding_capacity): @@ -660,22 +663,32 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): def _try_put_indices(self): assert self._batches_outstanding <= self._outstanding_capacity, \ "too many indices have been put to queue" - try: - indices = next(self._sampler_iter) - except StopIteration: - return + # In multi-process mode for IterableDataset, _try_put_indices will + # be called both in main process(for our implement has blocking queue, + # and blocking queue read is in main process) and thread, which may + # cause error following error + # 1. "ValueError: generator already executing" in next(self._sampler_iter) + # 2. re-enter in increase _send_idx + # add a lock for threading save, for _try_put_indices is only a slight + # function which is not in data reading pipeline, this lock almost no + # influence on performance + with self._thread_lock: + try: + indices = next(self._sampler_iter) + except StopIteration: + return - for i in range(self._num_workers): - worker_idx = next(self._workers_idx_cycle) - if self._worker_status[worker_idx]: - break - else: - return + for i in range(self._num_workers): + worker_idx = next(self._workers_idx_cycle) + if self._worker_status[worker_idx]: + break + else: + return - self._indices_queues[worker_idx].put((self._send_idx, indices)) - self._task_infos[self._send_idx] = (worker_idx, ) - self._batches_outstanding += 1 - self._send_idx += 1 + self._indices_queues[worker_idx].put((self._send_idx, indices)) + self._task_infos[self._send_idx] = (worker_idx, ) + self._batches_outstanding += 1 + self._send_idx += 1 def __del__(self): self._try_shutdown_all()