提交 d626a675 编写于 作者: W walloollaw 提交者: qingqing01

Fix hang bug in parallel_map.py when multiprocessing mode used and some...

Fix hang bug in parallel_map.py when multiprocessing mode used and some consumer exit abnormally (#106)
上级 cbdd3bc5
......@@ -25,6 +25,7 @@ class Dataset(object):
def __init__(self):
self._epoch = -1
self._pos = 0
def __next__(self):
return self.next()
......@@ -33,8 +34,8 @@ class Dataset(object):
return self
def __str__(self):
return "{}(fname:{}, epoch:{:d}, size:{:d}, pos:{:d})".format(
type(self).__name__, self._fname, self._epoch,
return "{}(epoch:{:d}, size:{:d}, pos:{:d})".format(
type(self).__name__, self._epoch,
self.size(), self._pos)
def next(self):
......
......@@ -82,8 +82,8 @@ class RoiDbSource(Dataset):
self._imid2path = None
def __str__(self):
return 'RoiDbSource(fname:%s,epoch:%d,size:%d,pos:%d)' \
% (self._fname, self._epoch, self.size(), self._pos)
return 'RoiDbSource(epoch:%d,size:%d,pos:%d,fname:%s)' \
% (self._epoch, self.size(), self._pos, self._fname)
def next(self):
""" load next sample
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import unittest
import sys
import logging
import random
import copy
import set_env
import ppdet.data.transform as tf
from ppdet.data.dataset import Dataset
class MemorySource(Dataset):
""" memory data source for testing
"""
def __init__(self, samples):
super(MemorySource, self).__init__()
self._epoch = -1
self._pos = -1
self._drained = False
self._samples = samples
def next(self):
if self._epoch < 0:
self.reset()
if self._pos >= self.size():
self._drained = True
raise StopIteration("no more data in " + str(self))
else:
sample = copy.deepcopy(self._samples[self._pos])
self._pos += 1
return sample
def reset(self):
if self._epoch < 0:
self._epoch = 0
else:
self._epoch += 1
self._pos = 0
self._drained = False
random.shuffle(self._samples)
def size(self):
return len(self._samples)
def drained(self):
assert self._epoch >= 0, "the first epoch has not started yet"
return self._pos >= self.size()
def epoch_id(self):
return self._epoch
class TestDataset(unittest.TestCase):
"""Test cases for ppdet.data.dataset
"""
@classmethod
def setUpClass(cls):
""" setup
"""
pass
@classmethod
def tearDownClass(cls):
""" tearDownClass """
pass
def test_next(self):
""" test next
"""
samples = list(range(10))
mem_sc = MemorySource(samples)
for i, d in enumerate(mem_sc):
self.assertTrue(d in samples)
def test_transform_with_abnormal_worker(self):
""" test dataset transform with abnormally exit process
"""
samples = list(range(1000))
ds = MemorySource(samples)
def _mapper(sample):
if sample == 3:
sys.exit(1)
return 2 * sample
worker_conf = {'WORKER_NUM': 2, 'use_process': True}
mapped = tf.map(ds, _mapper, worker_conf)
ct = 0
for i, d in enumerate(mapped):
ct += 1
self.assertTrue(d / 2 in samples)
self.assertEqual(len(samples) - 1, ct)
def test_transform_with_delay_worker(self):
""" test dataset transform with delayed process
"""
samples = list(range(1000))
ds = MemorySource(samples)
def _mapper(sample):
if sample == 3:
time.sleep(30)
return 2 * sample
worker_conf = {'WORKER_NUM': 2, 'use_process': True}
mapped = tf.map(ds, _mapper, worker_conf)
ct = 0
for i, d in enumerate(mapped):
ct += 1
self.assertTrue(d / 2 in samples)
self.assertEqual(len(samples), ct)
if __name__ == '__main__':
logging.basicConfig()
unittest.main()
......@@ -21,6 +21,11 @@ from __future__ import print_function
import sys
import six
if six.PY3:
from queue import Empty
else:
from Queue import Empty
import uuid
import logging
import signal
......@@ -31,15 +36,19 @@ logger = logging.getLogger(__name__)
class EndSignal(object):
def __init__(self, errno=0, errmsg=''):
""" signal used to notify worker to exit
"""
def __init__(self, id, errno=0, errmsg=''):
self.id = id
self.errno = errno
self.errmsg = errmsg
class ParallelMappedDataset(ProxiedDataset):
"""
Transform samples to mapped samples which is similar to 'basic.MappedDataset',
but multiple workers (threads or processes) will be used
Transform samples to mapped samples which is similar to
'basic.MappedDataset', but multiple workers (threads or processes)
will be used
Notes:
this class is not thread-safe
......@@ -58,9 +67,10 @@ class ParallelMappedDataset(ProxiedDataset):
args.update(worker_args)
if args['use_process'] and type(args['memsize']) is str:
assert args['memsize'][-1].lower() == 'g', \
"invalid param for memsize[%s], should be ended with 'G' or 'g'" % (args['memsize'])
"invalid param for memsize[{}], should " \
"be ended with 'G' or 'g'".format(args['memsize'])
gb = args['memsize'][:-1]
args['memsize'] = int(gb) * 1024**3
args['memsize'] = int(gb) * 1024 ** 3
self._worker_args = args
self._started = False
......@@ -103,22 +113,25 @@ class ParallelMappedDataset(ProxiedDataset):
self._producer.daemon = True
self._consumers = []
self._consumer_endsig = {}
for i in range(consumer_num):
consumer_id = 'consumer-' + id + '-' + str(i)
p = Worker(
target=self._consume,
args=('consumer-' + id + '_' + str(i), self._inq, self._outq,
args=(consumer_id, self._inq, self._outq,
self._mapper))
self._consumers.append(p)
p.daemon = True
setattr(p, 'id', consumer_id)
self._epoch = -1
self._feeding_ev = Event()
self._produced = 0 # produced sample in self._produce
self._consumed = 0 # consumed sample in self.next
self._stopped_consumers = 0
def _produce(self, id, source, inq):
"""Fetch data from source and feed it to 'inq' queue"""
endsig = EndSignal(id)
while True:
self._feeding_ev.wait()
if self._exit:
......@@ -128,32 +141,38 @@ class ParallelMappedDataset(ProxiedDataset):
self._produced += 1
except StopIteration:
self._feeding_ev.clear()
self._feeding_ev.wait() # wait other guy to wake up me
logger.debug("producer[{}] starts new epoch".format(id))
self._feeding_ev.wait()
except Exception as e:
msg = "producer[{}] failed with error: {}".format(id, str(e))
inq.put(EndSignal(-1, msg))
endsig.errno = -1
endsig.errmsg = "producer[{}] failed with error: {}" \
.format(id, str(e))
inq.put(endsig)
break
logger.debug("producer[{}] exits".format(id))
def _consume(self, id, inq, outq, mapper):
"""Fetch data from 'inq', process it and put result to 'outq'"""
if self._worker_args['use_process']:
# handle SIGTERM signal to exit to prevent print stack frame
signal.signal(signal.SIGTERM, lambda signum, frame : sys.exit())
endsig = EndSignal(id)
while True:
sample = inq.get()
if isinstance(sample, EndSignal):
sample.errmsg += "[consumer[{}] exits]".format(id)
outq.put(sample)
logger.debug("end signal received, " +
"consumer[{}] exits".format(id))
endsig.errno = sample.errno
endsig.errmsg = "consumer[{}] exits for reason[{}]" \
.format(id, sample.errmsg)
outq.put(endsig)
break
try:
result = mapper(sample)
outq.put(result)
except Exception as e:
msg = 'failed to map consumer[%s], error: {}'.format(str(e), id)
outq.put(EndSignal(-1, msg))
endsig.errno = -2
endsig.errmsg = "consumer[{}] failed to map with error:[{}]" \
.format(id, str(e))
outq.put(endsig)
break
def drained(self):
......@@ -168,6 +187,25 @@ class ParallelMappedDataset(ProxiedDataset):
for _ in range(len(self._consumers)):
self._inq.put(EndSignal(0, "notify consumers to exit"))
def _consumer_healthy(self):
abnormal_num = 0
for w in self._consumers:
if not w.is_alive() and w.id not in self._consumer_endsig:
abnormal_num += 1
if self._worker_args['use_process']:
errmsg = "consumer[{}] exit abnormally with exitcode[{}]" \
.format(w.pid, w.exitcode)
else:
errmsg = "consumer[{}] exit abnormally".format(w.ident)
logger.warn(errmsg)
if abnormal_num > 0:
logger.warn("{} consumers have exited abnormally!!!" \
.format(abnormal_num))
return abnormal_num == 0
def next(self):
""" get next transformed sample
"""
......@@ -177,41 +215,54 @@ class ParallelMappedDataset(ProxiedDataset):
if self.drained():
raise StopIteration()
while True:
sample = self._outq.get()
while not self._exit:
try:
sample = self._outq.get(timeout=3)
except Empty as e:
if not self._consumer_healthy():
raise StopIteration()
else:
continue
if isinstance(sample, EndSignal):
self._stopped_consumers += 1
if sample.errno != 0:
logger.warn("consumer failed with error: {}".format(
sample.errmsg))
self._consumer_endsig[sample.id] = sample
logger.warn("recv endsignal from outq with errmsg[{}]" \
.format(sample.errmsg))
if self._stopped_consumers < len(self._consumers):
if len(self._consumer_endsig.keys()) < len(self._consumers):
self._inq.put(sample)
else:
raise ValueError("all consumers exited, no more samples")
self._exit = True
raise StopIteration("all consumers exited, no more samples")
else:
self._consumed += 1
return sample
raise StopIteration()
def reset(self):
""" reset for a new epoch of samples
"""
assert not self._exit, "cannot reset for already stopped dataset"
if self._epoch < 0:
self._epoch = 0
for p in self._consumers:
p.start()
for w in self._consumers:
w.start()
self._producer.start()
else:
assert self._consumer_healthy(), "cannot start another pass of data" \
" for some consumers exited abnormally before!!!"
if not self.drained():
logger.warn("do not reset before epoch[%d] finishes".format(
self._epoch))
logger.warn("reset before epoch[{}] finishes".format(self._epoch))
self._produced = self._produced - self._consumed
else:
self._produced = 0
self._epoch += 1
assert self._stopped_consumers == 0, "some consumers already exited," \
assert len(self._consumer_endsig.keys()) == 0, "some consumers already exited," \
+ " cannot start another epoch"
self._source.reset()
......@@ -221,9 +272,4 @@ class ParallelMappedDataset(ProxiedDataset):
# FIXME(dengkaipeng): fix me if you have better impliment
# handle terminate reader process, do not print stack frame
def _reader_exit(signum, frame):
logger.debug("Reader process exit.")
sys.exit()
signal.signal(signal.SIGTERM, _reader_exit)
signal.signal(signal.SIGTERM, lambda signum, frame : sys.exit())
......@@ -22,9 +22,11 @@ import six
if six.PY3:
import pickle
from io import BytesIO as StringIO
from queue import Empty
else:
import cPickle as pickle
from cStringIO import StringIO
from Queue import Empty
import logging
import traceback
......@@ -87,6 +89,8 @@ class SharedQueue(Queue):
buff = super(SharedQueue, self).get(**kwargs)
data = buff.get()
return pickle.load(StringIO(data))
except Empty as e:
raise e
except Exception as e:
stack_info = traceback.format_exc()
err_msg = 'failed to get element from SharedQueue '\
......
......@@ -316,8 +316,6 @@ class PageAllocator(object):
start_pos = pos
flags = ''
while True:
# maybe flags already has some '0' pages,
# so just check 'page_num - len(flags)' pages
flags = self.get_page_status(pos, page_num, ret_flag=True)
if flags.count('0') == page_num:
......@@ -343,10 +341,10 @@ class PageAllocator(object):
if free_pages == 0:
err_msg = 'all pages have been used:%s' % (str(self))
else:
err_msg = 'not found available pages with page_status[%s] '\
'and %d free pages' % (str(page_status), free_pages)
err_msg = 'failed to malloc %d pages at pos[%d] for reason[%s] and allocator status[%s]' \
% (page_num, pos, err_msg, str(self))
err_msg = 'not found enough pages[avail:%d, expect:%d] '\
'with total free pages[%d]' % (page_status[0], page_num, free_pages)
err_msg = 'failed to malloc %d pages at pos[%d] for reason[%s] '\
'and allocator status[%s]' % (page_num, pos, err_msg, str(self))
raise MemoryFullError(err_msg)
self.set_page_status(pos, page_num, '1')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册