From a991b6a0869290c6ece7f0cd2348f86273146b3c Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Wed, 16 Mar 2022 09:39:11 +0800 Subject: [PATCH] fix IterableDataset may block model when num_workers > 0. test=develop (#40541) --- python/paddle/fluid/dataloader/dataloader_iter.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 706ec0d523..5385ac28b9 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -564,6 +564,14 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): self._rcvd_idx += 1 self._batches_outstanding -= 1 else: + # NOTE: when _rcvd_idx catch up _send_idx, which means + # one of following: + # 1. all 2 * num_workers batches have been loaded + # and stored in _blocking_queue + # 2. all data drained + # we need to let _thread blocking at _data_queue + # get_data to inoccupy CPU, otherwise may occupy + # CPU time for model running # NOTE: in persistent workers mode, do not check data # drained here, simply let it go to _data_queue # reading to get _ResumeIteration @@ -573,7 +581,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): # may also be data in blocking queue if self._batches_outstanding < len(self._places): return None - continue if self._rcvd_idx in self._task_infos and \ len(self._task_infos[self._rcvd_idx]) == 3: -- GitLab