parallel_map.py 9.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# function:
#   transform samples in 'source' using 'mapper'

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
23
import os
24
import six
25 26 27 28 29
if six.PY3:
    from queue import Empty
else:
    from Queue import Empty

30 31 32 33 34 35 36 37 38 39
import uuid
import logging
import signal
import threading
from .transformer import ProxiedDataset

logger = logging.getLogger(__name__)


class EndSignal(object):
40 41 42 43
    """ signal used to notify worker to exit
    """
    def __init__(self, id, errno=0, errmsg=''):
        self.id = id
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
        self.errno = errno
        self.errmsg = errmsg


class ParallelMappedDataset(ProxiedDataset):
    """
    Transform samples to mapped samples which is similar to 'basic.MappedDataset',
    but multiple workers (threads or processes) will be used

    Notes:
        this class is not thread-safe
    """

    def __init__(self, source, mapper, worker_args):
        super(ParallelMappedDataset, self).__init__(source)
        worker_args = {k.lower(): v for k, v in worker_args.items()}

W
walloollaw 已提交
61 62
        args = {'bufsize': 100, 'worker_num': 8,
            'use_process': False, 'memsize': '3G'}
63
        args.update(worker_args)
W
walloollaw 已提交
64 65 66 67 68 69
        if args['use_process'] and type(args['memsize']) is str:
            assert args['memsize'][-1].lower() == 'g', \
                "invalid param for memsize[%s], should be ended with 'G' or 'g'" % (args['memsize'])
            gb = args['memsize'][:-1]
            args['memsize'] = int(gb) * 1024 ** 3

70 71 72 73 74 75 76 77 78
        self._worker_args = args
        self._started = False
        self._source = source
        self._mapper = mapper
        self._exit = False
        self._setup()

    def _setup(self):
        """setup input/output queues and workers """
W
walloollaw 已提交
79
        use_process = self._worker_args.get('use_process', False)
K
Kaipeng Deng 已提交
80 81 82 83
        if use_process and sys.platform == "win32":
            logger.info("Use multi-thread reader instead of "
                        "multi-process reader on Windows.")
            use_process = False
84 85 86 87 88 89

        bufsize = self._worker_args['bufsize']
        if use_process:
            from .shared_queue import SharedQueue as Queue
            from multiprocessing import Process as Worker
            from multiprocessing import Event
W
walloollaw 已提交
90 91 92
            memsize = self._worker_args['memsize']
            self._inq = Queue(bufsize, memsize=memsize)
            self._outq = Queue(bufsize, memsize=memsize)
93 94 95 96 97 98 99
        else:
            if six.PY3:
                from queue import Queue
            else:
                from Queue import Queue
            from threading import Thread as Worker
            from threading import Event
W
walloollaw 已提交
100 101
            self._inq = Queue(bufsize)
            self._outq = Queue(bufsize)
102 103 104 105 106 107 108 109 110

        consumer_num = self._worker_args['worker_num']
        id = str(uuid.uuid4())[-3:]
        self._producer = threading.Thread(
            target=self._produce,
            args=('producer-' + id, self._source, self._inq))
        self._producer.daemon = True

        self._consumers = []
111
        self._consumer_endsig = {}
112
        for i in range(consumer_num):
113
            consumer_id = 'consumer-' + id + '-' + str(i)
114 115
            p = Worker(
                target=self._consume,
116
                args=(consumer_id, self._inq, self._outq,
117 118 119
                      self._mapper))
            self._consumers.append(p)
            p.daemon = True
120
            setattr(p, 'id', consumer_id)
121 122 123 124 125 126 127 128

        self._epoch = -1
        self._feeding_ev = Event()
        self._produced = 0  # produced sample in self._produce
        self._consumed = 0  # consumed sample in self.next

    def _produce(self, id, source, inq):
        """Fetch data from source and feed it to 'inq' queue"""
129
        endsig = EndSignal(id)
130 131 132 133 134 135 136 137 138 139 140
        while True:
            self._feeding_ev.wait()
            if self._exit:
                break
            try:
                inq.put(source.next())
                self._produced += 1
            except StopIteration:
                self._feeding_ev.clear()
                self._feeding_ev.wait()  # wait other guy to wake up me
            except Exception as e:
141 142 143
                endsig.errno = -1
                endsig.errmsg = "producer[{}] failed with error: {}".format(id, str(e))
                inq.put(endsig)
144 145 146 147
                break

    def _consume(self, id, inq, outq, mapper):
        """Fetch data from 'inq', process it and put result to 'outq'"""
148 149 150 151 152
        if self._worker_args['use_process']:
            # handle SIGTERM signal to exit to prevent print stack frame
            signal.signal(signal.SIGTERM, lambda signum, frame : sys.exit())

        endsig = EndSignal(id)
153 154 155
        while True:
            sample = inq.get()
            if isinstance(sample, EndSignal):
156 157 158
                endsig.errno = sample.errno
                endsig.errmsg = "consumer[%s] exits for reason[%s]" % (id, sample.errmsg)
                outq.put(endsig)
159 160 161 162 163 164
                break

            try:
                result = mapper(sample)
                outq.put(result)
            except Exception as e:
165 166 167
                endsig.errno = -2
                endsig.errmsg = 'consumer[%s] failed to map with error:[%s]' % (id, str(e))
                outq.put(endsig)
168 169 170 171 172 173 174 175 176 177 178 179 180 181
                break

    def drained(self):
        assert self._epoch >= 0, "first epoch has not started yet"
        return self._source.drained() and self._produced == self._consumed

    def stop(self):
        """ notify to exit
        """
        self._exit = True
        self._feeding_ev.set()
        for _ in range(len(self._consumers)):
            self._inq.put(EndSignal(0, "notify consumers to exit"))

182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
    def _consumer_healthy(self):
        abnormal_num = 0
        is_process = self._worker_args['use_process']
        for w in self._consumers:
            if is_process:
                if not w.is_alive():
                    if w.id not in self._consumer_endsig:
                        abnormal_num += 1
                        logger.warn('consumer[%s] exit abnormally with exitcode[%d]' % (w.pid, w.exitcode))
            else:
                if not w.is_alive():
                    if w.id not in self._consumer_endsig:
                        abnormal_num += 1
                        logger.warn('consumer[%s] exit abnormally' % (w.ident))

        if abnormal_num > 0:
            logger.warn('%d consumers have exited abnormally!!!' % (abnormal_num))

        return abnormal_num == 0

202 203 204 205 206 207 208 209 210
    def next(self):
        """ get next transformed sample
        """
        if self._epoch < 0:
            self.reset()

        if self.drained():
            raise StopIteration()

211 212 213 214 215 216 217 218 219
        while not self._exit:
            try:
                sample = self._outq.get(timeout=3)
            except Empty as e:
                if not self._consumer_healthy():
                    raise StopIteration()
                else:
                    continue

220
            if isinstance(sample, EndSignal):
221 222
                self._consumer_endsig[sample.id] = sample
                logger.warn("recv endsignal from outq with errmsg[{}]".format(sample.errmsg))
223

224
                if len(self._consumer_endsig.keys()) < len(self._consumers):
225 226 227
                    self._inq.put(sample)
                else:
                    raise ValueError("all consumers exited, no more samples")
228

229 230 231 232
            else:
                self._consumed += 1
                return sample

233 234
        raise StopIteration()

235 236 237
    def reset(self):
        """ reset for a new epoch of samples
        """
238 239
        assert not self._exit, "cannot reset for already stopped dataset"

240 241
        if self._epoch < 0:
            self._epoch = 0
242 243
            for w in self._consumers:
                w.start()
244 245
            self._producer.start()
        else:
246 247 248
            assert self._consumer_healthy(), "cannot start another pass of data" \
                " for some consumers exited abnormally before!!!"

249
            if not self.drained():
250
                logger.warn("reset before epoch[{}] finishes".format(
251 252 253 254 255 256 257
                    self._epoch))
                self._produced = self._produced - self._consumed
            else:
                self._produced = 0

            self._epoch += 1

258
        assert len(self._consumer_endsig.keys()) == 0, "some consumers already exited," \
259 260 261 262 263 264 265 266 267
            + " cannot start another epoch"

        self._source.reset()
        self._consumed = 0
        self._feeding_ev.set()


# FIXME(dengkaipeng): fix me if you have better impliment
# handle terminate reader process, do not print stack frame
268
signal.signal(signal.SIGTERM, lambda signum, frame : sys.exit())