diff --git a/ppdet/data/parallel_map.py b/ppdet/data/parallel_map.py index 8d1bc6d0b24a8fbe1c245f7841c13f546f86897c..789fda1f2ed1c18a162e287562bbf09315c5211e 100644 --- a/ppdet/data/parallel_map.py +++ b/ppdet/data/parallel_map.py @@ -119,7 +119,7 @@ class ParallelMap(object): self._producer = threading.Thread( target=self._produce, args=('producer-' + id, self._source, self._inq)) - self._producer.daemon = False + self._producer.daemon = True self._consumers = [] self._consumer_endsig = {} @@ -130,7 +130,7 @@ class ParallelMap(object): target=self._consume, args=(consumer_id, self._inq, self._outq, self._worker)) self._consumers.append(p) - p.daemon = use_process + p.daemon = True setattr(p, 'id', consumer_id) if use_process: worker_set.add(p) diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index dd923506ffc8170b543c450cce29d94029ae5ca8..4f4678800c74ffc0cfa73b933336462fcf5e5b42 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -304,9 +304,11 @@ class Reader(object): if self._epoch < 0: self.reset() if self.drained(): + self.stop() raise StopIteration batch = self._load_batch() if self._drop_last and len(batch) < self._batch_size: + self.stop() raise StopIteration if self._worker_num > -1: return batch @@ -418,8 +420,8 @@ def create_reader(cfg, max_iter=0, global_cfg=None, devices_num=1): n += 1 if max_iter > 0 and n == max_iter: return - reader.reset() if max_iter <= 0: return + reader.reset() return _reader