channel.py 31.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
D
dongdaxiang 已提交
16 17 18 19 20 21 22 23 24 25
import threading
import multiprocessing
import multiprocessing.queues
import sys
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
26 27 28
import numpy as np
import logging
import enum
29
import os
30
import copy
D
dongdaxiang 已提交
31

32
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
33

D
dongdaxiang 已提交
34 35 36 37 38 39 40

class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
    RPC_PACKAGE_ERROR = 4
B
barrierye 已提交
41
    CLIENT_ERROR = 5
B
barrierye 已提交
42
    CLOSED_ERROR = 6
B
barriery 已提交
43 44
    NO_SERVICE = 7
    UNKNOW = 8
D
dongdaxiang 已提交
45 46 47 48 49 50 51 52


class ChannelDataType(enum.Enum):
    DICT = 0
    CHANNEL_NPDATA = 1
    ERROR = 2


53 54 55 56
class ChannelData(object):
    def __init__(self,
                 datatype=None,
                 npdata=None,
B
barrierye 已提交
57
                 dictdata=None,
58 59
                 data_id=None,
                 ecode=None,
B
barrierye 已提交
60 61
                 error_info=None,
                 client_need_profile=False):
62 63 64
        '''
        There are several ways to use it:
        
B
barrierye 已提交
65 66 67
        1. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
        2. ChannelData(ChannelDataType.DICT.value, dictdata, data_id)
        3. ChannelData(ecode, error_info, data_id)
68 69 70 71 72 73

        Protobufs are not pickle-able:
        https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
B
barriery 已提交
74 75
                _LOGGER.critical("Failed to generate ChannelData: data_id"
                                 " and error_info cannot be None")
76
                os._exit(-1)
77 78
            datatype = ChannelDataType.ERROR.value
        else:
B
barrierye 已提交
79 80
            if datatype == ChannelDataType.CHANNEL_NPDATA.value:
                ecode, error_info = ChannelData.check_npdata(npdata)
81
                if ecode != ChannelDataEcode.OK.value:
B
barrierye 已提交
82
                    datatype = ChannelDataType.ERROR.value
B
barriery 已提交
83
                    _LOGGER.error("(logid={}) {}".format(data_id, error_info))
B
barrierye 已提交
84 85 86 87
            elif datatype == ChannelDataType.DICT.value:
                ecode, error_info = ChannelData.check_dictdata(dictdata)
                if ecode != ChannelDataEcode.OK.value:
                    datatype = ChannelDataType.ERROR.value
B
barriery 已提交
88
                    _LOGGER.error("(logid={}) {}".format(data_id, error_info))
89
            else:
B
barriery 已提交
90 91
                _LOGGER.critical("(logid={}) datatype not match".format(
                    data_id))
92
                os._exit(-1)
93
        self.datatype = datatype
B
barrierye 已提交
94 95
        self.npdata = npdata
        self.dictdata = dictdata
96 97 98
        self.id = data_id
        self.ecode = ecode
        self.error_info = error_info
B
barrierye 已提交
99
        self.client_need_profile = client_need_profile
B
barrierye 已提交
100
        self.profile_data_set = set()
B
barrierye 已提交
101

B
barrierye 已提交
102
    def add_profile(self, profile_set):
B
barrierye 已提交
103 104
        if self.client_need_profile is False:
            self.client_need_profile = True
B
barrierye 已提交
105
        self.profile_data_set |= profile_set
106

B
barrierye 已提交
107 108 109 110
    @staticmethod
    def check_dictdata(dictdata):
        ecode = ChannelDataEcode.OK.value
        error_info = None
B
barrierye 已提交
111 112 113 114 115
        if isinstance(dictdata, list):
            # batch data
            for sample in dictdata:
                if not isinstance(sample, dict):
                    ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
116 117
                    error_info = "Failed to check data: the type of " \
                            "data must be dict, but get {}.".format(type(sample))
B
barrierye 已提交
118 119 120
                    break
        elif not isinstance(dictdata, dict):
            # batch size = 1
B
barrierye 已提交
121
            ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
122 123
            error_info = "Failed to check data: the type of data must " \
                    "be dict, but get {}.".format(type(dictdata))
B
barrierye 已提交
124
        return ecode, error_info
B
barrierye 已提交
125

B
bug fix  
barriery 已提交
126 127 128 129 130 131 132 133 134 135
    @staticmethod
    def check_batch_npdata(batch):
        ecode = ChannelDataEcode.OK.value
        error_info = None
        for npdata in batch:
            ecode, error_info = ChannelData.check_npdata(npdata)
            if ecode != ChannelDataEcode.OK.value:
                break
        return ecode, error_info

B
barrierye 已提交
136 137
    @staticmethod
    def check_npdata(npdata):
138 139
        ecode = ChannelDataEcode.OK.value
        error_info = None
W
wangjiawei04 已提交
140 141 142 143 144
        if isinstance(npdata, list):
            # batch data
            for sample in npdata:
                if not isinstance(sample, dict):
                    ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
145 146 147
                    error_info = "Failed to check data: the " \
                            "value of data must be dict, but get {}.".format(
                                    type(sample))
W
wangjiawei04 已提交
148 149 150 151
                    break
                for _, value in sample.items():
                    if not isinstance(value, np.ndarray):
                        ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
152 153 154
                        error_info = "Failed to check data: the" \
                                " value of data must be np.ndarray, but get {}.".format(
                                        type(value))
W
wangjiawei04 已提交
155 156 157 158 159 160
                        return ecode, error_info
        elif isinstance(npdata, dict):
            # batch_size = 1
            for _, value in npdata.items():
                if not isinstance(value, np.ndarray):
                    ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
161 162 163
                    error_info = "Failed to check data: the value " \
                            "of data must be np.ndarray, but get {}.".format(
                                    type(value))
W
wangjiawei04 已提交
164 165 166
                    break
        else:
            ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
167 168
            error_info = "Failed to check data: the value of data " \
                    "must be dict, but get {}.".format(type(npdata))
169 170 171 172
        return ecode, error_info

    def parse(self):
        feed = None
B
barrierye 已提交
173 174
        if self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            # return narray
175
            feed = self.npdata
B
barrierye 已提交
176 177 178
        elif self.datatype == ChannelDataType.DICT.value:
            # return dict
            feed = self.dictdata
179
        else:
B
barriery 已提交
180 181
            _LOGGER.critical("Failed to parse channeldata: error " \
                    "type({}) in datatype.".format(self.datatype))
182
            os._exit(-1)
183 184
        return feed

185 186 187 188 189 190 191 192
    def __cmp__(self, other):
        if self.id < other.id:
            return -1
        elif self.id == other.id:
            return 0
        else:
            return 1

193 194 195 196 197
    def __str__(self):
        return "type[{}], ecode[{}], id[{}]".format(
            ChannelDataType(self.datatype).name, self.ecode, self.id)


B
barrierye 已提交
198
class ProcessChannel(object):
199 200 201 202 203 204 205 206 207
    """ 
    (Process version) The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
B
barriery 已提交
208
    3. Function front support timeout param to make auto-batching.
209 210 211 212 213

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
B
barrierye 已提交
214 215 216 217 218 219 220 221 222 223 224

    There are two buffers and one queue in Channel:

        op_A \                                           / op_D
        op_B - a. input_buf -> b. queue -> c. output_buf - op_E
        op_C /                                           \ op_F
    
    a. In input_buf, the input of multiple predecessor Ops is packed by data ID.
    b. The packed data will be stored in queue.
    c. In order to support multiple successor Ops to retrieve data, output_buf
        maintains the data obtained from queue.
225 226
    """

B
barriery 已提交
227
    def __init__(self, manager, name=None, maxsize=0):
B
barrierye 已提交
228 229 230 231 232 233
        # For queue multiprocess: after putting an object on 
        # an empty queue there may be an infinitessimal delay
        # before the queue's :meth:`~Queue.empty`
        # see more:
        # - https://bugs.python.org/issue18277
        # - https://hg.python.org/cpython/rev/860fc6a2bd21
234
        self._que = manager.PriorityQueue(maxsize=maxsize)
235 236
        self._maxsize = maxsize
        self.name = name
237
        self._stop = manager.Value('i', 0)
238 239 240 241

        self._cv = multiprocessing.Condition()

        self._producers = []
B
barrierye 已提交
242
        self._pushed_producer_count = manager.dict()  # {data_id: count}
B
barrierye 已提交
243
        self._input_buf = manager.dict()  # {data_id: {op_name: data}}
244

245
        self._reset_max_cursor = 1000000000000000000
B
barrierye 已提交
246 247 248 249
        self._consumer_cursors = manager.dict()  # {op_name: cursor}
        self._cursor_count = manager.dict()  # {cursor: count}
        self._base_cursor = manager.Value('i', 0)
        self._output_buf = manager.list()
250

B
barriery 已提交
251 252 253
    def get_maxsize(self):
        return self._maxsize

B
barriery 已提交
254 255 256
    def size(self):
        return self._que.qsize()

257 258 259 260
    def get_producers(self):
        return self._producers

    def get_consumers(self):
B
barrierye 已提交
261
        return self._consumer_cursors.keys()
262 263 264 265 266 267 268

    def _log(self, info_str):
        return "[{}] {}".format(self.name, info_str)

    def add_producer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
        if op_name in self._producers:
269
            _LOGGER.critical(
B
barriery 已提交
270 271
                self._log("Failed to add producer: producer({})" \
                        " is already in channel".format(op_name)))
272
            os._exit(-1)
273
        self._producers.append(op_name)
B
barriery 已提交
274
        _LOGGER.debug(self._log("Succ add a producer: {}".format(op_name)))
275 276 277

    def add_consumer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
B
barrierye 已提交
278
        if op_name in self._consumer_cursors:
279
            _LOGGER.critical(
B
barriery 已提交
280 281
                    self._log("Failed to add consumer: consumer({})" \
                            " is already in channel".format(op_name)))
282
            os._exit(-1)
B
barrierye 已提交
283
        self._consumer_cursors[op_name] = 0
284

B
barrierye 已提交
285 286 287
        if self._cursor_count.get(0) is None:
            self._cursor_count[0] = 0
        self._cursor_count[0] += 1
B
barriery 已提交
288
        _LOGGER.debug(self._log("Succ add a consumer: {}".format(op_name)))
289 290

    def push(self, channeldata, op_name=None):
B
barrierye 已提交
291
        _LOGGER.debug(
B
barriery 已提交
292 293
            self._log("(logid={}) Op({}) Pushing data".format(channeldata.id,
                                                              op_name)))
294
        if len(self._producers) == 0:
295
            _LOGGER.critical(
296
                self._log(
B
barriery 已提交
297 298 299
                    "(logid={}) Op({}) Failed to push data: expected number"
                    " of producers to be greater than 0, but the it is 0.".
                    format(channeldata.id, op_name)))
300
            os._exit(-1)
301 302
        elif len(self._producers) == 1:
            with self._cv:
303
                while self._stop.value == 0:
304
                    try:
B
barrierye 已提交
305
                        self._que.put({op_name: channeldata}, timeout=0)
306 307 308
                        break
                    except Queue.Full:
                        self._cv.wait()
309
                if self._stop.value == 1:
B
barrierye 已提交
310
                    raise ChannelStopError()
311
                self._cv.notify_all()
B
barriery 已提交
312
            _LOGGER.debug(
B
barriery 已提交
313 314
                self._log("(logid={}) Op({}) Pushed data into internal queue.".
                          format(channeldata.id, op_name)))
315 316
            return True
        elif op_name is None:
317
            _LOGGER.critical(
318
                self._log(
B
barriery 已提交
319 320 321
                    "(logid={}) Op({}) Failed to push data: there are multiple "
                    "producers, so op_name cannot be None.".format(
                        channeldata.id, op_name)))
322
            os._exit(-1)
323 324 325 326 327

        producer_num = len(self._producers)
        data_id = channeldata.id
        put_data = None
        with self._cv:
B
barrierye 已提交
328 329
            if data_id not in self._input_buf:
                self._input_buf[data_id] = {
330 331 332
                    name: None
                    for name in self._producers
                }
B
barrierye 已提交
333
                self._pushed_producer_count[data_id] = 0
334
            # see: https://docs.python.org/3.6/library/multiprocessing.html?highlight=multiprocess#proxy-objects
B
barrierye 已提交
335 336 337 338 339
            # self._input_buf[data_id][op_name] = channeldata
            tmp_input_buf = self._input_buf[data_id]
            tmp_input_buf[op_name] = channeldata
            self._input_buf[data_id] = tmp_input_buf

B
barrierye 已提交
340
            if self._pushed_producer_count[data_id] + 1 == producer_num:
B
barrierye 已提交
341 342
                put_data = self._input_buf[data_id]
                self._input_buf.pop(data_id)
B
barrierye 已提交
343
                self._pushed_producer_count.pop(data_id)
344
            else:
B
barrierye 已提交
345
                self._pushed_producer_count[data_id] += 1
346 347

            if put_data is None:
B
barrierye 已提交
348
                _LOGGER.debug(
B
barriery 已提交
349 350 351
                    self._log(
                        "(logid={}) Op({}) Pushed data into input_buffer.".
                        format(data_id, op_name)))
352
            else:
353
                while self._stop.value == 0:
354
                    try:
B
barrierye 已提交
355
                        self._que.put(put_data, timeout=0)
356 357 358
                        break
                    except Queue.Empty:
                        self._cv.wait()
359
                if self._stop.value == 1:
B
barrierye 已提交
360
                    raise ChannelStopError()
361

B
barrierye 已提交
362
                _LOGGER.debug(
B
barriery 已提交
363 364 365
                    self._log(
                        "(logid={}) Op({}) Pushed data into internal_queue.".
                        format(data_id, op_name)))
366 367 368
            self._cv.notify_all()
        return True

B
barriery 已提交
369
    def front(self, op_name=None, timeout=None):
B
barriery 已提交
370
        _LOGGER.debug(
B
barriery 已提交
371 372
            self._log("Op({}) Getting data[?]; timeout(s)={}".format(op_name,
                                                                     timeout)))
B
barriery 已提交
373
        endtime = None
B
bug fix  
barriery 已提交
374 375 376 377 378
        if timeout is not None:
            if timeout <= 0:
                timeout = None
            else:
                endtime = _time() + timeout
B
barriery 已提交
379

B
barrierye 已提交
380
        if len(self._consumer_cursors) == 0:
381
            _LOGGER.critical(
382
                self._log(
B
barriery 已提交
383 384
                    "Op({}) Failed to get data: expected number of consumers to be " \
                            "greater than 0, but the it is 0.".format(op_name)))
385
            os._exit(-1)
B
barrierye 已提交
386
        elif len(self._consumer_cursors) == 1:
387 388
            resp = None
            with self._cv:
389
                while self._stop.value == 0 and resp is None:
390
                    try:
B
barrierye 已提交
391
                        resp = self._que.get(timeout=0)
392 393
                        break
                    except Queue.Empty:
B
barriery 已提交
394 395 396
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
397
                                _LOGGER.debug(
B
barriery 已提交
398 399
                                    self._log("Op({}) Failed to get data: "
                                              "timeout".format(op_name)))
B
barriery 已提交
400 401 402 403
                                raise ChannelTimeoutError()
                            self._cv.wait(remaining)
                        else:
                            self._cv.wait()
404
                if self._stop.value == 1:
B
barrierye 已提交
405
                    raise ChannelStopError()
B
barriery 已提交
406
            _LOGGER.debug(
B
barriery 已提交
407 408
                self._log("(logid={}) Op({}) Got data".format(resp.values()[0]
                                                              .id, op_name)))
409 410
            return resp
        elif op_name is None:
411
            _LOGGER.critical(
412
                self._log(
B
barriery 已提交
413 414
                    "Op({}) Failed to get data: there are multiple consumers, "
                    "so op_name cannot be None.".format(op_name)))
415
            os._exit(-1)
416

B
barrierye 已提交
417 418 419 420 421 422 423 424 425 426
        # In output_buf, different Ops (according to op_name) have different
        # cursors. In addition, there is a base_cursor. Their difference is
        # the data_idx to be taken by the corresponding Op at the current
        # time:    data_idx = consumer_cursor - base_cursor
        # 
        #            base_cursor    consumer_B_cursor (data_idx: 3)
        #                 |                       |
        # output_buf: | data0 | data1 | data2 | data3 |
        #                 |
        #   consumer_A_cursor (data_idx: 0)
427
        with self._cv:
B
barrierye 已提交
428 429
            # When the data required by the current Op is not in output_buf,
            # it is necessary to obtain a data from queue and add it to output_buf.
430
            while self._stop.value == 0 and self._consumer_cursors[
B
barrierye 已提交
431
                    op_name] - self._base_cursor.value >= len(self._output_buf):
432
                try:
B
barrierye 已提交
433
                    channeldata = self._que.get(timeout=0)
B
barrierye 已提交
434
                    self._output_buf.append(channeldata)
B
barriery 已提交
435
                    _LOGGER.debug(
B
barriery 已提交
436 437 438
                        self._log(
                            "(logid={}) Op({}) Pop ready item into output_buffer".
                            format(channeldata.values()[0].id, op_name)))
439 440
                    break
                except Queue.Empty:
B
barriery 已提交
441 442 443
                    if timeout is not None:
                        remaining = endtime - _time()
                        if remaining <= 0.0:
B
barriery 已提交
444
                            _LOGGER.debug(
B
barriery 已提交
445 446
                                self._log("Op({}) Failed to get data: timeout".
                                          format(op_name)))
B
barriery 已提交
447 448 449 450
                            raise ChannelTimeoutError()
                        self._cv.wait(remaining)
                    else:
                        self._cv.wait()
451
            if self._stop.value == 1:
B
barrierye 已提交
452
                raise ChannelStopError()
453

B
barrierye 已提交
454 455 456 457
            consumer_cursor = self._consumer_cursors[op_name]
            base_cursor = self._base_cursor.value
            data_idx = consumer_cursor - base_cursor
            resp = self._output_buf[data_idx]
458

B
barrierye 已提交
459 460 461 462 463 464 465 466
            self._cursor_count[consumer_cursor] -= 1
            if consumer_cursor == base_cursor and self._cursor_count[
                    consumer_cursor] == 0:
                # When all the different Ops get the data that data_idx points
                # to, pop the data from output_buf.
                self._cursor_count.pop(consumer_cursor)
                self._output_buf.pop(0)
                self._base_cursor.value += 1
467 468
                # to avoid cursor overflow
                if self._base_cursor.value >= self._reset_max_cursor:
B
barriery 已提交
469
                    _LOGGER.info(self._log("Reset cursor in Channel"))
470 471 472 473 474 475 476 477 478 479
                    self._base_cursor.value -= self._reset_max_cursor
                    for name in self._consumer_cursors.keys():
                        self._consumer_cursors[name] -= self._reset_max_cursor
                    cursor_count_tmp = {
                        cursor - self._reset_max_cursor: count
                        for cursor, count in self._cursor_count.copy().items()
                    }
                    self._cursor_count.clear()
                    for cursor, count in cursor_count_tmp.items():
                        self._cursor_count[cursor] = count
B
barrierye 已提交
480 481 482 483 484 485

            self._consumer_cursors[op_name] += 1
            new_consumer_cursor = self._consumer_cursors[op_name]
            if self._cursor_count.get(new_consumer_cursor) is None:
                self._cursor_count[new_consumer_cursor] = 0
            self._cursor_count[new_consumer_cursor] += 1
486

487 488
            self._cv.notify_all()

B
barriery 已提交
489
        _LOGGER.debug(
B
barriery 已提交
490 491
            self._log("(logid={}) Op({}) Got data from output_buffer".format(
                resp.values()[0].id, op_name)))
B
barriery 已提交
492
        return resp
493 494

    def stop(self):
495
        _LOGGER.info(self._log("stop."))
496
        self._stop.value = 1
B
barrierye 已提交
497 498
        with self._cv:
            self._cv.notify_all()
499 500


501
class ThreadChannel(Queue.PriorityQueue):
502 503 504 505 506 507 508 509 510
    """ 
    (Thread version)The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
B
barriery 已提交
511
    3. Function front support timeout param to make auto-batching.
512 513 514 515 516

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
B
barrierye 已提交
517 518 519 520 521 522 523 524 525 526 527

    There are two buffers and one queue in Channel:

        op_A \                                           / op_D
        op_B - a. input_buf -> b. queue -> c. output_buf - op_E
        op_C /                                           \ op_F
    
    a. In input_buf, the input of multiple predecessor Ops is packed by data ID.
    b. The packed data will be stored in queue.
    c. In order to support multiple successor Ops to retrieve data, output_buf
        maintains the data obtained from queue.
528 529
    """

B
barriery 已提交
530
    def __init__(self, name=None, maxsize=-1):
531 532 533 534 535 536 537 538
        Queue.Queue.__init__(self, maxsize=maxsize)
        self._maxsize = maxsize
        self.name = name
        self._stop = False

        self._cv = threading.Condition()

        self._producers = []
B
barrierye 已提交
539
        self._pushed_producer_count = {}  # {data_id: count}
B
barrierye 已提交
540
        self._input_buf = {}  # {data_id: {op_name: data}}
541

542
        self._reset_max_cursor = 1000000000000000000
B
barrierye 已提交
543 544 545 546
        self._consumer_cursors = {}  # {op_name: idx}
        self._cursor_count = {}  # {cursor: count}
        self._base_cursor = 0
        self._output_buf = []
547

B
barriery 已提交
548 549 550
    def get_maxsize(self):
        return self._maxsize

B
barriery 已提交
551 552 553
    def size(self):
        return self.qsize()

554 555 556 557
    def get_producers(self):
        return self._producers

    def get_consumers(self):
B
barrierye 已提交
558
        return self._consumer_cursors.keys()
559 560 561 562 563 564 565

    def _log(self, info_str):
        return "[{}] {}".format(self.name, info_str)

    def add_producer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
        if op_name in self._producers:
566
            _LOGGER.critical(
B
barriery 已提交
567 568
                self._log("Failed to add producer: producer({}) is "
                          "already in channel".format(op_name)))
569
            os._exit(-1)
570
        self._producers.append(op_name)
B
barriery 已提交
571
        _LOGGER.debug(self._log("Succ add a producer: {}".format(op_name)))
572 573 574

    def add_consumer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
B
barrierye 已提交
575
        if op_name in self._consumer_cursors:
576
            _LOGGER.critical(
B
barriery 已提交
577 578
                self._log("Failed to add consumer: consumer({}) is "
                          "already in channel".format(op_name)))
579
            os._exit(-1)
B
barrierye 已提交
580
        self._consumer_cursors[op_name] = 0
581

B
barrierye 已提交
582 583 584
        if self._cursor_count.get(0) is None:
            self._cursor_count[0] = 0
        self._cursor_count[0] += 1
B
barriery 已提交
585
        _LOGGER.debug(self._log("Succ add a consumer: {}".format(op_name)))
586 587

    def push(self, channeldata, op_name=None):
B
barrierye 已提交
588
        _LOGGER.debug(
B
barriery 已提交
589 590
            self._log("(logid={}) Op({}) Pushing data".format(channeldata.id,
                                                              op_name)))
591
        if len(self._producers) == 0:
592
            _LOGGER.critical(
593
                self._log(
B
barriery 已提交
594 595 596
                    "(logid={}) Op({}) Failed to push data: expected number of "
                    "producers to be greater than 0, but the it is 0.".format(
                        channeldata.id, op_name)))
597
            os._exit(-1)
598 599 600 601
        elif len(self._producers) == 1:
            with self._cv:
                while self._stop is False:
                    try:
B
barrierye 已提交
602
                        self.put({op_name: channeldata}, timeout=0)
603 604 605
                        break
                    except Queue.Full:
                        self._cv.wait()
B
barrierye 已提交
606 607
                if self._stop:
                    raise ChannelStopError()
608
                self._cv.notify_all()
B
barriery 已提交
609
            _LOGGER.debug(
B
barriery 已提交
610 611
                self._log("(logid={}) Op({}) Pushed data into internal_queue.".
                          format(channeldata.id, op_name)))
612 613
            return True
        elif op_name is None:
614
            _LOGGER.critical(
615
                self._log(
B
barriery 已提交
616 617 618
                    "(logid={}) Op({}) Failed to push data: there are multiple"
                    " producers, so op_name cannot be None.".format(
                        channeldata.id, op_name)))
619
            os._exit(-1)
620 621 622 623 624

        producer_num = len(self._producers)
        data_id = channeldata.id
        put_data = None
        with self._cv:
B
barrierye 已提交
625 626
            if data_id not in self._input_buf:
                self._input_buf[data_id] = {
627 628 629
                    name: None
                    for name in self._producers
                }
B
barrierye 已提交
630
                self._pushed_producer_count[data_id] = 0
B
barrierye 已提交
631
            self._input_buf[data_id][op_name] = channeldata
B
barrierye 已提交
632
            if self._pushed_producer_count[data_id] + 1 == producer_num:
B
barrierye 已提交
633 634
                put_data = self._input_buf[data_id]
                self._input_buf.pop(data_id)
B
barrierye 已提交
635
                self._pushed_producer_count.pop(data_id)
636
            else:
B
barrierye 已提交
637
                self._pushed_producer_count[data_id] += 1
638 639

            if put_data is None:
B
barrierye 已提交
640
                _LOGGER.debug(
B
barriery 已提交
641 642 643
                    self._log(
                        "(logid={}) Op({}) Pushed data into input_buffer.".
                        format(data_id, op_name)))
644 645 646 647 648 649 650
            else:
                while self._stop is False:
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
B
barrierye 已提交
651 652
                if self._stop:
                    raise ChannelStopError()
653

B
barrierye 已提交
654
                _LOGGER.debug(
B
barriery 已提交
655 656 657
                    self._log(
                        "(logid={}) Op({}) Pushed data into internal_queue.".
                        format(data_id, op_name)))
658 659 660
            self._cv.notify_all()
        return True

B
barriery 已提交
661
    def front(self, op_name=None, timeout=None):
B
barriery 已提交
662
        _LOGGER.debug(
B
barriery 已提交
663 664
            self._log("Op({}) Getting data[?]; timeout(s)={}".format(op_name,
                                                                     timeout)))
B
barriery 已提交
665
        endtime = None
B
bug fix  
barriery 已提交
666 667 668 669 670
        if timeout is not None:
            if timeout <= 0:
                timeout = None
            else:
                endtime = _time() + timeout
B
barriery 已提交
671

B
barrierye 已提交
672
        if len(self._consumer_cursors) == 0:
673
            _LOGGER.critical(
674
                self._log(
B
barriery 已提交
675 676
                    "Op({}) Failed to get data: expected number of consumers to be "
                    "greater than 0, but the it is 0.".format(op_name)))
677
            os._exit(-1)
B
barrierye 已提交
678
        elif len(self._consumer_cursors) == 1:
679 680 681 682 683 684 685
            resp = None
            with self._cv:
                while self._stop is False and resp is None:
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
B
barriery 已提交
686 687 688
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
689
                                _LOGGER.debug(
B
barriery 已提交
690 691 692
                                    self._log(
                                        "Op({}) Failed to get data: timeout".
                                        format(op_name)))
B
barriery 已提交
693 694 695 696
                                raise ChannelTimeoutError()
                            self._cv.wait(remaining)
                        else:
                            self._cv.wait()
B
barrierye 已提交
697 698
                if self._stop:
                    raise ChannelStopError()
B
barrierye 已提交
699
            _LOGGER.debug(
B
barriery 已提交
700 701
                self._log("(logid={}) Op({}) Got data".format(resp.values()[0]
                                                              .id, op_name)))
702 703
            return resp
        elif op_name is None:
704
            _LOGGER.critical(
B
barriery 已提交
705 706 707
                self._log("Op({}) Failed to get data: there are multiple "
                          "consumers, so op_name cannot be None.".format(
                              op_name)))
708
            os._exit(-1)
709

B
barrierye 已提交
710 711 712 713 714 715 716 717 718 719
        # In output_buf, different Ops (according to op_name) have different
        # cursors. In addition, there is a base_cursor. Their difference is
        # the data_idx to be taken by the corresponding Op at the current
        # time:    data_idx = consumer_cursor - base_cursor
        # 
        #            base_cursor    consumer_B_cursor (data_idx: 3)
        #                 |                       |
        # output_buf: | data0 | data1 | data2 | data3 |
        #                 |
        #   consumer_A_cursor (data_idx: 0)
720
        with self._cv:
B
barrierye 已提交
721 722 723 724
            # When the data required by the current Op is not in output_buf,
            # it is necessary to obtain a data from queue and add it to output_buf.
            while self._stop is False and self._consumer_cursors[
                    op_name] - self._base_cursor >= len(self._output_buf):
725 726
                try:
                    channeldata = self.get(timeout=0)
B
barrierye 已提交
727
                    self._output_buf.append(channeldata)
B
barriery 已提交
728
                    _LOGGER.debug(
B
barriery 已提交
729 730 731
                        self._log(
                            "(logid={}) Op({}) Pop ready item into output_buffer".
                            format(channeldata.values()[0].id, op_name)))
732 733
                    break
                except Queue.Empty:
B
barriery 已提交
734 735 736
                    if timeout is not None:
                        remaining = endtime - _time()
                        if remaining <= 0.0:
B
barriery 已提交
737
                            _LOGGER.debug(
B
barriery 已提交
738 739
                                self._log("Op({}) Failed to get data: timeout".
                                          format(op_name)))
B
barriery 已提交
740 741 742 743
                            raise ChannelTimeoutError()
                        self._cv.wait(remaining)
                    else:
                        self._cv.wait()
B
barrierye 已提交
744 745
            if self._stop:
                raise ChannelStopError()
746

B
barrierye 已提交
747 748 749
            consumer_cursor = self._consumer_cursors[op_name]
            base_cursor = self._base_cursor
            data_idx = consumer_cursor - base_cursor
B
barrierye 已提交
750 751

            resp = None
752

B
barrierye 已提交
753 754 755 756 757 758
            self._cursor_count[consumer_cursor] -= 1
            if consumer_cursor == base_cursor and self._cursor_count[
                    consumer_cursor] == 0:
                # When all the different Ops get the data that data_idx points
                # to, pop the data from output_buf.
                self._cursor_count.pop(consumer_cursor)
B
barrierye 已提交
759
                resp = self._output_buf.pop(0)
B
barrierye 已提交
760
                self._base_cursor += 1
761 762
                # to avoid cursor overflow
                if self._base_cursor >= self._reset_max_cursor:
B
barriery 已提交
763
                    _LOGGER.info(self._log("Reset cursor in Channel"))
764 765 766 767 768 769 770
                    self._base_cursor -= self._reset_max_cursor
                    for name in self._consumer_cursors:
                        self._consumer_cursors[name] -= self._reset_max_cursor
                    self._cursor_count = {
                        cursor - self._reset_max_cursor: count
                        for cursor, count in self._cursor_count.items()
                    }
B
barrierye 已提交
771 772
            else:
                resp = copy.deepcopy(self._output_buf[data_idx])
B
barrierye 已提交
773 774 775 776 777 778

            self._consumer_cursors[op_name] += 1
            new_consumer_cursor = self._consumer_cursors[op_name]
            if self._cursor_count.get(new_consumer_cursor) is None:
                self._cursor_count[new_consumer_cursor] = 0
            self._cursor_count[new_consumer_cursor] += 1
779 780 781

            self._cv.notify_all()

B
barriery 已提交
782
        _LOGGER.debug(
B
barriery 已提交
783 784
            self._log("(logid={}) Op({}) Got data from output_buffer".format(
                resp.values()[0].id, op_name)))
B
barrierye 已提交
785
        return resp
786 787

    def stop(self):
788
        _LOGGER.info(self._log("stop."))
789
        self._stop = True
B
barrierye 已提交
790 791 792
        with self._cv:
            self._cv.notify_all()

B
barriery 已提交
793

B
barriery 已提交
794 795 796
class ChannelTimeoutError(RuntimeError):
    def __init__(self):
        pass
B
barrierye 已提交
797

B
barriery 已提交
798

B
barrierye 已提交
799 800 801
class ChannelStopError(RuntimeError):
    def __init__(self):
        pass