parallel_map.py 9.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# function:
16
#   transform samples in 'source' using 'worker'
17 18 19 20 21

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

22
import os
23 24
import sys
import six
25 26 27 28 29
if six.PY3:
    from queue import Empty
else:
    from Queue import Empty

30 31 32 33
import uuid
import logging
import signal
import threading
34
import traceback
35 36 37

logger = logging.getLogger(__name__)

38
main_pid = os.getpid()
K
Kaipeng Deng 已提交
39 40
worker_set = set()

41 42

class EndSignal(object):
43 44
    """ signal used to notify worker to exit
    """
45

46 47
    def __init__(self, id, errno=0, errmsg=''):
        self.id = id
48 49 50 51
        self.errno = errno
        self.errmsg = errmsg


52
class ParallelMap(object):
53
    """
54 55 56
    Transform samples to mapped samples which is similar to 
    'basic.MappedDataset', but multiple workers (threads or processes) 
    will be used
57 58 59 60 61

    Notes:
        this class is not thread-safe
    """

62 63 64 65 66 67 68 69 70 71
    def __init__(self,
                 source,
                 worker,
                 worker_num,
                 bufsize=100,
                 use_process=False,
                 memsize='3G'):
        self._worker_num = worker_num
        self._bufsize = bufsize
        self._use_process = use_process
72
        if self._use_process and sys.platform == "win32":
Y
Yang Zhang 已提交
73 74
            logger.debug("Use multi-thread reader instead of "
                         "multi-process reader on Windows.")
75
            self._use_process = False
76
        if self._use_process and type(memsize) is str:
K
Kaipeng Deng 已提交
77 78 79 80 81
            assert memsize[-1].lower() in ['g', 'm'], \
                "invalid param for memsize[%s], should be " \
                "ended with 'G' or 'g' or 'M' or 'm'" % (memsize)
            power = 3 if memsize[-1].lower() == 'g' else 2
            self._memsize = int(memsize[:-1]) * (1024**power)
82 83
        self._started = False
        self._source = source
84
        self._worker = worker
85 86
        self._exit = False
        self._setup()
87 88 89 90 91 92 93
        self._souce_drained = False

    def __iter__(self):
        return self

    def __next__(self):
        return self.next()
94 95 96

    def _setup(self):
        """setup input/output queues and workers """
97
        use_process = self._use_process
98

99
        bufsize = self._bufsize
100 101 102 103
        if use_process:
            from .shared_queue import SharedQueue as Queue
            from multiprocessing import Process as Worker
            from multiprocessing import Event
104
            memsize = self._memsize
W
walloollaw 已提交
105 106
            self._inq = Queue(bufsize, memsize=memsize)
            self._outq = Queue(bufsize, memsize=memsize)
107 108 109 110 111 112 113
        else:
            if six.PY3:
                from queue import Queue
            else:
                from Queue import Queue
            from threading import Thread as Worker
            from threading import Event
W
walloollaw 已提交
114 115
            self._inq = Queue(bufsize)
            self._outq = Queue(bufsize)
116

117
        consumer_num = self._worker_num
118 119 120 121 122 123 124
        id = str(uuid.uuid4())[-3:]
        self._producer = threading.Thread(
            target=self._produce,
            args=('producer-' + id, self._source, self._inq))
        self._producer.daemon = True

        self._consumers = []
125
        self._consumer_endsig = {}
K
Kaipeng Deng 已提交
126
        global worker_set
127
        for i in range(consumer_num):
128
            consumer_id = 'consumer-' + id + '-' + str(i)
129 130
            p = Worker(
                target=self._consume,
131
                args=(consumer_id, self._inq, self._outq, self._worker))
132 133
            self._consumers.append(p)
            p.daemon = True
134
            setattr(p, 'id', consumer_id)
135 136
            if use_process:
                worker_set.add(p)
137 138 139 140 141 142 143 144

        self._epoch = -1
        self._feeding_ev = Event()
        self._produced = 0  # produced sample in self._produce
        self._consumed = 0  # consumed sample in self.next

    def _produce(self, id, source, inq):
        """Fetch data from source and feed it to 'inq' queue"""
145
        endsig = EndSignal(id)
146 147 148 149 150
        while True:
            self._feeding_ev.wait()
            if self._exit:
                break
            try:
151 152
                s = source.next()
                inq.put(s)
153 154
                self._produced += 1
            except StopIteration:
155
                self._souce_drained = True
156
                self._feeding_ev.clear()
157
                self._feeding_ev.wait()
158
            except Exception as e:
159 160 161 162
                endsig.errno = -1
                endsig.errmsg = "producer[{}] failed with error: {}" \
                    .format(id, str(e))
                inq.put(endsig)
163 164
                break

165
    def _consume(self, id, inq, outq, worker):
166
        """Fetch data from 'inq', process it and put result to 'outq'"""
167
        if self._use_process:
168
            # handle SIGTERM signal to exit to prevent print stack frame
169
            signal.signal(signal.SIGTERM, lambda signum, frame: sys.exit())
170 171

        endsig = EndSignal(id)
172 173 174
        while True:
            sample = inq.get()
            if isinstance(sample, EndSignal):
175 176 177 178
                endsig.errno = sample.errno
                endsig.errmsg = "consumer[{}] exits for reason[{}]" \
                    .format(id, sample.errmsg)
                outq.put(endsig)
179 180 181
                break

            try:
182
                result = worker(sample)
183 184
                outq.put(result)
            except Exception as e:
185 186 187 188
                endsig.errno = -2
                endsig.errmsg = "consumer[{}] failed to map with error:[{}]" \
                    .format(id, str(e))
                outq.put(endsig)
189 190 191 192 193 194 195 196 197 198 199 200 201 202
                break

    def drained(self):
        assert self._epoch >= 0, "first epoch has not started yet"
        return self._source.drained() and self._produced == self._consumed

    def stop(self):
        """ notify to exit
        """
        self._exit = True
        self._feeding_ev.set()
        for _ in range(len(self._consumers)):
            self._inq.put(EndSignal(0, "notify consumers to exit"))

203 204 205 206 207
    def _consumer_healthy(self):
        abnormal_num = 0
        for w in self._consumers:
            if not w.is_alive() and w.id not in self._consumer_endsig:
                abnormal_num += 1
208
                if self._use_process:
209 210 211 212
                    errmsg = "consumer[{}] exit abnormally with exitcode[{}]" \
                                .format(w.pid, w.exitcode)
                else:
                    errmsg = "consumer[{}] exit abnormally".format(w.ident)
213

214
                logger.warning(errmsg)
215 216

        if abnormal_num > 0:
217
            logger.warning("{} consumers have exited abnormally!!!" \
218 219 220 221
                .format(abnormal_num))

        return abnormal_num == 0

222 223 224 225 226 227 228 229 230
    def next(self):
        """ get next transformed sample
        """
        if self._epoch < 0:
            self.reset()

        if self.drained():
            raise StopIteration()

231 232 233 234 235 236 237 238 239
        while not self._exit:
            try:
                sample = self._outq.get(timeout=3)
            except Empty as e:
                if not self._consumer_healthy():
                    raise StopIteration()
                else:
                    continue

240
            if isinstance(sample, EndSignal):
241
                self._consumer_endsig[sample.id] = sample
242
                logger.warning("recv endsignal from outq with errmsg[{}]" \
243
                    .format(sample.errmsg))
244

245
                if len(self._consumer_endsig.keys()) < len(self._consumers):
246 247
                    self._inq.put(sample)
                else:
248 249
                    self._exit = True
                    raise StopIteration("all consumers exited, no more samples")
250 251 252 253
            else:
                self._consumed += 1
                return sample

254 255
        raise StopIteration()

256 257 258
    def reset(self):
        """ reset for a new epoch of samples
        """
259 260
        assert not self._exit, "cannot reset for already stopped dataset"

261 262
        if self._epoch < 0:
            self._epoch = 0
263 264
            for w in self._consumers:
                w.start()
265 266
            self._producer.start()
        else:
267 268 269
            assert self._consumer_healthy(), "cannot start another pass of data" \
                " for some consumers exited abnormally before!!!"

270
            if not self.drained():
271
                logger.warning("reset before epoch[{}] finishes".format(
272
                    self._epoch))
273 274 275 276 277 278
                self._produced = self._produced - self._consumed
            else:
                self._produced = 0

            self._epoch += 1

279
        assert len(self._consumer_endsig.keys()) == 0, "some consumers already exited," \
280 281 282
            + " cannot start another epoch"

        self._source.reset()
283
        self._souce_drained = False
284 285 286 287
        self._consumed = 0
        self._feeding_ev.set()


K
Kaipeng Deng 已提交
288
# FIXME: fix me if you have better impliment
289
# handle terminate reader process, do not print stack frame
290
signal.signal(signal.SIGTERM, lambda signum, frame: sys.exit())
291 292


293 294 295 296
# FIXME(dkp): KeyboardInterrupt should be handled inside ParallelMap
# and do such as: 1. exit workers 2. close queues 3. release shared
# memory, HACK KeyboardInterrupt with global signal.SIGINT handler
# here, should be refined later
K
Kaipeng Deng 已提交
297
def _term_workers(sig_num, frame):
298 299 300 301 302 303 304
    global worker_set, main_pid
    # only do subporcess killing in main process
    if os.getpid() != main_pid:
        return

    logger.info("KeyboardInterrupt: main proc {} exit, kill subprocess {}" \
                .format(os.getpid(), [w.pid for w in worker_set]))
K
Kaipeng Deng 已提交
305
    for w in worker_set:
306 307 308
        if w.pid is not None:
            os.kill(w.pid, signal.SIGINT)
    sys.exit()
309 310


K
Kaipeng Deng 已提交
311
signal.signal(signal.SIGINT, _term_workers)