From d626a675cb2ff413290a1e72dd2682fb3b99ca35 Mon Sep 17 00:00:00 2001 From: walloollaw <37680514+walloollaw@users.noreply.github.com> Date: Thu, 12 Dec 2019 18:55:14 +0800 Subject: [PATCH] Fix hang bug in parallel_map.py when multiprocessing mode used and some consumer exit abnormally (#106) --- ppdet/data/dataset.py | 5 +- ppdet/data/source/roidb_source.py | 4 +- ppdet/data/tests/test_dataset.py | 143 ++++++++++++++++++ ppdet/data/transform/parallel_map.py | 122 ++++++++++----- ppdet/data/transform/shared_queue/queue.py | 4 + .../transform/shared_queue/sharedmemory.py | 10 +- 6 files changed, 240 insertions(+), 48 deletions(-) create mode 100644 ppdet/data/tests/test_dataset.py diff --git a/ppdet/data/dataset.py b/ppdet/data/dataset.py index 31d4df4a0..0ee38d12f 100644 --- a/ppdet/data/dataset.py +++ b/ppdet/data/dataset.py @@ -25,6 +25,7 @@ class Dataset(object): def __init__(self): self._epoch = -1 + self._pos = 0 def __next__(self): return self.next() @@ -33,8 +34,8 @@ class Dataset(object): return self def __str__(self): - return "{}(fname:{}, epoch:{:d}, size:{:d}, pos:{:d})".format( - type(self).__name__, self._fname, self._epoch, + return "{}(epoch:{:d}, size:{:d}, pos:{:d})".format( + type(self).__name__, self._epoch, self.size(), self._pos) def next(self): diff --git a/ppdet/data/source/roidb_source.py b/ppdet/data/source/roidb_source.py index 4d898d951..44317974a 100644 --- a/ppdet/data/source/roidb_source.py +++ b/ppdet/data/source/roidb_source.py @@ -82,8 +82,8 @@ class RoiDbSource(Dataset): self._imid2path = None def __str__(self): - return 'RoiDbSource(fname:%s,epoch:%d,size:%d,pos:%d)' \ - % (self._fname, self._epoch, self.size(), self._pos) + return 'RoiDbSource(epoch:%d,size:%d,pos:%d,fname:%s)' \ + % (self._epoch, self.size(), self._pos, self._fname) def next(self): """ load next sample diff --git a/ppdet/data/tests/test_dataset.py b/ppdet/data/tests/test_dataset.py new file mode 100644 index 000000000..db99f464b --- /dev/null +++ b/ppdet/data/tests/test_dataset.py @@ -0,0 +1,143 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +import unittest +import sys +import logging +import random +import copy + +import set_env + +import ppdet.data.transform as tf +from ppdet.data.dataset import Dataset + +class MemorySource(Dataset): + """ memory data source for testing + """ + def __init__(self, samples): + super(MemorySource, self).__init__() + self._epoch = -1 + + self._pos = -1 + self._drained = False + self._samples = samples + + def next(self): + if self._epoch < 0: + self.reset() + + if self._pos >= self.size(): + self._drained = True + raise StopIteration("no more data in " + str(self)) + else: + sample = copy.deepcopy(self._samples[self._pos]) + self._pos += 1 + return sample + + def reset(self): + if self._epoch < 0: + self._epoch = 0 + else: + self._epoch += 1 + + self._pos = 0 + self._drained = False + random.shuffle(self._samples) + + def size(self): + return len(self._samples) + + def drained(self): + assert self._epoch >= 0, "the first epoch has not started yet" + return self._pos >= self.size() + + def epoch_id(self): + return self._epoch + + +class TestDataset(unittest.TestCase): + """Test cases for ppdet.data.dataset + """ + + @classmethod + def setUpClass(cls): + """ setup + """ + pass + + @classmethod + def tearDownClass(cls): + """ tearDownClass """ + pass + + def test_next(self): + """ test next + """ + samples = list(range(10)) + mem_sc = MemorySource(samples) + + for i, d in enumerate(mem_sc): + self.assertTrue(d in samples) + + def test_transform_with_abnormal_worker(self): + """ test dataset transform with abnormally exit process + """ + samples = list(range(1000)) + ds = MemorySource(samples) + + def _mapper(sample): + if sample == 3: + sys.exit(1) + + return 2 * sample + + worker_conf = {'WORKER_NUM': 2, 'use_process': True} + mapped = tf.map(ds, _mapper, worker_conf) + + ct = 0 + for i, d in enumerate(mapped): + ct += 1 + self.assertTrue(d / 2 in samples) + + self.assertEqual(len(samples) - 1, ct) + + def test_transform_with_delay_worker(self): + """ test dataset transform with delayed process + """ + samples = list(range(1000)) + ds = MemorySource(samples) + + def _mapper(sample): + if sample == 3: + time.sleep(30) + + return 2 * sample + + worker_conf = {'WORKER_NUM': 2, 'use_process': True} + mapped = tf.map(ds, _mapper, worker_conf) + + ct = 0 + for i, d in enumerate(mapped): + ct += 1 + self.assertTrue(d / 2 in samples) + + self.assertEqual(len(samples), ct) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main() + diff --git a/ppdet/data/transform/parallel_map.py b/ppdet/data/transform/parallel_map.py index a4dfa0556..bb4ff6e48 100644 --- a/ppdet/data/transform/parallel_map.py +++ b/ppdet/data/transform/parallel_map.py @@ -21,6 +21,11 @@ from __future__ import print_function import sys import six +if six.PY3: + from queue import Empty +else: + from Queue import Empty + import uuid import logging import signal @@ -31,15 +36,19 @@ logger = logging.getLogger(__name__) class EndSignal(object): - def __init__(self, errno=0, errmsg=''): + """ signal used to notify worker to exit + """ + def __init__(self, id, errno=0, errmsg=''): + self.id = id self.errno = errno self.errmsg = errmsg class ParallelMappedDataset(ProxiedDataset): """ - Transform samples to mapped samples which is similar to 'basic.MappedDataset', - but multiple workers (threads or processes) will be used + Transform samples to mapped samples which is similar to + 'basic.MappedDataset', but multiple workers (threads or processes) + will be used Notes: this class is not thread-safe @@ -58,9 +67,10 @@ class ParallelMappedDataset(ProxiedDataset): args.update(worker_args) if args['use_process'] and type(args['memsize']) is str: assert args['memsize'][-1].lower() == 'g', \ - "invalid param for memsize[%s], should be ended with 'G' or 'g'" % (args['memsize']) + "invalid param for memsize[{}], should " \ + "be ended with 'G' or 'g'".format(args['memsize']) gb = args['memsize'][:-1] - args['memsize'] = int(gb) * 1024**3 + args['memsize'] = int(gb) * 1024 ** 3 self._worker_args = args self._started = False @@ -103,22 +113,25 @@ class ParallelMappedDataset(ProxiedDataset): self._producer.daemon = True self._consumers = [] + self._consumer_endsig = {} for i in range(consumer_num): + consumer_id = 'consumer-' + id + '-' + str(i) p = Worker( target=self._consume, - args=('consumer-' + id + '_' + str(i), self._inq, self._outq, + args=(consumer_id, self._inq, self._outq, self._mapper)) self._consumers.append(p) p.daemon = True + setattr(p, 'id', consumer_id) self._epoch = -1 self._feeding_ev = Event() self._produced = 0 # produced sample in self._produce self._consumed = 0 # consumed sample in self.next - self._stopped_consumers = 0 def _produce(self, id, source, inq): """Fetch data from source and feed it to 'inq' queue""" + endsig = EndSignal(id) while True: self._feeding_ev.wait() if self._exit: @@ -128,32 +141,38 @@ class ParallelMappedDataset(ProxiedDataset): self._produced += 1 except StopIteration: self._feeding_ev.clear() - self._feeding_ev.wait() # wait other guy to wake up me - logger.debug("producer[{}] starts new epoch".format(id)) + self._feeding_ev.wait() except Exception as e: - msg = "producer[{}] failed with error: {}".format(id, str(e)) - inq.put(EndSignal(-1, msg)) + endsig.errno = -1 + endsig.errmsg = "producer[{}] failed with error: {}" \ + .format(id, str(e)) + inq.put(endsig) break - logger.debug("producer[{}] exits".format(id)) - def _consume(self, id, inq, outq, mapper): """Fetch data from 'inq', process it and put result to 'outq'""" + if self._worker_args['use_process']: + # handle SIGTERM signal to exit to prevent print stack frame + signal.signal(signal.SIGTERM, lambda signum, frame : sys.exit()) + + endsig = EndSignal(id) while True: sample = inq.get() if isinstance(sample, EndSignal): - sample.errmsg += "[consumer[{}] exits]".format(id) - outq.put(sample) - logger.debug("end signal received, " + - "consumer[{}] exits".format(id)) + endsig.errno = sample.errno + endsig.errmsg = "consumer[{}] exits for reason[{}]" \ + .format(id, sample.errmsg) + outq.put(endsig) break try: result = mapper(sample) outq.put(result) except Exception as e: - msg = 'failed to map consumer[%s], error: {}'.format(str(e), id) - outq.put(EndSignal(-1, msg)) + endsig.errno = -2 + endsig.errmsg = "consumer[{}] failed to map with error:[{}]" \ + .format(id, str(e)) + outq.put(endsig) break def drained(self): @@ -168,6 +187,25 @@ class ParallelMappedDataset(ProxiedDataset): for _ in range(len(self._consumers)): self._inq.put(EndSignal(0, "notify consumers to exit")) + def _consumer_healthy(self): + abnormal_num = 0 + for w in self._consumers: + if not w.is_alive() and w.id not in self._consumer_endsig: + abnormal_num += 1 + if self._worker_args['use_process']: + errmsg = "consumer[{}] exit abnormally with exitcode[{}]" \ + .format(w.pid, w.exitcode) + else: + errmsg = "consumer[{}] exit abnormally".format(w.ident) + + logger.warn(errmsg) + + if abnormal_num > 0: + logger.warn("{} consumers have exited abnormally!!!" \ + .format(abnormal_num)) + + return abnormal_num == 0 + def next(self): """ get next transformed sample """ @@ -177,41 +215,54 @@ class ParallelMappedDataset(ProxiedDataset): if self.drained(): raise StopIteration() - while True: - sample = self._outq.get() + while not self._exit: + try: + sample = self._outq.get(timeout=3) + except Empty as e: + if not self._consumer_healthy(): + raise StopIteration() + else: + continue + if isinstance(sample, EndSignal): - self._stopped_consumers += 1 - if sample.errno != 0: - logger.warn("consumer failed with error: {}".format( - sample.errmsg)) + self._consumer_endsig[sample.id] = sample + logger.warn("recv endsignal from outq with errmsg[{}]" \ + .format(sample.errmsg)) - if self._stopped_consumers < len(self._consumers): + if len(self._consumer_endsig.keys()) < len(self._consumers): self._inq.put(sample) else: - raise ValueError("all consumers exited, no more samples") + self._exit = True + raise StopIteration("all consumers exited, no more samples") else: self._consumed += 1 return sample + raise StopIteration() + def reset(self): """ reset for a new epoch of samples """ + assert not self._exit, "cannot reset for already stopped dataset" + if self._epoch < 0: self._epoch = 0 - for p in self._consumers: - p.start() + for w in self._consumers: + w.start() self._producer.start() else: + assert self._consumer_healthy(), "cannot start another pass of data" \ + " for some consumers exited abnormally before!!!" + if not self.drained(): - logger.warn("do not reset before epoch[%d] finishes".format( - self._epoch)) + logger.warn("reset before epoch[{}] finishes".format(self._epoch)) self._produced = self._produced - self._consumed else: self._produced = 0 self._epoch += 1 - assert self._stopped_consumers == 0, "some consumers already exited," \ + assert len(self._consumer_endsig.keys()) == 0, "some consumers already exited," \ + " cannot start another epoch" self._source.reset() @@ -221,9 +272,4 @@ class ParallelMappedDataset(ProxiedDataset): # FIXME(dengkaipeng): fix me if you have better impliment # handle terminate reader process, do not print stack frame -def _reader_exit(signum, frame): - logger.debug("Reader process exit.") - sys.exit() - - -signal.signal(signal.SIGTERM, _reader_exit) +signal.signal(signal.SIGTERM, lambda signum, frame : sys.exit()) diff --git a/ppdet/data/transform/shared_queue/queue.py b/ppdet/data/transform/shared_queue/queue.py index 0bd44d3e9..8f0ba8ab4 100644 --- a/ppdet/data/transform/shared_queue/queue.py +++ b/ppdet/data/transform/shared_queue/queue.py @@ -22,9 +22,11 @@ import six if six.PY3: import pickle from io import BytesIO as StringIO + from queue import Empty else: import cPickle as pickle from cStringIO import StringIO + from Queue import Empty import logging import traceback @@ -87,6 +89,8 @@ class SharedQueue(Queue): buff = super(SharedQueue, self).get(**kwargs) data = buff.get() return pickle.load(StringIO(data)) + except Empty as e: + raise e except Exception as e: stack_info = traceback.format_exc() err_msg = 'failed to get element from SharedQueue '\ diff --git a/ppdet/data/transform/shared_queue/sharedmemory.py b/ppdet/data/transform/shared_queue/sharedmemory.py index 933ab24c2..8b1d3ab40 100644 --- a/ppdet/data/transform/shared_queue/sharedmemory.py +++ b/ppdet/data/transform/shared_queue/sharedmemory.py @@ -316,8 +316,6 @@ class PageAllocator(object): start_pos = pos flags = '' while True: - # maybe flags already has some '0' pages, - # so just check 'page_num - len(flags)' pages flags = self.get_page_status(pos, page_num, ret_flag=True) if flags.count('0') == page_num: @@ -343,10 +341,10 @@ class PageAllocator(object): if free_pages == 0: err_msg = 'all pages have been used:%s' % (str(self)) else: - err_msg = 'not found available pages with page_status[%s] '\ - 'and %d free pages' % (str(page_status), free_pages) - err_msg = 'failed to malloc %d pages at pos[%d] for reason[%s] and allocator status[%s]' \ - % (page_num, pos, err_msg, str(self)) + err_msg = 'not found enough pages[avail:%d, expect:%d] '\ + 'with total free pages[%d]' % (page_status[0], page_num, free_pages) + err_msg = 'failed to malloc %d pages at pos[%d] for reason[%s] '\ + 'and allocator status[%s]' % (page_num, pos, err_msg, str(self)) raise MemoryFullError(err_msg) self.set_page_status(pos, page_num, '1') -- GitLab