提交 a1c4a797 编写于 作者: W wanglong03

fix hang bug when consumer in parallel_map.py exit abnormally

上级 182d5092
......@@ -25,6 +25,7 @@ class Dataset(object):
def __init__(self):
self._epoch = -1
self._pos = 0
def __next__(self):
return self.next()
......@@ -33,8 +34,8 @@ class Dataset(object):
return self
def __str__(self):
return "{}(fname:{}, epoch:{:d}, size:{:d}, pos:{:d})".format(
type(self).__name__, self._fname, self._epoch,
return "{}(epoch:{:d}, size:{:d}, pos:{:d})".format(
type(self).__name__, self._epoch,
self.size(), self._pos)
def next(self):
......
......@@ -79,8 +79,8 @@ class RoiDbSource(Dataset):
self._imid2path = None
def __str__(self):
return 'RoiDbSource(fname:%s,epoch:%d,size:%d,pos:%d)' \
% (self._fname, self._epoch, self.size(), self._pos)
return 'RoiDbSource(epoch:%d,size:%d,pos:%d,fname:%s)' \
% (self._epoch, self.size(), self._pos, self._fname)
def next(self):
""" load next sample
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import unittest
import sys
import logging
import random
import copy
import set_env
import ppdet.data.transform as tf
from ppdet.data.dataset import Dataset
class MemorySource(Dataset):
""" memory data source for testing
"""
def __init__(self, samples):
super(MemorySource, self).__init__()
self._epoch = -1
self._pos = -1
self._drained = False
self._samples = samples
def next(self):
if self._epoch < 0:
self.reset()
if self._pos >= self.size():
self._drained = True
raise StopIteration("no more data in " + str(self))
else:
sample = copy.deepcopy(self._samples[self._pos])
self._pos += 1
return sample
def reset(self):
if self._epoch < 0:
self._epoch = 0
else:
self._epoch += 1
self._pos = 0
self._drained = False
random.shuffle(self._samples)
def size(self):
return len(self._samples)
def drained(self):
assert self._epoch >= 0, "the first epoch has not started yet"
return self._pos >= self.size()
def epoch_id(self):
return self._epoch
class TestDataset(unittest.TestCase):
"""Test cases for ppdet.data.dataset
"""
@classmethod
def setUpClass(cls):
""" setup
"""
pass
@classmethod
def tearDownClass(cls):
""" tearDownClass """
pass
def test_next(self):
""" test next
"""
samples = list(range(10))
mem_sc = MemorySource(samples)
for i, d in enumerate(mem_sc):
self.assertTrue(d in samples)
def test_transform_with_abnormal_worker(self):
""" test dataset transform with abnormally exit process
"""
samples = list(range(1000))
ds = MemorySource(samples)
def _mapper(sample):
if sample == 3:
sys.exit(1)
return 2 * sample
worker_conf = {'WORKER_NUM': 2, 'use_process': True}
mapped = tf.map(ds, _mapper, worker_conf)
ct = 0
for i, d in enumerate(mapped):
ct += 1
self.assertTrue(d / 2 in samples)
self.assertEqual(len(samples) - 1, ct)
def test_transform_with_delay_worker(self):
""" test dataset transform with delayed process
"""
samples = list(range(1000))
ds = MemorySource(samples)
def _mapper(sample):
if sample == 3:
time.sleep(30)
return 2 * sample
worker_conf = {'WORKER_NUM': 2, 'use_process': True}
mapped = tf.map(ds, _mapper, worker_conf)
ct = 0
for i, d in enumerate(mapped):
ct += 1
self.assertTrue(d / 2 in samples)
self.assertEqual(len(samples), ct)
if __name__ == '__main__':
logging.basicConfig()
unittest.main()
......@@ -20,7 +20,13 @@ from __future__ import division
from __future__ import print_function
import sys
import os
import six
if six.PY3:
from queue import Empty
else:
from Queue import Empty
import uuid
import logging
import signal
......@@ -31,7 +37,10 @@ logger = logging.getLogger(__name__)
class EndSignal(object):
def __init__(self, errno=0, errmsg=''):
""" signal used to notify worker to exit
"""
def __init__(self, id, errno=0, errmsg=''):
self.id = id
self.errno = errno
self.errmsg = errmsg
......@@ -99,22 +108,25 @@ class ParallelMappedDataset(ProxiedDataset):
self._producer.daemon = True
self._consumers = []
self._consumer_endsig = {}
for i in range(consumer_num):
consumer_id = 'consumer-' + id + '-' + str(i)
p = Worker(
target=self._consume,
args=('consumer-' + id + '_' + str(i), self._inq, self._outq,
args=(consumer_id, self._inq, self._outq,
self._mapper))
self._consumers.append(p)
p.daemon = True
setattr(p, 'id', consumer_id)
self._epoch = -1
self._feeding_ev = Event()
self._produced = 0 # produced sample in self._produce
self._consumed = 0 # consumed sample in self.next
self._stopped_consumers = 0
def _produce(self, id, source, inq):
"""Fetch data from source and feed it to 'inq' queue"""
endsig = EndSignal(id)
while True:
self._feeding_ev.wait()
if self._exit:
......@@ -125,31 +137,34 @@ class ParallelMappedDataset(ProxiedDataset):
except StopIteration:
self._feeding_ev.clear()
self._feeding_ev.wait() # wait other guy to wake up me
logger.debug("producer[{}] starts new epoch".format(id))
except Exception as e:
msg = "producer[{}] failed with error: {}".format(id, str(e))
inq.put(EndSignal(-1, msg))
endsig.errno = -1
endsig.errmsg = "producer[{}] failed with error: {}".format(id, str(e))
inq.put(endsig)
break
logger.debug("producer[{}] exits".format(id))
def _consume(self, id, inq, outq, mapper):
"""Fetch data from 'inq', process it and put result to 'outq'"""
if self._worker_args['use_process']:
# handle SIGTERM signal to exit to prevent print stack frame
signal.signal(signal.SIGTERM, lambda signum, frame : sys.exit())
endsig = EndSignal(id)
while True:
sample = inq.get()
if isinstance(sample, EndSignal):
sample.errmsg += "[consumer[{}] exits]".format(id)
outq.put(sample)
logger.debug("end signal received, " +
"consumer[{}] exits".format(id))
endsig.errno = sample.errno
endsig.errmsg = "consumer[%s] exits for reason[%s]" % (id, sample.errmsg)
outq.put(endsig)
break
try:
result = mapper(sample)
outq.put(result)
except Exception as e:
msg = 'failed to map consumer[%s], error: {}'.format(str(e), id)
outq.put(EndSignal(-1, msg))
endsig.errno = -2
endsig.errmsg = 'consumer[%s] failed to map with error:[%s]' % (id, str(e))
outq.put(endsig)
break
def drained(self):
......@@ -164,6 +179,26 @@ class ParallelMappedDataset(ProxiedDataset):
for _ in range(len(self._consumers)):
self._inq.put(EndSignal(0, "notify consumers to exit"))
def _consumer_healthy(self):
abnormal_num = 0
is_process = self._worker_args['use_process']
for w in self._consumers:
if is_process:
if not w.is_alive():
if w.id not in self._consumer_endsig:
abnormal_num += 1
logger.warn('consumer[%s] exit abnormally with exitcode[%d]' % (w.pid, w.exitcode))
else:
if not w.is_alive():
if w.id not in self._consumer_endsig:
abnormal_num += 1
logger.warn('consumer[%s] exit abnormally' % (w.ident))
if abnormal_num > 0:
logger.warn('%d consumers have exited abnormally!!!' % (abnormal_num))
return abnormal_num == 0
def next(self):
""" get next transformed sample
"""
......@@ -173,33 +208,46 @@ class ParallelMappedDataset(ProxiedDataset):
if self.drained():
raise StopIteration()
while True:
sample = self._outq.get()
while not self._exit:
try:
sample = self._outq.get(timeout=3)
except Empty as e:
if not self._consumer_healthy():
raise StopIteration()
else:
continue
if isinstance(sample, EndSignal):
self._stopped_consumers += 1
if sample.errno != 0:
logger.warn("consumer failed with error: {}".format(
sample.errmsg))
self._consumer_endsig[sample.id] = sample
logger.warn("recv endsignal from outq with errmsg[{}]".format(sample.errmsg))
if self._stopped_consumers < len(self._consumers):
if len(self._consumer_endsig.keys()) < len(self._consumers):
self._inq.put(sample)
else:
raise ValueError("all consumers exited, no more samples")
else:
self._consumed += 1
return sample
raise StopIteration()
def reset(self):
""" reset for a new epoch of samples
"""
assert not self._exit, "cannot reset for already stopped dataset"
if self._epoch < 0:
self._epoch = 0
for p in self._consumers:
p.start()
for w in self._consumers:
w.start()
self._producer.start()
else:
assert self._consumer_healthy(), "cannot start another pass of data" \
" for some consumers exited abnormally before!!!"
if not self.drained():
logger.warn("do not reset before epoch[%d] finishes".format(
logger.warn("reset before epoch[{}] finishes".format(
self._epoch))
self._produced = self._produced - self._consumed
else:
......@@ -207,7 +255,7 @@ class ParallelMappedDataset(ProxiedDataset):
self._epoch += 1
assert self._stopped_consumers == 0, "some consumers already exited," \
assert len(self._consumer_endsig.keys()) == 0, "some consumers already exited," \
+ " cannot start another epoch"
self._source.reset()
......@@ -217,9 +265,4 @@ class ParallelMappedDataset(ProxiedDataset):
# FIXME(dengkaipeng): fix me if you have better impliment
# handle terminate reader process, do not print stack frame
def _reader_exit(signum, frame):
logger.debug("Reader process exit.")
sys.exit()
signal.signal(signal.SIGTERM, _reader_exit)
signal.signal(signal.SIGTERM, lambda signum, frame : sys.exit())
......@@ -22,9 +22,11 @@ import six
if six.PY3:
import pickle
from io import BytesIO as StringIO
from queue import Empty
else:
import cPickle as pickle
from cStringIO import StringIO
from Queue import Empty
import logging
import traceback
......@@ -87,6 +89,8 @@ class SharedQueue(Queue):
buff = super(SharedQueue, self).get(**kwargs)
data = buff.get()
return pickle.load(StringIO(data))
except Empty as e:
raise e
except Exception as e:
stack_info = traceback.format_exc()
err_msg = 'failed to get element from SharedQueue '\
......
......@@ -316,8 +316,6 @@ class PageAllocator(object):
start_pos = pos
flags = ''
while True:
# maybe flags already has some '0' pages,
# so just check 'page_num - len(flags)' pages
flags = self.get_page_status(
pos, page_num, ret_flag=True)
......@@ -344,8 +342,8 @@ class PageAllocator(object):
if free_pages == 0:
err_msg = 'all pages have been used:%s' % (str(self))
else:
err_msg = 'not found available pages with page_status[%s] '\
'and %d free pages' % (str(page_status), free_pages)
err_msg = 'not found enough pages[avail:%d, expect:%d] '\
'with total free pages[%d]' % (page_status[0], page_num, free_pages)
err_msg = 'failed to malloc %d pages at pos[%d] for reason[%s] and allocator status[%s]' \
% (page_num, pos, err_msg, str(self))
raise MemoryFullError(err_msg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册