提交 d626a675 编写于 作者: W walloollaw 提交者: qingqing01

Fix hang bug in parallel_map.py when multiprocessing mode used and some...

Fix hang bug in parallel_map.py when multiprocessing mode used and some consumer exit abnormally (#106)
上级 cbdd3bc5
...@@ -25,6 +25,7 @@ class Dataset(object): ...@@ -25,6 +25,7 @@ class Dataset(object):
def __init__(self): def __init__(self):
self._epoch = -1 self._epoch = -1
self._pos = 0
def __next__(self): def __next__(self):
return self.next() return self.next()
...@@ -33,8 +34,8 @@ class Dataset(object): ...@@ -33,8 +34,8 @@ class Dataset(object):
return self return self
def __str__(self): def __str__(self):
return "{}(fname:{}, epoch:{:d}, size:{:d}, pos:{:d})".format( return "{}(epoch:{:d}, size:{:d}, pos:{:d})".format(
type(self).__name__, self._fname, self._epoch, type(self).__name__, self._epoch,
self.size(), self._pos) self.size(), self._pos)
def next(self): def next(self):
......
...@@ -82,8 +82,8 @@ class RoiDbSource(Dataset): ...@@ -82,8 +82,8 @@ class RoiDbSource(Dataset):
self._imid2path = None self._imid2path = None
def __str__(self): def __str__(self):
return 'RoiDbSource(fname:%s,epoch:%d,size:%d,pos:%d)' \ return 'RoiDbSource(epoch:%d,size:%d,pos:%d,fname:%s)' \
% (self._fname, self._epoch, self.size(), self._pos) % (self._epoch, self.size(), self._pos, self._fname)
def next(self): def next(self):
""" load next sample """ load next sample
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import unittest
import sys
import logging
import random
import copy
import set_env
import ppdet.data.transform as tf
from ppdet.data.dataset import Dataset
class MemorySource(Dataset):
""" memory data source for testing
"""
def __init__(self, samples):
super(MemorySource, self).__init__()
self._epoch = -1
self._pos = -1
self._drained = False
self._samples = samples
def next(self):
if self._epoch < 0:
self.reset()
if self._pos >= self.size():
self._drained = True
raise StopIteration("no more data in " + str(self))
else:
sample = copy.deepcopy(self._samples[self._pos])
self._pos += 1
return sample
def reset(self):
if self._epoch < 0:
self._epoch = 0
else:
self._epoch += 1
self._pos = 0
self._drained = False
random.shuffle(self._samples)
def size(self):
return len(self._samples)
def drained(self):
assert self._epoch >= 0, "the first epoch has not started yet"
return self._pos >= self.size()
def epoch_id(self):
return self._epoch
class TestDataset(unittest.TestCase):
"""Test cases for ppdet.data.dataset
"""
@classmethod
def setUpClass(cls):
""" setup
"""
pass
@classmethod
def tearDownClass(cls):
""" tearDownClass """
pass
def test_next(self):
""" test next
"""
samples = list(range(10))
mem_sc = MemorySource(samples)
for i, d in enumerate(mem_sc):
self.assertTrue(d in samples)
def test_transform_with_abnormal_worker(self):
""" test dataset transform with abnormally exit process
"""
samples = list(range(1000))
ds = MemorySource(samples)
def _mapper(sample):
if sample == 3:
sys.exit(1)
return 2 * sample
worker_conf = {'WORKER_NUM': 2, 'use_process': True}
mapped = tf.map(ds, _mapper, worker_conf)
ct = 0
for i, d in enumerate(mapped):
ct += 1
self.assertTrue(d / 2 in samples)
self.assertEqual(len(samples) - 1, ct)
def test_transform_with_delay_worker(self):
""" test dataset transform with delayed process
"""
samples = list(range(1000))
ds = MemorySource(samples)
def _mapper(sample):
if sample == 3:
time.sleep(30)
return 2 * sample
worker_conf = {'WORKER_NUM': 2, 'use_process': True}
mapped = tf.map(ds, _mapper, worker_conf)
ct = 0
for i, d in enumerate(mapped):
ct += 1
self.assertTrue(d / 2 in samples)
self.assertEqual(len(samples), ct)
if __name__ == '__main__':
logging.basicConfig()
unittest.main()
...@@ -21,6 +21,11 @@ from __future__ import print_function ...@@ -21,6 +21,11 @@ from __future__ import print_function
import sys import sys
import six import six
if six.PY3:
from queue import Empty
else:
from Queue import Empty
import uuid import uuid
import logging import logging
import signal import signal
...@@ -31,15 +36,19 @@ logger = logging.getLogger(__name__) ...@@ -31,15 +36,19 @@ logger = logging.getLogger(__name__)
class EndSignal(object): class EndSignal(object):
def __init__(self, errno=0, errmsg=''): """ signal used to notify worker to exit
"""
def __init__(self, id, errno=0, errmsg=''):
self.id = id
self.errno = errno self.errno = errno
self.errmsg = errmsg self.errmsg = errmsg
class ParallelMappedDataset(ProxiedDataset): class ParallelMappedDataset(ProxiedDataset):
""" """
Transform samples to mapped samples which is similar to 'basic.MappedDataset', Transform samples to mapped samples which is similar to
but multiple workers (threads or processes) will be used 'basic.MappedDataset', but multiple workers (threads or processes)
will be used
Notes: Notes:
this class is not thread-safe this class is not thread-safe
...@@ -58,9 +67,10 @@ class ParallelMappedDataset(ProxiedDataset): ...@@ -58,9 +67,10 @@ class ParallelMappedDataset(ProxiedDataset):
args.update(worker_args) args.update(worker_args)
if args['use_process'] and type(args['memsize']) is str: if args['use_process'] and type(args['memsize']) is str:
assert args['memsize'][-1].lower() == 'g', \ assert args['memsize'][-1].lower() == 'g', \
"invalid param for memsize[%s], should be ended with 'G' or 'g'" % (args['memsize']) "invalid param for memsize[{}], should " \
"be ended with 'G' or 'g'".format(args['memsize'])
gb = args['memsize'][:-1] gb = args['memsize'][:-1]
args['memsize'] = int(gb) * 1024**3 args['memsize'] = int(gb) * 1024 ** 3
self._worker_args = args self._worker_args = args
self._started = False self._started = False
...@@ -103,22 +113,25 @@ class ParallelMappedDataset(ProxiedDataset): ...@@ -103,22 +113,25 @@ class ParallelMappedDataset(ProxiedDataset):
self._producer.daemon = True self._producer.daemon = True
self._consumers = [] self._consumers = []
self._consumer_endsig = {}
for i in range(consumer_num): for i in range(consumer_num):
consumer_id = 'consumer-' + id + '-' + str(i)
p = Worker( p = Worker(
target=self._consume, target=self._consume,
args=('consumer-' + id + '_' + str(i), self._inq, self._outq, args=(consumer_id, self._inq, self._outq,
self._mapper)) self._mapper))
self._consumers.append(p) self._consumers.append(p)
p.daemon = True p.daemon = True
setattr(p, 'id', consumer_id)
self._epoch = -1 self._epoch = -1
self._feeding_ev = Event() self._feeding_ev = Event()
self._produced = 0 # produced sample in self._produce self._produced = 0 # produced sample in self._produce
self._consumed = 0 # consumed sample in self.next self._consumed = 0 # consumed sample in self.next
self._stopped_consumers = 0
def _produce(self, id, source, inq): def _produce(self, id, source, inq):
"""Fetch data from source and feed it to 'inq' queue""" """Fetch data from source and feed it to 'inq' queue"""
endsig = EndSignal(id)
while True: while True:
self._feeding_ev.wait() self._feeding_ev.wait()
if self._exit: if self._exit:
...@@ -128,32 +141,38 @@ class ParallelMappedDataset(ProxiedDataset): ...@@ -128,32 +141,38 @@ class ParallelMappedDataset(ProxiedDataset):
self._produced += 1 self._produced += 1
except StopIteration: except StopIteration:
self._feeding_ev.clear() self._feeding_ev.clear()
self._feeding_ev.wait() # wait other guy to wake up me self._feeding_ev.wait()
logger.debug("producer[{}] starts new epoch".format(id))
except Exception as e: except Exception as e:
msg = "producer[{}] failed with error: {}".format(id, str(e)) endsig.errno = -1
inq.put(EndSignal(-1, msg)) endsig.errmsg = "producer[{}] failed with error: {}" \
.format(id, str(e))
inq.put(endsig)
break break
logger.debug("producer[{}] exits".format(id))
def _consume(self, id, inq, outq, mapper): def _consume(self, id, inq, outq, mapper):
"""Fetch data from 'inq', process it and put result to 'outq'""" """Fetch data from 'inq', process it and put result to 'outq'"""
if self._worker_args['use_process']:
# handle SIGTERM signal to exit to prevent print stack frame
signal.signal(signal.SIGTERM, lambda signum, frame : sys.exit())
endsig = EndSignal(id)
while True: while True:
sample = inq.get() sample = inq.get()
if isinstance(sample, EndSignal): if isinstance(sample, EndSignal):
sample.errmsg += "[consumer[{}] exits]".format(id) endsig.errno = sample.errno
outq.put(sample) endsig.errmsg = "consumer[{}] exits for reason[{}]" \
logger.debug("end signal received, " + .format(id, sample.errmsg)
"consumer[{}] exits".format(id)) outq.put(endsig)
break break
try: try:
result = mapper(sample) result = mapper(sample)
outq.put(result) outq.put(result)
except Exception as e: except Exception as e:
msg = 'failed to map consumer[%s], error: {}'.format(str(e), id) endsig.errno = -2
outq.put(EndSignal(-1, msg)) endsig.errmsg = "consumer[{}] failed to map with error:[{}]" \
.format(id, str(e))
outq.put(endsig)
break break
def drained(self): def drained(self):
...@@ -168,6 +187,25 @@ class ParallelMappedDataset(ProxiedDataset): ...@@ -168,6 +187,25 @@ class ParallelMappedDataset(ProxiedDataset):
for _ in range(len(self._consumers)): for _ in range(len(self._consumers)):
self._inq.put(EndSignal(0, "notify consumers to exit")) self._inq.put(EndSignal(0, "notify consumers to exit"))
def _consumer_healthy(self):
abnormal_num = 0
for w in self._consumers:
if not w.is_alive() and w.id not in self._consumer_endsig:
abnormal_num += 1
if self._worker_args['use_process']:
errmsg = "consumer[{}] exit abnormally with exitcode[{}]" \
.format(w.pid, w.exitcode)
else:
errmsg = "consumer[{}] exit abnormally".format(w.ident)
logger.warn(errmsg)
if abnormal_num > 0:
logger.warn("{} consumers have exited abnormally!!!" \
.format(abnormal_num))
return abnormal_num == 0
def next(self): def next(self):
""" get next transformed sample """ get next transformed sample
""" """
...@@ -177,41 +215,54 @@ class ParallelMappedDataset(ProxiedDataset): ...@@ -177,41 +215,54 @@ class ParallelMappedDataset(ProxiedDataset):
if self.drained(): if self.drained():
raise StopIteration() raise StopIteration()
while True: while not self._exit:
sample = self._outq.get() try:
sample = self._outq.get(timeout=3)
except Empty as e:
if not self._consumer_healthy():
raise StopIteration()
else:
continue
if isinstance(sample, EndSignal): if isinstance(sample, EndSignal):
self._stopped_consumers += 1 self._consumer_endsig[sample.id] = sample
if sample.errno != 0: logger.warn("recv endsignal from outq with errmsg[{}]" \
logger.warn("consumer failed with error: {}".format( .format(sample.errmsg))
sample.errmsg))
if self._stopped_consumers < len(self._consumers): if len(self._consumer_endsig.keys()) < len(self._consumers):
self._inq.put(sample) self._inq.put(sample)
else: else:
raise ValueError("all consumers exited, no more samples") self._exit = True
raise StopIteration("all consumers exited, no more samples")
else: else:
self._consumed += 1 self._consumed += 1
return sample return sample
raise StopIteration()
def reset(self): def reset(self):
""" reset for a new epoch of samples """ reset for a new epoch of samples
""" """
assert not self._exit, "cannot reset for already stopped dataset"
if self._epoch < 0: if self._epoch < 0:
self._epoch = 0 self._epoch = 0
for p in self._consumers: for w in self._consumers:
p.start() w.start()
self._producer.start() self._producer.start()
else: else:
assert self._consumer_healthy(), "cannot start another pass of data" \
" for some consumers exited abnormally before!!!"
if not self.drained(): if not self.drained():
logger.warn("do not reset before epoch[%d] finishes".format( logger.warn("reset before epoch[{}] finishes".format(self._epoch))
self._epoch))
self._produced = self._produced - self._consumed self._produced = self._produced - self._consumed
else: else:
self._produced = 0 self._produced = 0
self._epoch += 1 self._epoch += 1
assert self._stopped_consumers == 0, "some consumers already exited," \ assert len(self._consumer_endsig.keys()) == 0, "some consumers already exited," \
+ " cannot start another epoch" + " cannot start another epoch"
self._source.reset() self._source.reset()
...@@ -221,9 +272,4 @@ class ParallelMappedDataset(ProxiedDataset): ...@@ -221,9 +272,4 @@ class ParallelMappedDataset(ProxiedDataset):
# FIXME(dengkaipeng): fix me if you have better impliment # FIXME(dengkaipeng): fix me if you have better impliment
# handle terminate reader process, do not print stack frame # handle terminate reader process, do not print stack frame
def _reader_exit(signum, frame): signal.signal(signal.SIGTERM, lambda signum, frame : sys.exit())
logger.debug("Reader process exit.")
sys.exit()
signal.signal(signal.SIGTERM, _reader_exit)
...@@ -22,9 +22,11 @@ import six ...@@ -22,9 +22,11 @@ import six
if six.PY3: if six.PY3:
import pickle import pickle
from io import BytesIO as StringIO from io import BytesIO as StringIO
from queue import Empty
else: else:
import cPickle as pickle import cPickle as pickle
from cStringIO import StringIO from cStringIO import StringIO
from Queue import Empty
import logging import logging
import traceback import traceback
...@@ -87,6 +89,8 @@ class SharedQueue(Queue): ...@@ -87,6 +89,8 @@ class SharedQueue(Queue):
buff = super(SharedQueue, self).get(**kwargs) buff = super(SharedQueue, self).get(**kwargs)
data = buff.get() data = buff.get()
return pickle.load(StringIO(data)) return pickle.load(StringIO(data))
except Empty as e:
raise e
except Exception as e: except Exception as e:
stack_info = traceback.format_exc() stack_info = traceback.format_exc()
err_msg = 'failed to get element from SharedQueue '\ err_msg = 'failed to get element from SharedQueue '\
......
...@@ -316,8 +316,6 @@ class PageAllocator(object): ...@@ -316,8 +316,6 @@ class PageAllocator(object):
start_pos = pos start_pos = pos
flags = '' flags = ''
while True: while True:
# maybe flags already has some '0' pages,
# so just check 'page_num - len(flags)' pages
flags = self.get_page_status(pos, page_num, ret_flag=True) flags = self.get_page_status(pos, page_num, ret_flag=True)
if flags.count('0') == page_num: if flags.count('0') == page_num:
...@@ -343,10 +341,10 @@ class PageAllocator(object): ...@@ -343,10 +341,10 @@ class PageAllocator(object):
if free_pages == 0: if free_pages == 0:
err_msg = 'all pages have been used:%s' % (str(self)) err_msg = 'all pages have been used:%s' % (str(self))
else: else:
err_msg = 'not found available pages with page_status[%s] '\ err_msg = 'not found enough pages[avail:%d, expect:%d] '\
'and %d free pages' % (str(page_status), free_pages) 'with total free pages[%d]' % (page_status[0], page_num, free_pages)
err_msg = 'failed to malloc %d pages at pos[%d] for reason[%s] and allocator status[%s]' \ err_msg = 'failed to malloc %d pages at pos[%d] for reason[%s] '\
% (page_num, pos, err_msg, str(self)) 'and allocator status[%s]' % (page_num, pos, err_msg, str(self))
raise MemoryFullError(err_msg) raise MemoryFullError(err_msg)
self.set_page_status(pos, page_num, '1') self.set_page_status(pos, page_num, '1')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册