未验证 提交 7051bbc2 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix test_multiprocess_dataloader unittest. test=develop (#26241)

* fix test_multiprocess_dataloader unittest. test=develop
上级 9a6a4fbc
...@@ -359,6 +359,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -359,6 +359,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._outstanding_capacity = 2 * max(self._num_workers, self._outstanding_capacity = 2 * max(self._num_workers,
len(self._places)) len(self._places))
# see _try_put_indices
self._thread_lock = threading.Lock()
# init workers and indices queues and put 2 indices in each indices queue # init workers and indices queues and put 2 indices in each indices queue
self._init_workers() self._init_workers()
for _ in range(self._outstanding_capacity): for _ in range(self._outstanding_capacity):
...@@ -660,22 +663,32 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -660,22 +663,32 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
def _try_put_indices(self): def _try_put_indices(self):
assert self._batches_outstanding <= self._outstanding_capacity, \ assert self._batches_outstanding <= self._outstanding_capacity, \
"too many indices have been put to queue" "too many indices have been put to queue"
try: # In multi-process mode for IterableDataset, _try_put_indices will
indices = next(self._sampler_iter) # be called both in main process(for our implement has blocking queue,
except StopIteration: # and blocking queue read is in main process) and thread, which may
return # cause error following error
# 1. "ValueError: generator already executing" in next(self._sampler_iter)
# 2. re-enter in increase _send_idx
# add a lock for threading save, for _try_put_indices is only a slight
# function which is not in data reading pipeline, this lock almost no
# influence on performance
with self._thread_lock:
try:
indices = next(self._sampler_iter)
except StopIteration:
return
for i in range(self._num_workers): for i in range(self._num_workers):
worker_idx = next(self._workers_idx_cycle) worker_idx = next(self._workers_idx_cycle)
if self._worker_status[worker_idx]: if self._worker_status[worker_idx]:
break break
else: else:
return return
self._indices_queues[worker_idx].put((self._send_idx, indices)) self._indices_queues[worker_idx].put((self._send_idx, indices))
self._task_infos[self._send_idx] = (worker_idx, ) self._task_infos[self._send_idx] = (worker_idx, )
self._batches_outstanding += 1 self._batches_outstanding += 1
self._send_idx += 1 self._send_idx += 1
def __del__(self): def __del__(self):
self._try_shutdown_all() self._try_shutdown_all()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册