channel.py 31.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
D
dongdaxiang 已提交
16 17 18 19 20 21 22 23 24 25
import threading
import multiprocessing
import multiprocessing.queues
import sys
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
26 27 28
import numpy as np
import logging
import enum
29
import os
30
import copy
D
dongdaxiang 已提交
31

32
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
33

D
dongdaxiang 已提交
34 35 36 37 38 39 40

class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
    RPC_PACKAGE_ERROR = 4
B
barrierye 已提交
41
    CLIENT_ERROR = 5
B
barrierye 已提交
42 43
    CLOSED_ERROR = 6
    UNKNOW = 7
D
dongdaxiang 已提交
44 45 46 47 48 49 50 51


class ChannelDataType(enum.Enum):
    DICT = 0
    CHANNEL_NPDATA = 1
    ERROR = 2


52 53 54 55
class ChannelData(object):
    def __init__(self,
                 datatype=None,
                 npdata=None,
B
barrierye 已提交
56
                 dictdata=None,
57 58
                 data_id=None,
                 ecode=None,
B
barrierye 已提交
59 60
                 error_info=None,
                 client_need_profile=False):
61 62 63
        '''
        There are several ways to use it:
        
B
barrierye 已提交
64 65 66
        1. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
        2. ChannelData(ChannelDataType.DICT.value, dictdata, data_id)
        3. ChannelData(ecode, error_info, data_id)
67 68 69 70 71 72

        Protobufs are not pickle-able:
        https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
B
barriery 已提交
73 74
                _LOGGER.critical("Failed to generate ChannelData: data_id"
                                 " and error_info cannot be None")
75
                os._exit(-1)
76 77
            datatype = ChannelDataType.ERROR.value
        else:
B
barrierye 已提交
78 79
            if datatype == ChannelDataType.CHANNEL_NPDATA.value:
                ecode, error_info = ChannelData.check_npdata(npdata)
80
                if ecode != ChannelDataEcode.OK.value:
B
barrierye 已提交
81
                    datatype = ChannelDataType.ERROR.value
B
barriery 已提交
82
                    _LOGGER.error("(logid={}) {}".format(data_id, error_info))
B
barrierye 已提交
83 84 85 86
            elif datatype == ChannelDataType.DICT.value:
                ecode, error_info = ChannelData.check_dictdata(dictdata)
                if ecode != ChannelDataEcode.OK.value:
                    datatype = ChannelDataType.ERROR.value
B
barriery 已提交
87
                    _LOGGER.error("(logid={}) {}".format(data_id, error_info))
88
            else:
B
barriery 已提交
89 90
                _LOGGER.critical("(logid={}) datatype not match".format(
                    data_id))
91
                os._exit(-1)
92
        self.datatype = datatype
B
barrierye 已提交
93 94
        self.npdata = npdata
        self.dictdata = dictdata
95 96 97
        self.id = data_id
        self.ecode = ecode
        self.error_info = error_info
B
barrierye 已提交
98
        self.client_need_profile = client_need_profile
B
barrierye 已提交
99
        self.profile_data_set = set()
B
barrierye 已提交
100

B
barrierye 已提交
101
    def add_profile(self, profile_set):
B
barrierye 已提交
102 103
        if self.client_need_profile is False:
            self.client_need_profile = True
B
barrierye 已提交
104
        self.profile_data_set |= profile_set
105

B
barrierye 已提交
106 107 108 109
    @staticmethod
    def check_dictdata(dictdata):
        ecode = ChannelDataEcode.OK.value
        error_info = None
B
barrierye 已提交
110 111 112 113 114
        if isinstance(dictdata, list):
            # batch data
            for sample in dictdata:
                if not isinstance(sample, dict):
                    ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
115 116
                    error_info = "Failed to check data: the type of " \
                            "data must be dict, but get {}.".format(type(sample))
B
barrierye 已提交
117 118 119
                    break
        elif not isinstance(dictdata, dict):
            # batch size = 1
B
barrierye 已提交
120
            ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
121 122
            error_info = "Failed to check data: the type of data must " \
                    "be dict, but get {}.".format(type(dictdata))
B
barrierye 已提交
123
        return ecode, error_info
B
barrierye 已提交
124

B
bug fix  
barriery 已提交
125 126 127 128 129 130 131 132 133 134
    @staticmethod
    def check_batch_npdata(batch):
        ecode = ChannelDataEcode.OK.value
        error_info = None
        for npdata in batch:
            ecode, error_info = ChannelData.check_npdata(npdata)
            if ecode != ChannelDataEcode.OK.value:
                break
        return ecode, error_info

B
barrierye 已提交
135 136
    @staticmethod
    def check_npdata(npdata):
137 138
        ecode = ChannelDataEcode.OK.value
        error_info = None
W
wangjiawei04 已提交
139 140 141 142 143
        if isinstance(npdata, list):
            # batch data
            for sample in npdata:
                if not isinstance(sample, dict):
                    ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
144 145 146
                    error_info = "Failed to check data: the " \
                            "value of data must be dict, but get {}.".format(
                                    type(sample))
W
wangjiawei04 已提交
147 148 149 150
                    break
                for _, value in sample.items():
                    if not isinstance(value, np.ndarray):
                        ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
151 152 153
                        error_info = "Failed to check data: the" \
                                " value of data must be np.ndarray, but get {}.".format(
                                        type(value))
W
wangjiawei04 已提交
154 155 156 157 158 159
                        return ecode, error_info
        elif isinstance(npdata, dict):
            # batch_size = 1
            for _, value in npdata.items():
                if not isinstance(value, np.ndarray):
                    ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
160 161 162
                    error_info = "Failed to check data: the value " \
                            "of data must be np.ndarray, but get {}.".format(
                                    type(value))
W
wangjiawei04 已提交
163 164 165
                    break
        else:
            ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
166 167
            error_info = "Failed to check data: the value of data " \
                    "must be dict, but get {}.".format(type(npdata))
168 169 170 171
        return ecode, error_info

    def parse(self):
        feed = None
B
barrierye 已提交
172 173
        if self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            # return narray
174
            feed = self.npdata
B
barrierye 已提交
175 176 177
        elif self.datatype == ChannelDataType.DICT.value:
            # return dict
            feed = self.dictdata
178
        else:
B
barriery 已提交
179 180
            _LOGGER.critical("Failed to parse channeldata: error " \
                    "type({}) in datatype.".format(self.datatype))
181
            os._exit(-1)
182 183 184 185 186 187 188
        return feed

    def __str__(self):
        return "type[{}], ecode[{}], id[{}]".format(
            ChannelDataType(self.datatype).name, self.ecode, self.id)


B
barrierye 已提交
189
class ProcessChannel(object):
190 191 192 193 194 195 196 197 198
    """ 
    (Process version) The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
B
barriery 已提交
199
    3. Function front support timeout param to make auto-batching.
200 201 202 203 204

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
B
barrierye 已提交
205 206 207 208 209 210 211 212 213 214 215

    There are two buffers and one queue in Channel:

        op_A \                                           / op_D
        op_B - a. input_buf -> b. queue -> c. output_buf - op_E
        op_C /                                           \ op_F
    
    a. In input_buf, the input of multiple predecessor Ops is packed by data ID.
    b. The packed data will be stored in queue.
    c. In order to support multiple successor Ops to retrieve data, output_buf
        maintains the data obtained from queue.
216 217
    """

B
barriery 已提交
218
    def __init__(self, manager, name=None, maxsize=0):
B
barrierye 已提交
219 220 221 222 223 224 225
        # For queue multiprocess: after putting an object on 
        # an empty queue there may be an infinitessimal delay
        # before the queue's :meth:`~Queue.empty`
        # see more:
        # - https://bugs.python.org/issue18277
        # - https://hg.python.org/cpython/rev/860fc6a2bd21
        self._que = manager.Queue(maxsize=maxsize)
226 227
        self._maxsize = maxsize
        self.name = name
228
        self._stop = manager.Value('i', 0)
229 230 231 232

        self._cv = multiprocessing.Condition()

        self._producers = []
B
barrierye 已提交
233
        self._pushed_producer_count = manager.dict()  # {data_id: count}
B
barrierye 已提交
234
        self._input_buf = manager.dict()  # {data_id: {op_name: data}}
235

236
        self._reset_max_cursor = 1000000000000000000
B
barrierye 已提交
237 238 239 240
        self._consumer_cursors = manager.dict()  # {op_name: cursor}
        self._cursor_count = manager.dict()  # {cursor: count}
        self._base_cursor = manager.Value('i', 0)
        self._output_buf = manager.list()
241

B
barriery 已提交
242 243 244
    def get_maxsize(self):
        return self._maxsize

B
barriery 已提交
245 246 247
    def size(self):
        return self._que.qsize()

248 249 250 251
    def get_producers(self):
        return self._producers

    def get_consumers(self):
B
barrierye 已提交
252
        return self._consumer_cursors.keys()
253 254 255 256 257 258 259

    def _log(self, info_str):
        return "[{}] {}".format(self.name, info_str)

    def add_producer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
        if op_name in self._producers:
260
            _LOGGER.critical(
B
barriery 已提交
261 262
                self._log("Failed to add producer: producer({})" \
                        " is already in channel".format(op_name)))
263
            os._exit(-1)
264
        self._producers.append(op_name)
B
barriery 已提交
265
        _LOGGER.debug(self._log("Succ add a producer: {}".format(op_name)))
266 267 268

    def add_consumer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
B
barrierye 已提交
269
        if op_name in self._consumer_cursors:
270
            _LOGGER.critical(
B
barriery 已提交
271 272
                    self._log("Failed to add consumer: consumer({})" \
                            " is already in channel".format(op_name)))
273
            os._exit(-1)
B
barrierye 已提交
274
        self._consumer_cursors[op_name] = 0
275

B
barrierye 已提交
276 277 278
        if self._cursor_count.get(0) is None:
            self._cursor_count[0] = 0
        self._cursor_count[0] += 1
B
barriery 已提交
279
        _LOGGER.debug(self._log("Succ add a consumer: {}".format(op_name)))
280 281

    def push(self, channeldata, op_name=None):
B
barrierye 已提交
282
        _LOGGER.debug(
B
barriery 已提交
283 284
            self._log("(logid={}) Op({}) Pushing data".format(channeldata.id,
                                                              op_name)))
285
        if len(self._producers) == 0:
286
            _LOGGER.critical(
287
                self._log(
B
barriery 已提交
288 289 290
                    "(logid={}) Op({}) Failed to push data: expected number"
                    " of producers to be greater than 0, but the it is 0.".
                    format(channeldata.id, op_name)))
291
            os._exit(-1)
292 293
        elif len(self._producers) == 1:
            with self._cv:
294
                while self._stop.value == 0:
295
                    try:
B
barrierye 已提交
296
                        self._que.put({op_name: channeldata}, timeout=0)
297 298 299
                        break
                    except Queue.Full:
                        self._cv.wait()
300
                if self._stop.value == 1:
B
barrierye 已提交
301
                    raise ChannelStopError()
302
                self._cv.notify_all()
B
barriery 已提交
303
            _LOGGER.debug(
B
barriery 已提交
304 305
                self._log("(logid={}) Op({}) Pushed data into internal queue.".
                          format(channeldata.id, op_name)))
306 307
            return True
        elif op_name is None:
308
            _LOGGER.critical(
309
                self._log(
B
barriery 已提交
310 311 312
                    "(logid={}) Op({}) Failed to push data: there are multiple "
                    "producers, so op_name cannot be None.".format(
                        channeldata.id, op_name)))
313
            os._exit(-1)
314 315 316 317 318

        producer_num = len(self._producers)
        data_id = channeldata.id
        put_data = None
        with self._cv:
B
barrierye 已提交
319 320
            if data_id not in self._input_buf:
                self._input_buf[data_id] = {
321 322 323
                    name: None
                    for name in self._producers
                }
B
barrierye 已提交
324
                self._pushed_producer_count[data_id] = 0
325
            # see: https://docs.python.org/3.6/library/multiprocessing.html?highlight=multiprocess#proxy-objects
B
barrierye 已提交
326 327 328 329 330
            # self._input_buf[data_id][op_name] = channeldata
            tmp_input_buf = self._input_buf[data_id]
            tmp_input_buf[op_name] = channeldata
            self._input_buf[data_id] = tmp_input_buf

B
barrierye 已提交
331
            if self._pushed_producer_count[data_id] + 1 == producer_num:
B
barrierye 已提交
332 333
                put_data = self._input_buf[data_id]
                self._input_buf.pop(data_id)
B
barrierye 已提交
334
                self._pushed_producer_count.pop(data_id)
335
            else:
B
barrierye 已提交
336
                self._pushed_producer_count[data_id] += 1
337 338

            if put_data is None:
B
barrierye 已提交
339
                _LOGGER.debug(
B
barriery 已提交
340 341 342
                    self._log(
                        "(logid={}) Op({}) Pushed data into input_buffer.".
                        format(data_id, op_name)))
343
            else:
344
                while self._stop.value == 0:
345
                    try:
B
barrierye 已提交
346
                        self._que.put(put_data, timeout=0)
347 348 349
                        break
                    except Queue.Empty:
                        self._cv.wait()
350
                if self._stop.value == 1:
B
barrierye 已提交
351
                    raise ChannelStopError()
352

B
barrierye 已提交
353
                _LOGGER.debug(
B
barriery 已提交
354 355 356
                    self._log(
                        "(logid={}) Op({}) Pushed data into internal_queue.".
                        format(data_id, op_name)))
357 358 359
            self._cv.notify_all()
        return True

B
barriery 已提交
360
    def front(self, op_name=None, timeout=None):
B
barriery 已提交
361
        _LOGGER.debug(
B
barriery 已提交
362 363
            self._log("Op({}) Getting data[?]; timeout(s)={}".format(op_name,
                                                                     timeout)))
B
barriery 已提交
364
        endtime = None
B
bug fix  
barriery 已提交
365 366 367 368 369
        if timeout is not None:
            if timeout <= 0:
                timeout = None
            else:
                endtime = _time() + timeout
B
barriery 已提交
370

B
barrierye 已提交
371
        if len(self._consumer_cursors) == 0:
372
            _LOGGER.critical(
373
                self._log(
B
barriery 已提交
374 375
                    "Op({}) Failed to get data: expected number of consumers to be " \
                            "greater than 0, but the it is 0.".format(op_name)))
376
            os._exit(-1)
B
barrierye 已提交
377
        elif len(self._consumer_cursors) == 1:
378 379
            resp = None
            with self._cv:
380
                while self._stop.value == 0 and resp is None:
381
                    try:
B
barrierye 已提交
382
                        resp = self._que.get(timeout=0)
383 384
                        break
                    except Queue.Empty:
B
barriery 已提交
385 386 387
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
388
                                _LOGGER.debug(
B
barriery 已提交
389 390
                                    self._log("Op({}) Failed to get data: "
                                              "timeout".format(op_name)))
B
barriery 已提交
391 392 393 394
                                raise ChannelTimeoutError()
                            self._cv.wait(remaining)
                        else:
                            self._cv.wait()
395
                if self._stop.value == 1:
B
barrierye 已提交
396
                    raise ChannelStopError()
B
barriery 已提交
397
            _LOGGER.debug(
B
barriery 已提交
398 399
                self._log("(logid={}) Op({}) Got data".format(resp.values()[0]
                                                              .id, op_name)))
400 401
            return resp
        elif op_name is None:
402
            _LOGGER.critical(
403
                self._log(
B
barriery 已提交
404 405
                    "Op({}) Failed to get data: there are multiple consumers, "
                    "so op_name cannot be None.".format(op_name)))
406
            os._exit(-1)
407

B
barrierye 已提交
408 409 410 411 412 413 414 415 416 417
        # In output_buf, different Ops (according to op_name) have different
        # cursors. In addition, there is a base_cursor. Their difference is
        # the data_idx to be taken by the corresponding Op at the current
        # time:    data_idx = consumer_cursor - base_cursor
        # 
        #            base_cursor    consumer_B_cursor (data_idx: 3)
        #                 |                       |
        # output_buf: | data0 | data1 | data2 | data3 |
        #                 |
        #   consumer_A_cursor (data_idx: 0)
418
        with self._cv:
B
barrierye 已提交
419 420
            # When the data required by the current Op is not in output_buf,
            # it is necessary to obtain a data from queue and add it to output_buf.
421
            while self._stop.value == 0 and self._consumer_cursors[
B
barrierye 已提交
422
                    op_name] - self._base_cursor.value >= len(self._output_buf):
423
                try:
B
barrierye 已提交
424
                    channeldata = self._que.get(timeout=0)
B
barrierye 已提交
425
                    self._output_buf.append(channeldata)
B
barriery 已提交
426
                    _LOGGER.debug(
B
barriery 已提交
427 428 429
                        self._log(
                            "(logid={}) Op({}) Pop ready item into output_buffer".
                            format(channeldata.values()[0].id, op_name)))
430 431
                    break
                except Queue.Empty:
B
barriery 已提交
432 433 434
                    if timeout is not None:
                        remaining = endtime - _time()
                        if remaining <= 0.0:
B
barriery 已提交
435
                            _LOGGER.debug(
B
barriery 已提交
436 437
                                self._log("Op({}) Failed to get data: timeout".
                                          format(op_name)))
B
barriery 已提交
438 439 440 441
                            raise ChannelTimeoutError()
                        self._cv.wait(remaining)
                    else:
                        self._cv.wait()
442
            if self._stop.value == 1:
B
barrierye 已提交
443
                raise ChannelStopError()
444

B
barrierye 已提交
445 446 447 448
            consumer_cursor = self._consumer_cursors[op_name]
            base_cursor = self._base_cursor.value
            data_idx = consumer_cursor - base_cursor
            resp = self._output_buf[data_idx]
449

B
barrierye 已提交
450 451 452 453 454 455 456 457
            self._cursor_count[consumer_cursor] -= 1
            if consumer_cursor == base_cursor and self._cursor_count[
                    consumer_cursor] == 0:
                # When all the different Ops get the data that data_idx points
                # to, pop the data from output_buf.
                self._cursor_count.pop(consumer_cursor)
                self._output_buf.pop(0)
                self._base_cursor.value += 1
458 459
                # to avoid cursor overflow
                if self._base_cursor.value >= self._reset_max_cursor:
B
barriery 已提交
460
                    _LOGGER.info(self._log("Reset cursor in Channel"))
461 462 463 464 465 466 467 468 469 470
                    self._base_cursor.value -= self._reset_max_cursor
                    for name in self._consumer_cursors.keys():
                        self._consumer_cursors[name] -= self._reset_max_cursor
                    cursor_count_tmp = {
                        cursor - self._reset_max_cursor: count
                        for cursor, count in self._cursor_count.copy().items()
                    }
                    self._cursor_count.clear()
                    for cursor, count in cursor_count_tmp.items():
                        self._cursor_count[cursor] = count
B
barrierye 已提交
471 472 473 474 475 476

            self._consumer_cursors[op_name] += 1
            new_consumer_cursor = self._consumer_cursors[op_name]
            if self._cursor_count.get(new_consumer_cursor) is None:
                self._cursor_count[new_consumer_cursor] = 0
            self._cursor_count[new_consumer_cursor] += 1
477

478 479
            self._cv.notify_all()

B
barriery 已提交
480
        _LOGGER.debug(
B
barriery 已提交
481 482
            self._log("(logid={}) Op({}) Got data from output_buffer".format(
                resp.values()[0].id, op_name)))
B
barriery 已提交
483
        return resp
484 485

    def stop(self):
486
        _LOGGER.info(self._log("stop."))
487
        self._stop.value = 1
B
barrierye 已提交
488 489
        with self._cv:
            self._cv.notify_all()
490 491 492 493 494 495 496 497 498 499 500 501


class ThreadChannel(Queue.Queue):
    """ 
    (Thread version)The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
B
barriery 已提交
502
    3. Function front support timeout param to make auto-batching.
503 504 505 506 507

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
B
barrierye 已提交
508 509 510 511 512 513 514 515 516 517 518

    There are two buffers and one queue in Channel:

        op_A \                                           / op_D
        op_B - a. input_buf -> b. queue -> c. output_buf - op_E
        op_C /                                           \ op_F
    
    a. In input_buf, the input of multiple predecessor Ops is packed by data ID.
    b. The packed data will be stored in queue.
    c. In order to support multiple successor Ops to retrieve data, output_buf
        maintains the data obtained from queue.
519 520
    """

B
barriery 已提交
521
    def __init__(self, name=None, maxsize=-1):
522 523 524 525 526 527 528 529
        Queue.Queue.__init__(self, maxsize=maxsize)
        self._maxsize = maxsize
        self.name = name
        self._stop = False

        self._cv = threading.Condition()

        self._producers = []
B
barrierye 已提交
530
        self._pushed_producer_count = {}  # {data_id: count}
B
barrierye 已提交
531
        self._input_buf = {}  # {data_id: {op_name: data}}
532

533
        self._reset_max_cursor = 1000000000000000000
B
barrierye 已提交
534 535 536 537
        self._consumer_cursors = {}  # {op_name: idx}
        self._cursor_count = {}  # {cursor: count}
        self._base_cursor = 0
        self._output_buf = []
538

B
barriery 已提交
539 540 541
    def get_maxsize(self):
        return self._maxsize

B
barriery 已提交
542 543 544
    def size(self):
        return self.qsize()

545 546 547 548
    def get_producers(self):
        return self._producers

    def get_consumers(self):
B
barrierye 已提交
549
        return self._consumer_cursors.keys()
550 551 552 553 554 555 556

    def _log(self, info_str):
        return "[{}] {}".format(self.name, info_str)

    def add_producer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
        if op_name in self._producers:
557
            _LOGGER.critical(
B
barriery 已提交
558 559
                self._log("Failed to add producer: producer({}) is "
                          "already in channel".format(op_name)))
560
            os._exit(-1)
561
        self._producers.append(op_name)
B
barriery 已提交
562
        _LOGGER.debug(self._log("Succ add a producer: {}".format(op_name)))
563 564 565

    def add_consumer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
B
barrierye 已提交
566
        if op_name in self._consumer_cursors:
567
            _LOGGER.critical(
B
barriery 已提交
568 569
                self._log("Failed to add consumer: consumer({}) is "
                          "already in channel".format(op_name)))
570
            os._exit(-1)
B
barrierye 已提交
571
        self._consumer_cursors[op_name] = 0
572

B
barrierye 已提交
573 574 575
        if self._cursor_count.get(0) is None:
            self._cursor_count[0] = 0
        self._cursor_count[0] += 1
B
barriery 已提交
576
        _LOGGER.debug(self._log("Succ add a consumer: {}".format(op_name)))
577 578

    def push(self, channeldata, op_name=None):
B
barrierye 已提交
579
        _LOGGER.debug(
B
barriery 已提交
580 581
            self._log("(logid={}) Op({}) Pushing data".format(channeldata.id,
                                                              op_name)))
582
        if len(self._producers) == 0:
583
            _LOGGER.critical(
584
                self._log(
B
barriery 已提交
585 586 587
                    "(logid={}) Op({}) Failed to push data: expected number of "
                    "producers to be greater than 0, but the it is 0.".format(
                        channeldata.id, op_name)))
588
            os._exit(-1)
589 590 591 592
        elif len(self._producers) == 1:
            with self._cv:
                while self._stop is False:
                    try:
B
barrierye 已提交
593
                        self.put({op_name: channeldata}, timeout=0)
594 595 596
                        break
                    except Queue.Full:
                        self._cv.wait()
B
barrierye 已提交
597 598
                if self._stop:
                    raise ChannelStopError()
599
                self._cv.notify_all()
B
barriery 已提交
600
            _LOGGER.debug(
B
barriery 已提交
601 602
                self._log("(logid={}) Op({}) Pushed data into internal_queue.".
                          format(channeldata.id, op_name)))
603 604
            return True
        elif op_name is None:
605
            _LOGGER.critical(
606
                self._log(
B
barriery 已提交
607 608 609
                    "(logid={}) Op({}) Failed to push data: there are multiple"
                    " producers, so op_name cannot be None.".format(
                        channeldata.id, op_name)))
610
            os._exit(-1)
611 612 613 614 615

        producer_num = len(self._producers)
        data_id = channeldata.id
        put_data = None
        with self._cv:
B
barrierye 已提交
616 617
            if data_id not in self._input_buf:
                self._input_buf[data_id] = {
618 619 620
                    name: None
                    for name in self._producers
                }
B
barrierye 已提交
621
                self._pushed_producer_count[data_id] = 0
B
barrierye 已提交
622
            self._input_buf[data_id][op_name] = channeldata
B
barrierye 已提交
623
            if self._pushed_producer_count[data_id] + 1 == producer_num:
B
barrierye 已提交
624 625
                put_data = self._input_buf[data_id]
                self._input_buf.pop(data_id)
B
barrierye 已提交
626
                self._pushed_producer_count.pop(data_id)
627
            else:
B
barrierye 已提交
628
                self._pushed_producer_count[data_id] += 1
629 630

            if put_data is None:
B
barrierye 已提交
631
                _LOGGER.debug(
B
barriery 已提交
632 633 634
                    self._log(
                        "(logid={}) Op({}) Pushed data into input_buffer.".
                        format(data_id, op_name)))
635 636 637 638 639 640 641
            else:
                while self._stop is False:
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
B
barrierye 已提交
642 643
                if self._stop:
                    raise ChannelStopError()
644

B
barrierye 已提交
645
                _LOGGER.debug(
B
barriery 已提交
646 647 648
                    self._log(
                        "(logid={}) Op({}) Pushed data into internal_queue.".
                        format(data_id, op_name)))
649 650 651
            self._cv.notify_all()
        return True

B
barriery 已提交
652
    def front(self, op_name=None, timeout=None):
B
barriery 已提交
653
        _LOGGER.debug(
B
barriery 已提交
654 655
            self._log("Op({}) Getting data[?]; timeout(s)={}".format(op_name,
                                                                     timeout)))
B
barriery 已提交
656
        endtime = None
B
bug fix  
barriery 已提交
657 658 659 660 661
        if timeout is not None:
            if timeout <= 0:
                timeout = None
            else:
                endtime = _time() + timeout
B
barriery 已提交
662

B
barrierye 已提交
663
        if len(self._consumer_cursors) == 0:
664
            _LOGGER.critical(
665
                self._log(
B
barriery 已提交
666 667
                    "Op({}) Failed to get data: expected number of consumers to be "
                    "greater than 0, but the it is 0.".format(op_name)))
668
            os._exit(-1)
B
barrierye 已提交
669
        elif len(self._consumer_cursors) == 1:
670 671 672 673 674 675 676
            resp = None
            with self._cv:
                while self._stop is False and resp is None:
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
B
barriery 已提交
677 678 679
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
680
                                _LOGGER.debug(
B
barriery 已提交
681 682 683
                                    self._log(
                                        "Op({}) Failed to get data: timeout".
                                        format(op_name)))
B
barriery 已提交
684 685 686 687
                                raise ChannelTimeoutError()
                            self._cv.wait(remaining)
                        else:
                            self._cv.wait()
B
barrierye 已提交
688 689
                if self._stop:
                    raise ChannelStopError()
B
barrierye 已提交
690
            _LOGGER.debug(
B
barriery 已提交
691 692
                self._log("(logid={}) Op({}) Got data".format(resp.values()[0]
                                                              .id, op_name)))
693 694
            return resp
        elif op_name is None:
695
            _LOGGER.critical(
B
barriery 已提交
696 697 698
                self._log("Op({}) Failed to get data: there are multiple "
                          "consumers, so op_name cannot be None.".format(
                              op_name)))
699
            os._exit(-1)
700

B
barrierye 已提交
701 702 703 704 705 706 707 708 709 710
        # In output_buf, different Ops (according to op_name) have different
        # cursors. In addition, there is a base_cursor. Their difference is
        # the data_idx to be taken by the corresponding Op at the current
        # time:    data_idx = consumer_cursor - base_cursor
        # 
        #            base_cursor    consumer_B_cursor (data_idx: 3)
        #                 |                       |
        # output_buf: | data0 | data1 | data2 | data3 |
        #                 |
        #   consumer_A_cursor (data_idx: 0)
711
        with self._cv:
B
barrierye 已提交
712 713 714 715
            # When the data required by the current Op is not in output_buf,
            # it is necessary to obtain a data from queue and add it to output_buf.
            while self._stop is False and self._consumer_cursors[
                    op_name] - self._base_cursor >= len(self._output_buf):
716 717
                try:
                    channeldata = self.get(timeout=0)
B
barrierye 已提交
718
                    self._output_buf.append(channeldata)
B
barriery 已提交
719
                    _LOGGER.debug(
B
barriery 已提交
720 721 722
                        self._log(
                            "(logid={}) Op({}) Pop ready item into output_buffer".
                            format(channeldata.values()[0].id, op_name)))
723 724
                    break
                except Queue.Empty:
B
barriery 已提交
725 726 727
                    if timeout is not None:
                        remaining = endtime - _time()
                        if remaining <= 0.0:
B
barriery 已提交
728
                            _LOGGER.debug(
B
barriery 已提交
729 730
                                self._log("Op({}) Failed to get data: timeout".
                                          format(op_name)))
B
barriery 已提交
731 732 733 734
                            raise ChannelTimeoutError()
                        self._cv.wait(remaining)
                    else:
                        self._cv.wait()
B
barrierye 已提交
735 736
            if self._stop:
                raise ChannelStopError()
737

B
barrierye 已提交
738 739 740
            consumer_cursor = self._consumer_cursors[op_name]
            base_cursor = self._base_cursor
            data_idx = consumer_cursor - base_cursor
B
barrierye 已提交
741 742

            resp = None
743

B
barrierye 已提交
744 745 746 747 748 749
            self._cursor_count[consumer_cursor] -= 1
            if consumer_cursor == base_cursor and self._cursor_count[
                    consumer_cursor] == 0:
                # When all the different Ops get the data that data_idx points
                # to, pop the data from output_buf.
                self._cursor_count.pop(consumer_cursor)
B
barrierye 已提交
750
                resp = self._output_buf.pop(0)
B
barrierye 已提交
751
                self._base_cursor += 1
752 753
                # to avoid cursor overflow
                if self._base_cursor >= self._reset_max_cursor:
B
barriery 已提交
754
                    _LOGGER.info(self._log("Reset cursor in Channel"))
755 756 757 758 759 760 761
                    self._base_cursor -= self._reset_max_cursor
                    for name in self._consumer_cursors:
                        self._consumer_cursors[name] -= self._reset_max_cursor
                    self._cursor_count = {
                        cursor - self._reset_max_cursor: count
                        for cursor, count in self._cursor_count.items()
                    }
B
barrierye 已提交
762 763
            else:
                resp = copy.deepcopy(self._output_buf[data_idx])
B
barrierye 已提交
764 765 766 767 768 769

            self._consumer_cursors[op_name] += 1
            new_consumer_cursor = self._consumer_cursors[op_name]
            if self._cursor_count.get(new_consumer_cursor) is None:
                self._cursor_count[new_consumer_cursor] = 0
            self._cursor_count[new_consumer_cursor] += 1
770 771 772

            self._cv.notify_all()

B
barriery 已提交
773
        _LOGGER.debug(
B
barriery 已提交
774 775
            self._log("(logid={}) Op({}) Got data from output_buffer".format(
                resp.values()[0].id, op_name)))
B
barrierye 已提交
776
        return resp
777 778

    def stop(self):
779
        _LOGGER.info(self._log("stop."))
780
        self._stop = True
B
barrierye 已提交
781 782 783
        with self._cv:
            self._cv.notify_all()

B
barriery 已提交
784

B
barriery 已提交
785 786 787
class ChannelTimeoutError(RuntimeError):
    def __init__(self):
        pass
B
barrierye 已提交
788

B
barriery 已提交
789

B
barrierye 已提交
790 791 792
class ChannelStopError(RuntimeError):
    def __init__(self):
        pass