channel.py 29.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
D
dongdaxiang 已提交
16 17 18 19 20 21 22 23 24 25
import threading
import multiprocessing
import multiprocessing.queues
import sys
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
26 27 28 29
import numpy as np
import logging
import enum
import copy
D
dongdaxiang 已提交
30

W
wangjiawei04 已提交
31
_LOGGER = logging.getLogger()
B
barrierye 已提交
32

D
dongdaxiang 已提交
33 34 35 36 37 38 39

class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
    RPC_PACKAGE_ERROR = 4
B
barrierye 已提交
40
    CLIENT_ERROR = 5
B
barrierye 已提交
41 42
    CLOSED_ERROR = 6
    UNKNOW = 7
D
dongdaxiang 已提交
43 44 45 46 47 48 49 50


class ChannelDataType(enum.Enum):
    DICT = 0
    CHANNEL_NPDATA = 1
    ERROR = 2


51 52 53 54
class ChannelData(object):
    def __init__(self,
                 datatype=None,
                 npdata=None,
B
barrierye 已提交
55
                 dictdata=None,
56 57
                 data_id=None,
                 ecode=None,
B
barrierye 已提交
58 59
                 error_info=None,
                 client_need_profile=False):
60 61 62
        '''
        There are several ways to use it:
        
B
barrierye 已提交
63 64 65
        1. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
        2. ChannelData(ChannelDataType.DICT.value, dictdata, data_id)
        3. ChannelData(ecode, error_info, data_id)
66 67 68 69 70 71 72 73 74

        Protobufs are not pickle-able:
        https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
                raise ValueError("data_id and error_info cannot be None")
            datatype = ChannelDataType.ERROR.value
        else:
B
barrierye 已提交
75 76
            if datatype == ChannelDataType.CHANNEL_NPDATA.value:
                ecode, error_info = ChannelData.check_npdata(npdata)
77
                if ecode != ChannelDataEcode.OK.value:
B
barrierye 已提交
78
                    datatype = ChannelDataType.ERROR.value
B
barrierye 已提交
79
                    _LOGGER.error(error_info)
B
barrierye 已提交
80 81 82 83
            elif datatype == ChannelDataType.DICT.value:
                ecode, error_info = ChannelData.check_dictdata(dictdata)
                if ecode != ChannelDataEcode.OK.value:
                    datatype = ChannelDataType.ERROR.value
B
barrierye 已提交
84
                    _LOGGER.error(error_info)
85 86 87
            else:
                raise ValueError("datatype not match")
        self.datatype = datatype
B
barrierye 已提交
88 89
        self.npdata = npdata
        self.dictdata = dictdata
90 91 92
        self.id = data_id
        self.ecode = ecode
        self.error_info = error_info
B
barrierye 已提交
93
        self.client_need_profile = client_need_profile
B
barrierye 已提交
94
        self.profile_data_set = set()
B
barrierye 已提交
95

B
barrierye 已提交
96
    def add_profile(self, profile_set):
B
barrierye 已提交
97 98
        if self.client_need_profile is False:
            self.client_need_profile = True
B
barrierye 已提交
99
        self.profile_data_set |= profile_set
100

B
barrierye 已提交
101 102 103 104
    @staticmethod
    def check_dictdata(dictdata):
        ecode = ChannelDataEcode.OK.value
        error_info = None
B
barrierye 已提交
105 106 107 108 109 110 111 112 113 114
        if isinstance(dictdata, list):
            # batch data
            for sample in dictdata:
                if not isinstance(sample, dict):
                    ecode = ChannelDataEcode.TYPE_ERROR.value
                    error_info = "the value of data must " \
                            "be dict, but get {}.".format(type(sample))
                    break
        elif not isinstance(dictdata, dict):
            # batch size = 1
B
barrierye 已提交
115
            ecode = ChannelDataEcode.TYPE_ERROR.value
B
barrierye 已提交
116 117
            error_info = "the value of data must " \
                        "be dict, but get {}.".format(type(dictdata))
B
barrierye 已提交
118
        return ecode, error_info
B
barrierye 已提交
119

B
bug fix  
barriery 已提交
120 121 122 123 124 125 126 127 128 129
    @staticmethod
    def check_batch_npdata(batch):
        ecode = ChannelDataEcode.OK.value
        error_info = None
        for npdata in batch:
            ecode, error_info = ChannelData.check_npdata(npdata)
            if ecode != ChannelDataEcode.OK.value:
                break
        return ecode, error_info

B
barrierye 已提交
130 131
    @staticmethod
    def check_npdata(npdata):
132 133
        ecode = ChannelDataEcode.OK.value
        error_info = None
W
wangjiawei04 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
        if isinstance(npdata, list):
            # batch data
            for sample in npdata:
                if not isinstance(sample, dict):
                    ecode = ChannelDataEcode.TYPE_ERROR.value
                    error_info = "the value of data must " \
                            "be dict, but get {}.".format(type(sample))
                    break
                for _, value in sample.items():
                    if not isinstance(value, np.ndarray):
                        ecode = ChannelDataEcode.TYPE_ERROR.value
                        error_info = "the value of data must " \
                                "be np.ndarray, but get {}.".format(type(value))
                        return ecode, error_info
        elif isinstance(npdata, dict):
            # batch_size = 1
            for _, value in npdata.items():
                if not isinstance(value, np.ndarray):
                    ecode = ChannelDataEcode.TYPE_ERROR.value
                    error_info = "the value of data must " \
                            "be np.ndarray, but get {}.".format(type(value))
                    break
        else:
            ecode = ChannelDataEcode.TYPE_ERROR.value
            error_info = "the value of data must " \
                    "be dict, but get {}.".format(type(npdata))
160 161 162 163
        return ecode, error_info

    def parse(self):
        feed = None
B
barrierye 已提交
164 165
        if self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            # return narray
166
            feed = self.npdata
B
barrierye 已提交
167 168 169
        elif self.datatype == ChannelDataType.DICT.value:
            # return dict
            feed = self.dictdata
170 171 172 173 174 175 176 177 178
        else:
            raise TypeError("Error type({}) in datatype.".format(self.datatype))
        return feed

    def __str__(self):
        return "type[{}], ecode[{}], id[{}]".format(
            ChannelDataType(self.datatype).name, self.ecode, self.id)


B
barrierye 已提交
179
class ProcessChannel(object):
180 181 182 183 184 185 186 187 188
    """ 
    (Process version) The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
B
barriery 已提交
189
    3. Function front support timeout param to make auto-batching.
190 191 192 193 194

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
B
barrierye 已提交
195 196 197 198 199 200 201 202 203 204 205

    There are two buffers and one queue in Channel:

        op_A \                                           / op_D
        op_B - a. input_buf -> b. queue -> c. output_buf - op_E
        op_C /                                           \ op_F
    
    a. In input_buf, the input of multiple predecessor Ops is packed by data ID.
    b. The packed data will be stored in queue.
    c. In order to support multiple successor Ops to retrieve data, output_buf
        maintains the data obtained from queue.
206 207
    """

B
barriery 已提交
208
    def __init__(self, manager, name=None, maxsize=0):
B
barrierye 已提交
209 210 211 212 213 214 215
        # For queue multiprocess: after putting an object on 
        # an empty queue there may be an infinitessimal delay
        # before the queue's :meth:`~Queue.empty`
        # see more:
        # - https://bugs.python.org/issue18277
        # - https://hg.python.org/cpython/rev/860fc6a2bd21
        self._que = manager.Queue(maxsize=maxsize)
216 217
        self._maxsize = maxsize
        self.name = name
218
        self._stop = manager.Value('i', 0)
219 220 221 222

        self._cv = multiprocessing.Condition()

        self._producers = []
B
barrierye 已提交
223
        self._pushed_producer_count = manager.dict()  # {data_id: count}
B
barrierye 已提交
224
        self._input_buf = manager.dict()  # {data_id: {op_name: data}}
225

226
        self._reset_max_cursor = 1000000000000000000
B
barrierye 已提交
227 228 229 230
        self._consumer_cursors = manager.dict()  # {op_name: cursor}
        self._cursor_count = manager.dict()  # {cursor: count}
        self._base_cursor = manager.Value('i', 0)
        self._output_buf = manager.list()
231 232 233 234 235

    def get_producers(self):
        return self._producers

    def get_consumers(self):
B
barrierye 已提交
236
        return self._consumer_cursors.keys()
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253

    def _log(self, info_str):
        return "[{}] {}".format(self.name, info_str)

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)

    def add_consumer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
B
barrierye 已提交
254
        if op_name in self._consumer_cursors:
255 256
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
B
barrierye 已提交
257
        self._consumer_cursors[op_name] = 0
258

B
barrierye 已提交
259 260 261
        if self._cursor_count.get(0) is None:
            self._cursor_count[0] = 0
        self._cursor_count[0] += 1
262 263

    def push(self, channeldata, op_name=None):
B
barrierye 已提交
264
        _LOGGER.debug(
265 266 267 268 269 270 271 272 273
            self._log("{} try to push data: {}".format(op_name,
                                                       channeldata.__str__())))
        if len(self._producers) == 0:
            raise Exception(
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
            with self._cv:
274
                while self._stop.value == 0:
275
                    try:
B
barrierye 已提交
276
                        self._que.put({op_name: channeldata}, timeout=0)
277 278 279
                        break
                    except Queue.Full:
                        self._cv.wait()
280
                if self._stop.value == 1:
B
barrierye 已提交
281
                    raise ChannelStopError()
B
barrierye 已提交
282
                _LOGGER.debug(
283
                    self._log("{} channel size: {}".format(op_name,
B
barrierye 已提交
284
                                                           self._que.qsize())))
285
                self._cv.notify_all()
B
barrierye 已提交
286 287
                _LOGGER.debug(self._log("{} notify all".format(op_name)))
            _LOGGER.debug(self._log("{} push data succ!".format(op_name)))
288 289 290 291 292 293 294 295 296 297
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))

        producer_num = len(self._producers)
        data_id = channeldata.id
        put_data = None
        with self._cv:
B
barrierye 已提交
298
            _LOGGER.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
299 300
            if data_id not in self._input_buf:
                self._input_buf[data_id] = {
301 302 303
                    name: None
                    for name in self._producers
                }
B
barrierye 已提交
304
                self._pushed_producer_count[data_id] = 0
305
            # see: https://docs.python.org/3.6/library/multiprocessing.html?highlight=multiprocess#proxy-objects
B
barrierye 已提交
306 307 308 309 310
            # self._input_buf[data_id][op_name] = channeldata
            tmp_input_buf = self._input_buf[data_id]
            tmp_input_buf[op_name] = channeldata
            self._input_buf[data_id] = tmp_input_buf

B
barrierye 已提交
311
            if self._pushed_producer_count[data_id] + 1 == producer_num:
B
barrierye 已提交
312 313
                put_data = self._input_buf[data_id]
                self._input_buf.pop(data_id)
B
barrierye 已提交
314
                self._pushed_producer_count.pop(data_id)
315
            else:
B
barrierye 已提交
316
                self._pushed_producer_count[data_id] += 1
317 318

            if put_data is None:
B
barrierye 已提交
319
                _LOGGER.debug(
320 321 322
                    self._log("{} push data succ, but not push to queue.".
                              format(op_name)))
            else:
323
                while self._stop.value == 0:
324
                    try:
B
barrierye 已提交
325
                        _LOGGER.debug(
326 327
                            self._log("{} push data succ: {}".format(
                                op_name, put_data.__str__())))
B
barrierye 已提交
328
                        self._que.put(put_data, timeout=0)
329 330 331
                        break
                    except Queue.Empty:
                        self._cv.wait()
332
                if self._stop.value == 1:
B
barrierye 已提交
333
                    raise ChannelStopError()
334

B
barrierye 已提交
335
                _LOGGER.debug(
336 337 338 339
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
        return True

B
barriery 已提交
340 341
    def front(self, op_name=None, timeout=None):
        endtime = None
B
bug fix  
barriery 已提交
342 343 344 345 346
        if timeout is not None:
            if timeout <= 0:
                timeout = None
            else:
                endtime = _time() + timeout
B
barriery 已提交
347

B
barrierye 已提交
348
        _LOGGER.debug(self._log("{} try to get data...".format(op_name)))
B
barrierye 已提交
349
        if len(self._consumer_cursors) == 0:
350 351 352 353
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
B
barrierye 已提交
354
        elif len(self._consumer_cursors) == 1:
355 356
            resp = None
            with self._cv:
357
                while self._stop.value == 0 and resp is None:
358
                    try:
B
barrierye 已提交
359
                        _LOGGER.debug(
360
                            self._log("{} try to get(with channel empty: {})".
B
barrierye 已提交
361
                                      format(op_name, self._que.empty())))
B
barrierye 已提交
362
                        resp = self._que.get(timeout=0)
363 364
                        break
                    except Queue.Empty:
B
barriery 已提交
365 366 367 368 369 370 371
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
                                raise ChannelTimeoutError()
                            self._cv.wait(remaining)
                        else:
                            self._cv.wait()
372
                if self._stop.value == 1:
B
barrierye 已提交
373
                    raise ChannelStopError()
374 375 376 377 378 379
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))

B
barrierye 已提交
380 381 382 383 384 385 386 387 388 389
        # In output_buf, different Ops (according to op_name) have different
        # cursors. In addition, there is a base_cursor. Their difference is
        # the data_idx to be taken by the corresponding Op at the current
        # time:    data_idx = consumer_cursor - base_cursor
        # 
        #            base_cursor    consumer_B_cursor (data_idx: 3)
        #                 |                       |
        # output_buf: | data0 | data1 | data2 | data3 |
        #                 |
        #   consumer_A_cursor (data_idx: 0)
390
        with self._cv:
B
barrierye 已提交
391 392
            # When the data required by the current Op is not in output_buf,
            # it is necessary to obtain a data from queue and add it to output_buf.
393
            while self._stop.value == 0 and self._consumer_cursors[
B
barrierye 已提交
394
                    op_name] - self._base_cursor.value >= len(self._output_buf):
B
barrierye 已提交
395
                _LOGGER.debug(
396
                    self._log(
B
barrierye 已提交
397 398 399
                        "({}) B self._consumer_cursors: {}, self._base_cursor: {}, len(self._output_buf): {}".
                        format(op_name, self._consumer_cursors,
                               self._base_cursor.value, len(self._output_buf))))
400
                try:
B
barrierye 已提交
401
                    _LOGGER.debug(
402
                        self._log("{} try to get(with channel size: {})".format(
B
barrierye 已提交
403
                            op_name, self._que.qsize())))
B
barrierye 已提交
404
                    channeldata = self._que.get(timeout=0)
B
barrierye 已提交
405
                    self._output_buf.append(channeldata)
406 407
                    break
                except Queue.Empty:
B
barriery 已提交
408 409 410 411 412 413 414
                    if timeout is not None:
                        remaining = endtime - _time()
                        if remaining <= 0.0:
                            raise ChannelTimeoutError()
                        self._cv.wait(remaining)
                    else:
                        self._cv.wait()
415
            if self._stop.value == 1:
B
barrierye 已提交
416
                raise ChannelStopError()
417

B
barrierye 已提交
418 419 420 421
            consumer_cursor = self._consumer_cursors[op_name]
            base_cursor = self._base_cursor.value
            data_idx = consumer_cursor - base_cursor
            resp = self._output_buf[data_idx]
B
barrierye 已提交
422
            _LOGGER.debug(self._log("{} get data: {}".format(op_name, resp)))
423

B
barrierye 已提交
424 425 426 427 428 429 430 431
            self._cursor_count[consumer_cursor] -= 1
            if consumer_cursor == base_cursor and self._cursor_count[
                    consumer_cursor] == 0:
                # When all the different Ops get the data that data_idx points
                # to, pop the data from output_buf.
                self._cursor_count.pop(consumer_cursor)
                self._output_buf.pop(0)
                self._base_cursor.value += 1
432 433 434 435 436 437 438 439 440 441 442 443
                # to avoid cursor overflow
                if self._base_cursor.value >= self._reset_max_cursor:
                    self._base_cursor.value -= self._reset_max_cursor
                    for name in self._consumer_cursors.keys():
                        self._consumer_cursors[name] -= self._reset_max_cursor
                    cursor_count_tmp = {
                        cursor - self._reset_max_cursor: count
                        for cursor, count in self._cursor_count.copy().items()
                    }
                    self._cursor_count.clear()
                    for cursor, count in cursor_count_tmp.items():
                        self._cursor_count[cursor] = count
B
barrierye 已提交
444 445 446 447 448 449

            self._consumer_cursors[op_name] += 1
            new_consumer_cursor = self._consumer_cursors[op_name]
            if self._cursor_count.get(new_consumer_cursor) is None:
                self._cursor_count[new_consumer_cursor] = 0
            self._cursor_count[new_consumer_cursor] += 1
450

B
barrierye 已提交
451
            _LOGGER.debug(
452
                self._log(
B
barrierye 已提交
453 454 455
                    "({}) A self._consumer_cursors: {}, self._base_cursor: {}, len(self._output_buf): {}".
                    format(op_name, self._consumer_cursors,
                           self._base_cursor.value, len(self._output_buf))))
B
barrierye 已提交
456
            _LOGGER.debug(self._log("{} notify all".format(op_name)))
457 458
            self._cv.notify_all()

B
barrierye 已提交
459
        _LOGGER.debug(self._log("multi | {} get data succ!".format(op_name)))
460 461 462
        return resp  # reference, read only

    def stop(self):
B
barrierye 已提交
463
        _LOGGER.debug(self._log("stop."))
464
        self._stop.value = 1
B
barrierye 已提交
465 466
        with self._cv:
            self._cv.notify_all()
467 468 469 470 471 472 473 474 475 476 477 478


class ThreadChannel(Queue.Queue):
    """ 
    (Thread version)The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
B
barriery 已提交
479
    3. Function front support timeout param to make auto-batching.
480 481 482 483 484

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
B
barrierye 已提交
485 486 487 488 489 490 491 492 493 494 495

    There are two buffers and one queue in Channel:

        op_A \                                           / op_D
        op_B - a. input_buf -> b. queue -> c. output_buf - op_E
        op_C /                                           \ op_F
    
    a. In input_buf, the input of multiple predecessor Ops is packed by data ID.
    b. The packed data will be stored in queue.
    c. In order to support multiple successor Ops to retrieve data, output_buf
        maintains the data obtained from queue.
496 497
    """

B
barriery 已提交
498
    def __init__(self, name=None, maxsize=-1):
499 500 501 502 503 504 505 506
        Queue.Queue.__init__(self, maxsize=maxsize)
        self._maxsize = maxsize
        self.name = name
        self._stop = False

        self._cv = threading.Condition()

        self._producers = []
B
barrierye 已提交
507
        self._pushed_producer_count = {}  # {data_id: count}
B
barrierye 已提交
508
        self._input_buf = {}  # {data_id: {op_name: data}}
509

510
        self._reset_max_cursor = 1000000000000000000
B
barrierye 已提交
511 512 513 514
        self._consumer_cursors = {}  # {op_name: idx}
        self._cursor_count = {}  # {cursor: count}
        self._base_cursor = 0
        self._output_buf = []
515 516 517 518 519

    def get_producers(self):
        return self._producers

    def get_consumers(self):
B
barrierye 已提交
520
        return self._consumer_cursors.keys()
521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537

    def _log(self, info_str):
        return "[{}] {}".format(self.name, info_str)

    def debug(self):
        return self._log("p: {}, c: {}".format(self.get_producers(),
                                               self.get_consumers()))

    def add_producer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
        if op_name in self._producers:
            raise ValueError(
                self._log("producer({}) is already in channel".format(op_name)))
        self._producers.append(op_name)

    def add_consumer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
B
barrierye 已提交
538
        if op_name in self._consumer_cursors:
539 540
            raise ValueError(
                self._log("consumer({}) is already in channel".format(op_name)))
B
barrierye 已提交
541
        self._consumer_cursors[op_name] = 0
542

B
barrierye 已提交
543 544 545
        if self._cursor_count.get(0) is None:
            self._cursor_count[0] = 0
        self._cursor_count[0] += 1
546 547

    def push(self, channeldata, op_name=None):
B
barrierye 已提交
548
        _LOGGER.debug(
549 550 551 552 553 554 555 556 557 558 559
            self._log("{} try to push data: {}".format(op_name,
                                                       channeldata.__str__())))
        if len(self._producers) == 0:
            raise Exception(
                self._log(
                    "expected number of producers to be greater than 0, but the it is 0."
                ))
        elif len(self._producers) == 1:
            with self._cv:
                while self._stop is False:
                    try:
B
barrierye 已提交
560
                        self.put({op_name: channeldata}, timeout=0)
561 562 563
                        break
                    except Queue.Full:
                        self._cv.wait()
B
barrierye 已提交
564 565
                if self._stop:
                    raise ChannelStopError()
566
                self._cv.notify_all()
B
barrierye 已提交
567
            _LOGGER.debug(self._log("{} push data succ!".format(op_name)))
568 569 570 571 572 573 574 575 576 577
            return True
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple producers, so op_name cannot be None."))

        producer_num = len(self._producers)
        data_id = channeldata.id
        put_data = None
        with self._cv:
B
barrierye 已提交
578
            _LOGGER.debug(self._log("{} get lock".format(op_name)))
B
barrierye 已提交
579 580
            if data_id not in self._input_buf:
                self._input_buf[data_id] = {
581 582 583
                    name: None
                    for name in self._producers
                }
B
barrierye 已提交
584
                self._pushed_producer_count[data_id] = 0
B
barrierye 已提交
585
            self._input_buf[data_id][op_name] = channeldata
B
barrierye 已提交
586
            if self._pushed_producer_count[data_id] + 1 == producer_num:
B
barrierye 已提交
587 588
                put_data = self._input_buf[data_id]
                self._input_buf.pop(data_id)
B
barrierye 已提交
589
                self._pushed_producer_count.pop(data_id)
590
            else:
B
barrierye 已提交
591
                self._pushed_producer_count[data_id] += 1
592 593

            if put_data is None:
B
barrierye 已提交
594
                _LOGGER.debug(
595 596 597 598 599 600 601 602 603
                    self._log("{} push data succ, but not push to queue.".
                              format(op_name)))
            else:
                while self._stop is False:
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
B
barrierye 已提交
604 605
                if self._stop:
                    raise ChannelStopError()
606

B
barrierye 已提交
607
                _LOGGER.debug(
608 609 610 611
                    self._log("multi | {} push data succ!".format(op_name)))
            self._cv.notify_all()
        return True

B
barriery 已提交
612 613
    def front(self, op_name=None, timeout=None):
        endtime = None
B
bug fix  
barriery 已提交
614 615 616 617 618
        if timeout is not None:
            if timeout <= 0:
                timeout = None
            else:
                endtime = _time() + timeout
B
barriery 已提交
619

B
barrierye 已提交
620
        _LOGGER.debug(self._log("{} try to get data".format(op_name)))
B
barrierye 已提交
621
        if len(self._consumer_cursors) == 0:
622 623 624 625
            raise Exception(
                self._log(
                    "expected number of consumers to be greater than 0, but the it is 0."
                ))
B
barrierye 已提交
626
        elif len(self._consumer_cursors) == 1:
627 628 629 630 631 632 633
            resp = None
            with self._cv:
                while self._stop is False and resp is None:
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
B
barriery 已提交
634 635 636 637 638 639 640
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
                                raise ChannelTimeoutError()
                            self._cv.wait(remaining)
                        else:
                            self._cv.wait()
B
barrierye 已提交
641 642
                if self._stop:
                    raise ChannelStopError()
B
barrierye 已提交
643
            _LOGGER.debug(
644 645 646 647 648 649 650 651
                self._log("{} get data succ: {}".format(op_name, resp.__str__(
                ))))
            return resp
        elif op_name is None:
            raise Exception(
                self._log(
                    "There are multiple consumers, so op_name cannot be None."))

B
barrierye 已提交
652 653 654 655 656 657 658 659 660 661
        # In output_buf, different Ops (according to op_name) have different
        # cursors. In addition, there is a base_cursor. Their difference is
        # the data_idx to be taken by the corresponding Op at the current
        # time:    data_idx = consumer_cursor - base_cursor
        # 
        #            base_cursor    consumer_B_cursor (data_idx: 3)
        #                 |                       |
        # output_buf: | data0 | data1 | data2 | data3 |
        #                 |
        #   consumer_A_cursor (data_idx: 0)
662
        with self._cv:
B
barrierye 已提交
663 664 665 666
            # When the data required by the current Op is not in output_buf,
            # it is necessary to obtain a data from queue and add it to output_buf.
            while self._stop is False and self._consumer_cursors[
                    op_name] - self._base_cursor >= len(self._output_buf):
667 668
                try:
                    channeldata = self.get(timeout=0)
B
barrierye 已提交
669
                    self._output_buf.append(channeldata)
670 671
                    break
                except Queue.Empty:
B
barriery 已提交
672 673 674 675 676 677 678
                    if timeout is not None:
                        remaining = endtime - _time()
                        if remaining <= 0.0:
                            raise ChannelTimeoutError()
                        self._cv.wait(remaining)
                    else:
                        self._cv.wait()
B
barrierye 已提交
679 680
            if self._stop:
                raise ChannelStopError()
681

B
barrierye 已提交
682 683 684
            consumer_cursor = self._consumer_cursors[op_name]
            base_cursor = self._base_cursor
            data_idx = consumer_cursor - base_cursor
B
barrierye 已提交
685 686

            resp = None
687

B
barrierye 已提交
688 689 690 691 692 693
            self._cursor_count[consumer_cursor] -= 1
            if consumer_cursor == base_cursor and self._cursor_count[
                    consumer_cursor] == 0:
                # When all the different Ops get the data that data_idx points
                # to, pop the data from output_buf.
                self._cursor_count.pop(consumer_cursor)
B
barrierye 已提交
694
                resp = self._output_buf.pop(0)
B
barrierye 已提交
695
                self._base_cursor += 1
696 697 698 699 700 701 702 703 704
                # to avoid cursor overflow
                if self._base_cursor >= self._reset_max_cursor:
                    self._base_cursor -= self._reset_max_cursor
                    for name in self._consumer_cursors:
                        self._consumer_cursors[name] -= self._reset_max_cursor
                    self._cursor_count = {
                        cursor - self._reset_max_cursor: count
                        for cursor, count in self._cursor_count.items()
                    }
B
barrierye 已提交
705 706 707
            else:
                resp = copy.deepcopy(self._output_buf[data_idx])
            _LOGGER.debug(self._log("{} get data: {}".format(op_name, resp)))
B
barrierye 已提交
708 709 710 711 712 713

            self._consumer_cursors[op_name] += 1
            new_consumer_cursor = self._consumer_cursors[op_name]
            if self._cursor_count.get(new_consumer_cursor) is None:
                self._cursor_count[new_consumer_cursor] = 0
            self._cursor_count[new_consumer_cursor] += 1
714 715 716

            self._cv.notify_all()

B
barrierye 已提交
717
        _LOGGER.debug(self._log("multi | {} get data succ!".format(op_name)))
B
barrierye 已提交
718
        return resp
719 720

    def stop(self):
B
barrierye 已提交
721
        _LOGGER.debug(self._log("stop."))
722
        self._stop = True
B
barrierye 已提交
723 724 725
        with self._cv:
            self._cv.notify_all()

B
barriery 已提交
726 727 728
class ChannelTimeoutError(RuntimeError):
    def __init__(self):
        pass
B
barrierye 已提交
729 730 731 732

class ChannelStopError(RuntimeError):
    def __init__(self):
        pass