channel.py 39.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
D
dongdaxiang 已提交
16 17 18 19 20 21 22 23 24 25
import threading
import multiprocessing
import multiprocessing.queues
import sys
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
26 27 28
import numpy as np
import logging
import enum
29
import os
30
import copy
31
import time
D
dongdaxiang 已提交
32

33
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
34

D
dongdaxiang 已提交
35

T
TeslaZhao 已提交
36 37 38 39
class ChannelDataErrcode(enum.Enum):
    """
    ChannelData error code
    """
D
dongdaxiang 已提交
40 41 42 43 44
    OK = 0
    TIMEOUT = 1
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
    RPC_PACKAGE_ERROR = 4
B
barrierye 已提交
45
    CLIENT_ERROR = 5
B
barrierye 已提交
46
    CLOSED_ERROR = 6
B
barriery 已提交
47 48
    NO_SERVICE = 7
    UNKNOW = 8
49 50 51
    INPUT_PARAMS_ERROR = 9

    PRODUCT_ERROR = 100
T
TeslaZhao 已提交
52 53 54 55 56 57 58 59


class ProductErrCode(enum.Enum):
    """
    ProductErrCode is a base class for recording business error code. 
    product developers inherit this class and extend more error codes. 
    """
    pass
D
dongdaxiang 已提交
60 61 62


class ChannelDataType(enum.Enum):
63 64 65
    """
    Channel data type
    """
D
dongdaxiang 已提交
66 67 68 69 70
    DICT = 0
    CHANNEL_NPDATA = 1
    ERROR = 2


71 72 73 74
class ChannelData(object):
    def __init__(self,
                 datatype=None,
                 npdata=None,
B
barrierye 已提交
75
                 dictdata=None,
76
                 data_id=None,
T
TeslaZhao 已提交
77 78
                 log_id=None,
                 error_code=None,
B
barrierye 已提交
79
                 error_info=None,
T
TeslaZhao 已提交
80 81
                 prod_error_code=None,
                 prod_error_info=None,
B
barrierye 已提交
82
                 client_need_profile=False):
83 84 85
        '''
        There are several ways to use it:
        
T
TeslaZhao 已提交
86 87 88
        1. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id, log_id)
        2. ChannelData(ChannelDataType.DICT.value, dictdata, data_id, log_id)
        3. ChannelData(error_code, error_info, prod_error_code, prod_error_info, data_id, log_id)
89 90 91 92

        Protobufs are not pickle-able:
        https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module
        '''
T
TeslaZhao 已提交
93
        if error_code is not None or prod_error_code is not None:
94
            if data_id is None or error_info is None:
B
barriery 已提交
95 96
                _LOGGER.critical("Failed to generate ChannelData: data_id"
                                 " and error_info cannot be None")
97
                os._exit(-1)
98 99
            datatype = ChannelDataType.ERROR.value
        else:
B
barrierye 已提交
100
            if datatype == ChannelDataType.CHANNEL_NPDATA.value:
T
TeslaZhao 已提交
101 102
                error_code, error_info = ChannelData.check_npdata(npdata)
                if error_code != ChannelDataErrcode.OK.value:
B
barrierye 已提交
103
                    datatype = ChannelDataType.ERROR.value
T
TeslaZhao 已提交
104 105
                    _LOGGER.error("(data_id={} log_id={}) {}".format(
                        data_id, log_id, error_info))
B
barrierye 已提交
106
            elif datatype == ChannelDataType.DICT.value:
T
TeslaZhao 已提交
107 108
                error_code, error_info = ChannelData.check_dictdata(dictdata)
                if error_code != ChannelDataErrcode.OK.value:
B
barrierye 已提交
109
                    datatype = ChannelDataType.ERROR.value
T
TeslaZhao 已提交
110 111
                    _LOGGER.error("(data_id={} log_id={}) {}".format(
                        data_id, log_id, error_info))
112
            else:
T
TeslaZhao 已提交
113 114
                _LOGGER.critical("(data_id={} log_id={}) datatype not match".
                                 format(data_id, log_id))
115
                os._exit(-1)
116
        self.datatype = datatype
B
barrierye 已提交
117 118
        self.npdata = npdata
        self.dictdata = dictdata
119
        self.id = data_id
T
TeslaZhao 已提交
120 121
        self.log_id = log_id
        self.error_code = error_code
122
        self.error_info = error_info
T
TeslaZhao 已提交
123 124
        self.prod_error_code = prod_error_code
        self.prod_error_info = prod_error_info
B
barrierye 已提交
125
        self.client_need_profile = client_need_profile
B
barrierye 已提交
126
        self.profile_data_set = set()
B
barrierye 已提交
127

128 129 130 131 132 133 134 135 136 137 138
    def get_size(self):
        size = 0
        dict_data = None
        if isinstance(self.dictdata, dict):
            for k in self.dictdata:
                size += sys.getsizeof(self.dictdata[k]) + sys.getsizeof(k)
        if isinstance(self.npdata, dict):
            for k in self.npdata:
                size += sys.getsizeof(self.npdata[k]) + sys.getsizeof(k)
        return size

B
barrierye 已提交
139
    def add_profile(self, profile_set):
B
barrierye 已提交
140 141
        if self.client_need_profile is False:
            self.client_need_profile = True
B
barrierye 已提交
142
        self.profile_data_set |= profile_set
143

B
barrierye 已提交
144 145
    @staticmethod
    def check_dictdata(dictdata):
T
TeslaZhao 已提交
146
        error_code = ChannelDataErrcode.OK.value
B
barrierye 已提交
147
        error_info = None
B
barrierye 已提交
148 149 150 151
        if isinstance(dictdata, list):
            # batch data
            for sample in dictdata:
                if not isinstance(sample, dict):
T
TeslaZhao 已提交
152
                    error_code = ChannelDataErrcode.TYPE_ERROR.value
B
barriery 已提交
153 154
                    error_info = "Failed to check data: the type of " \
                            "data must be dict, but get {}.".format(type(sample))
B
barrierye 已提交
155 156 157
                    break
        elif not isinstance(dictdata, dict):
            # batch size = 1
T
TeslaZhao 已提交
158
            error_code = ChannelDataErrcode.TYPE_ERROR.value
B
barriery 已提交
159 160
            error_info = "Failed to check data: the type of data must " \
                    "be dict, but get {}.".format(type(dictdata))
T
TeslaZhao 已提交
161
        return error_code, error_info
B
barrierye 已提交
162

B
bug fix  
barriery 已提交
163 164
    @staticmethod
    def check_batch_npdata(batch):
T
TeslaZhao 已提交
165
        error_code = ChannelDataErrcode.OK.value
B
bug fix  
barriery 已提交
166 167
        error_info = None
        for npdata in batch:
T
TeslaZhao 已提交
168 169
            error_code, error_info = ChannelData.check_npdata(npdata)
            if error_code != ChannelDataErrcode.OK.value:
B
bug fix  
barriery 已提交
170
                break
T
TeslaZhao 已提交
171
        return error_code, error_info
B
bug fix  
barriery 已提交
172

B
barrierye 已提交
173 174
    @staticmethod
    def check_npdata(npdata):
T
TeslaZhao 已提交
175
        error_code = ChannelDataErrcode.OK.value
176
        error_info = None
W
wangjiawei04 已提交
177 178 179 180
        if isinstance(npdata, list):
            # batch data
            for sample in npdata:
                if not isinstance(sample, dict):
T
TeslaZhao 已提交
181
                    error_code = ChannelDataErrcode.TYPE_ERROR.value
B
barriery 已提交
182 183 184
                    error_info = "Failed to check data: the " \
                            "value of data must be dict, but get {}.".format(
                                    type(sample))
W
wangjiawei04 已提交
185 186 187
                    break
                for _, value in sample.items():
                    if not isinstance(value, np.ndarray):
T
TeslaZhao 已提交
188
                        error_code = ChannelDataErrcode.TYPE_ERROR.value
B
barriery 已提交
189 190 191
                        error_info = "Failed to check data: the" \
                                " value of data must be np.ndarray, but get {}.".format(
                                        type(value))
T
TeslaZhao 已提交
192
                        return error_code, error_info
W
wangjiawei04 已提交
193 194 195 196
        elif isinstance(npdata, dict):
            # batch_size = 1
            for _, value in npdata.items():
                if not isinstance(value, np.ndarray):
T
TeslaZhao 已提交
197
                    error_code = ChannelDataErrcode.TYPE_ERROR.value
B
barriery 已提交
198 199 200
                    error_info = "Failed to check data: the value " \
                            "of data must be np.ndarray, but get {}.".format(
                                    type(value))
W
wangjiawei04 已提交
201 202
                    break
        else:
T
TeslaZhao 已提交
203
            error_code = ChannelDataErrcode.TYPE_ERROR.value
B
barriery 已提交
204 205
            error_info = "Failed to check data: the value of data " \
                    "must be dict, but get {}.".format(type(npdata))
T
TeslaZhao 已提交
206
        return error_code, error_info
207 208 209

    def parse(self):
        feed = None
B
barrierye 已提交
210 211
        if self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            # return narray
212
            feed = self.npdata
B
barrierye 已提交
213 214 215
        elif self.datatype == ChannelDataType.DICT.value:
            # return dict
            feed = self.dictdata
216
        else:
B
barriery 已提交
217 218
            _LOGGER.critical("Failed to parse channeldata: error " \
                    "type({}) in datatype.".format(self.datatype))
219
            os._exit(-1)
220 221
        return feed

222 223 224 225 226 227 228 229
    def __cmp__(self, other):
        if self.id < other.id:
            return -1
        elif self.id == other.id:
            return 0
        else:
            return 1

230 231
    def get_all_data(self):
        return "type[{}], error_code[{}], data_id[{}], log_id[{}], dict_size[{}]".format(
T
TeslaZhao 已提交
232
            ChannelDataType(self.datatype).name, self.error_code, self.id,
233
            self.log_id, self.get_size())
234 235


B
barrierye 已提交
236
class ProcessChannel(object):
237 238 239 240 241 242 243 244 245
    """ 
    (Process version) The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
B
barriery 已提交
246
    3. Function front support timeout param to make auto-batching.
247 248 249 250 251

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
B
barrierye 已提交
252 253 254 255 256 257 258 259 260 261 262

    There are two buffers and one queue in Channel:

        op_A \                                           / op_D
        op_B - a. input_buf -> b. queue -> c. output_buf - op_E
        op_C /                                           \ op_F
    
    a. In input_buf, the input of multiple predecessor Ops is packed by data ID.
    b. The packed data will be stored in queue.
    c. In order to support multiple successor Ops to retrieve data, output_buf
        maintains the data obtained from queue.
263 264
    """

265 266 267 268 269
    def __init__(self,
                 manager,
                 name=None,
                 maxsize=0,
                 channel_recv_frist_arrive=False):
B
barrierye 已提交
270 271 272 273 274 275
        # For queue multiprocess: after putting an object on 
        # an empty queue there may be an infinitessimal delay
        # before the queue's :meth:`~Queue.empty`
        # see more:
        # - https://bugs.python.org/issue18277
        # - https://hg.python.org/cpython/rev/860fc6a2bd21
276
        self._que = manager.PriorityQueue(maxsize=maxsize)
277 278
        self._maxsize = maxsize
        self.name = name
279
        self._stop = manager.Value('i', 0)
280 281 282 283

        self._cv = multiprocessing.Condition()

        self._producers = []
B
barrierye 已提交
284
        self._pushed_producer_count = manager.dict()  # {data_id: count}
B
barrierye 已提交
285
        self._input_buf = manager.dict()  # {data_id: {op_name: data}}
286

287
        self._reset_max_cursor = 1000000000000000000
B
barrierye 已提交
288 289 290 291
        self._consumer_cursors = manager.dict()  # {op_name: cursor}
        self._cursor_count = manager.dict()  # {cursor: count}
        self._base_cursor = manager.Value('i', 0)
        self._output_buf = manager.list()
292

293 294 295
        self._cur_max_dataid = manager.Value('i', -1)
        self._channel_recv_frist_arrive = channel_recv_frist_arrive

B
barriery 已提交
296 297 298
    def get_maxsize(self):
        return self._maxsize

B
barriery 已提交
299 300 301
    def size(self):
        return self._que.qsize()

302 303 304 305
    def get_producers(self):
        return self._producers

    def get_consumers(self):
B
barrierye 已提交
306
        return self._consumer_cursors.keys()
307 308 309 310 311 312 313

    def _log(self, info_str):
        return "[{}] {}".format(self.name, info_str)

    def add_producer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
        if op_name in self._producers:
314
            _LOGGER.critical(
B
barriery 已提交
315 316
                self._log("Failed to add producer: producer({})" \
                        " is already in channel".format(op_name)))
317
            os._exit(-1)
318
        self._producers.append(op_name)
B
barriery 已提交
319
        _LOGGER.debug(self._log("Succ add a producer: {}".format(op_name)))
320 321 322

    def add_consumer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
B
barrierye 已提交
323
        if op_name in self._consumer_cursors:
324
            _LOGGER.critical(
B
barriery 已提交
325 326
                    self._log("Failed to add consumer: consumer({})" \
                            " is already in channel".format(op_name)))
327
            os._exit(-1)
B
barrierye 已提交
328
        self._consumer_cursors[op_name] = 0
329

B
barrierye 已提交
330 331 332
        if self._cursor_count.get(0) is None:
            self._cursor_count[0] = 0
        self._cursor_count[0] += 1
B
barriery 已提交
333
        _LOGGER.debug(self._log("Succ add a consumer: {}".format(op_name)))
334 335

    def push(self, channeldata, op_name=None):
B
barrierye 已提交
336
        _LOGGER.debug(
337
            self._log(
338
                "(data_id={} log_id={}) Op({}) Enter channel::push producers:{}, time:{}".
339
                format(channeldata.id, channeldata.log_id, op_name,
340 341
                       len(self._producers), time.time())))

342
        if len(self._producers) == 0:
343
            _LOGGER.critical(
344
                self._log(
T
TeslaZhao 已提交
345
                    "(data_id={} log_id={}) Op({}) Failed to push data: expected number"
B
barriery 已提交
346
                    " of producers to be greater than 0, but the it is 0.".
T
TeslaZhao 已提交
347
                    format(channeldata.id, channeldata.log_id, op_name)))
348
            os._exit(-1)
349
        elif len(self._producers) == 1:
350
            start_time = _time()
351
            with self._cv:
352 353
                enter_cv_time = _time()
                push_que_time = enter_cv_time
354
                while self._stop.value == 0:
355
                    try:
T
TeslaZhao 已提交
356 357 358 359
                        self._que.put((channeldata.id, {
                            op_name: channeldata
                        }),
                                      timeout=0)
360
                        push_que_time = _time()
361 362 363
                        break
                    except Queue.Full:
                        self._cv.wait()
364
                if self._stop.value == 1:
B
barrierye 已提交
365
                    raise ChannelStopError()
366
                self._cv.notify_all()
367 368
                notify_all_time = _time()
                _LOGGER.debug(
369
                    "(data_id={}) Op({}) channel push cost! enter_cv:{} ms, push_que:{} ms, notify:{} ms, data_size:{}, time:{}".
370 371 372
                    format(channeldata.id, op_name, (enter_cv_time - start_time)
                           * 1000, (push_que_time - enter_cv_time) * 1000, (
                               notify_all_time - push_que_time) * 1000,
373
                           channeldata.get_size(), time.time()))
B
barriery 已提交
374
            _LOGGER.debug(
T
TeslaZhao 已提交
375 376 377
                self._log(
                    "(data_id={} log_id={}) Op({}) Pushed data into internal queue.".
                    format(channeldata.id, channeldata.log_id, op_name)))
378
            return True
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
        elif self._channel_recv_frist_arrive == True:
            start_time = _time()
            with self._cv:
                _LOGGER.debug(
                    "(data_id={}) Op({}) Channel({}) enter channel_recv_first_arrive. _cur_max_dataid:{}".
                    format(channeldata.id, op_name, self.name,
                           self._cur_max_dataid.value))
                if channeldata.id > self._cur_max_dataid.value:
                    enter_cv_time = _time()
                    push_que_time = enter_cv_time
                    while self._stop.value == 0:
                        try:
                            self._que.put((channeldata.id, {
                                op_name: channeldata
                            }),
                                          timeout=0)
                            push_que_time = _time()
                            self._cur_max_dataid.value = channeldata.id
                            break
                        except Queue.Full:
                            self._cv.wait()
                    if self._stop.value == 1:
                        raise ChannelStopError()
                    self._cv.notify_all()
                    notify_all_time = _time()
                    _LOGGER.debug(
                        "(data_id={}) Op({}) channel push cost! enter_cv:{} ms, push_que:{} ms, notify:{} ms, data_size:{}, time:{}".
                        format(channeldata.id, op_name, (
                            enter_cv_time - start_time) * 1000, (
                                push_que_time - enter_cv_time) * 1000, (
                                    notify_all_time - push_que_time) * 1000,
                               channeldata.get_size(), time.time()))
                else:
                    # log and drop it
                    _LOGGER.debug(
                        "(data_id={}) Op({}) send data is dropped! cur_max_dataid:{}".
                        format(channeldata.id, op_name,
                               self._cur_max_dataid.value))
            return True
418
        elif op_name is None:
419
            _LOGGER.critical(
420
                self._log(
T
TeslaZhao 已提交
421
                    "(data_id={} log_id={}) Op({}) Failed to push data: there are multiple "
B
barriery 已提交
422
                    "producers, so op_name cannot be None.".format(
T
TeslaZhao 已提交
423
                        channeldata.id, channeldata.log_id, op_name)))
424
            os._exit(-1)
425 426 427

        producer_num = len(self._producers)
        data_id = channeldata.id
T
TeslaZhao 已提交
428
        log_id = channeldata.log_id
429 430
        put_data = None
        with self._cv:
B
barrierye 已提交
431 432
            if data_id not in self._input_buf:
                self._input_buf[data_id] = {
433 434 435
                    name: None
                    for name in self._producers
                }
B
barrierye 已提交
436
                self._pushed_producer_count[data_id] = 0
437
            # see: https://docs.python.org/3.6/library/multiprocessing.html?highlight=multiprocess#proxy-objects
B
barrierye 已提交
438 439 440 441 442
            # self._input_buf[data_id][op_name] = channeldata
            tmp_input_buf = self._input_buf[data_id]
            tmp_input_buf[op_name] = channeldata
            self._input_buf[data_id] = tmp_input_buf

B
barrierye 已提交
443
            if self._pushed_producer_count[data_id] + 1 == producer_num:
B
barrierye 已提交
444 445
                put_data = self._input_buf[data_id]
                self._input_buf.pop(data_id)
B
barrierye 已提交
446
                self._pushed_producer_count.pop(data_id)
447
            else:
B
barrierye 已提交
448
                self._pushed_producer_count[data_id] += 1
449 450

            if put_data is None:
B
barrierye 已提交
451
                _LOGGER.debug(
B
barriery 已提交
452
                    self._log(
T
TeslaZhao 已提交
453 454
                        "(data_id={} log_id={}) Op({}) Pushed data into input_buffer.".
                        format(data_id, log_id, op_name)))
455
            else:
456
                while self._stop.value == 0:
457
                    try:
T
TeslaZhao 已提交
458
                        self._que.put((data_id, put_data), timeout=0)
459 460 461
                        break
                    except Queue.Empty:
                        self._cv.wait()
462
                if self._stop.value == 1:
B
barrierye 已提交
463
                    raise ChannelStopError()
464

B
barrierye 已提交
465
                _LOGGER.debug(
B
barriery 已提交
466
                    self._log(
467 468
                        "(data_id={} log_id={}) Op({}) Pushed data into internal_queue. time:{}".
                        format(data_id, log_id, op_name, time.time())))
469 470 471
            self._cv.notify_all()
        return True

B
barriery 已提交
472
    def front(self, op_name=None, timeout=None):
B
barriery 已提交
473
        _LOGGER.debug(
B
barriery 已提交
474 475
            self._log("Op({}) Getting data[?]; timeout(s)={}".format(op_name,
                                                                     timeout)))
B
barriery 已提交
476
        endtime = None
B
bug fix  
barriery 已提交
477 478 479 480 481
        if timeout is not None:
            if timeout <= 0:
                timeout = None
            else:
                endtime = _time() + timeout
B
barriery 已提交
482

B
barrierye 已提交
483
        if len(self._consumer_cursors) == 0:
484
            _LOGGER.critical(
485
                self._log(
B
barriery 已提交
486 487
                    "Op({}) Failed to get data: expected number of consumers to be " \
                            "greater than 0, but the it is 0.".format(op_name)))
488
            os._exit(-1)
B
barrierye 已提交
489
        elif len(self._consumer_cursors) == 1:
490
            resp = None
491 492 493
            time_1 = int(round(_time() * 1000000))
            time_2 = time_1
            time_3 = time_2
494
            with self._cv:
495
                time_2 = int(round(_time() * 1000000))
496
                while self._stop.value == 0 and resp is None:
497
                    try:
T
TeslaZhao 已提交
498
                        resp = self._que.get(timeout=0)[1]
499
                        time_3 = int(round(_time() * 1000000))
500 501
                        break
                    except Queue.Empty:
B
barriery 已提交
502 503 504
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
505
                                _LOGGER.debug(
B
barriery 已提交
506 507
                                    self._log("Op({}) Failed to get data: "
                                              "timeout".format(op_name)))
B
barriery 已提交
508 509 510 511
                                raise ChannelTimeoutError()
                            self._cv.wait(remaining)
                        else:
                            self._cv.wait()
512
                if self._stop.value == 1:
B
barrierye 已提交
513
                    raise ChannelStopError()
514 515 516
            key = list(resp.keys())[0]
            data_id = resp[key].id
            _LOGGER.debug(
517
                "(data_id={}) op({}) front cost enter_cv:{} ms, queue_get:{} ms, time:{}".
518
                format(data_id, op_name, (time_2 - time_1) / 1000.0, (
519
                    time_3 - time_2) / 1000.0, time.time()))
T
TeslaZhao 已提交
520 521 522 523 524
            if resp is not None:
                list_values = list(resp.values())
                _LOGGER.debug(
                    self._log("(data_id={} log_id={}) Op({}) Got data".format(
                        list_values[0].id, list_values[0].log_id, op_name)))
525 526
            return resp
        elif op_name is None:
527
            _LOGGER.critical(
528
                self._log(
B
barriery 已提交
529 530
                    "Op({}) Failed to get data: there are multiple consumers, "
                    "so op_name cannot be None.".format(op_name)))
531
            os._exit(-1)
532

B
barrierye 已提交
533 534 535 536 537 538 539 540 541 542
        # In output_buf, different Ops (according to op_name) have different
        # cursors. In addition, there is a base_cursor. Their difference is
        # the data_idx to be taken by the corresponding Op at the current
        # time:    data_idx = consumer_cursor - base_cursor
        # 
        #            base_cursor    consumer_B_cursor (data_idx: 3)
        #                 |                       |
        # output_buf: | data0 | data1 | data2 | data3 |
        #                 |
        #   consumer_A_cursor (data_idx: 0)
543
        with self._cv:
B
barrierye 已提交
544 545
            # When the data required by the current Op is not in output_buf,
            # it is necessary to obtain a data from queue and add it to output_buf.
546
            while self._stop.value == 0 and self._consumer_cursors[
B
barrierye 已提交
547
                    op_name] - self._base_cursor.value >= len(self._output_buf):
548
                try:
T
TeslaZhao 已提交
549
                    channeldata = self._que.get(timeout=0)[1]
B
barrierye 已提交
550
                    self._output_buf.append(channeldata)
T
TeslaZhao 已提交
551
                    list_values = list(channeldata.values())
B
barriery 已提交
552
                    _LOGGER.debug(
B
barriery 已提交
553
                        self._log(
554
                            "(data_id={} log_id={}) Op({}) Pop ready item into output_buffer, time:{}".
T
TeslaZhao 已提交
555
                            format(list_values[0].id, list_values[0].log_id,
556
                                   op_name, time.time())))
557 558
                    break
                except Queue.Empty:
B
barriery 已提交
559 560 561
                    if timeout is not None:
                        remaining = endtime - _time()
                        if remaining <= 0.0:
B
barriery 已提交
562
                            _LOGGER.debug(
B
barriery 已提交
563 564
                                self._log("Op({}) Failed to get data: timeout".
                                          format(op_name)))
B
barriery 已提交
565 566 567 568
                            raise ChannelTimeoutError()
                        self._cv.wait(remaining)
                    else:
                        self._cv.wait()
569
            if self._stop.value == 1:
B
barrierye 已提交
570
                raise ChannelStopError()
571

572
            time_1 = int(round(_time() * 1000000))
B
barrierye 已提交
573 574 575 576
            consumer_cursor = self._consumer_cursors[op_name]
            base_cursor = self._base_cursor.value
            data_idx = consumer_cursor - base_cursor
            resp = self._output_buf[data_idx]
577

B
barrierye 已提交
578 579 580 581 582 583 584 585
            self._cursor_count[consumer_cursor] -= 1
            if consumer_cursor == base_cursor and self._cursor_count[
                    consumer_cursor] == 0:
                # When all the different Ops get the data that data_idx points
                # to, pop the data from output_buf.
                self._cursor_count.pop(consumer_cursor)
                self._output_buf.pop(0)
                self._base_cursor.value += 1
586 587
                # to avoid cursor overflow
                if self._base_cursor.value >= self._reset_max_cursor:
B
barriery 已提交
588
                    _LOGGER.info(self._log("Reset cursor in Channel"))
589 590 591 592 593 594 595 596 597 598
                    self._base_cursor.value -= self._reset_max_cursor
                    for name in self._consumer_cursors.keys():
                        self._consumer_cursors[name] -= self._reset_max_cursor
                    cursor_count_tmp = {
                        cursor - self._reset_max_cursor: count
                        for cursor, count in self._cursor_count.copy().items()
                    }
                    self._cursor_count.clear()
                    for cursor, count in cursor_count_tmp.items():
                        self._cursor_count[cursor] = count
B
barrierye 已提交
599 600 601 602 603 604

            self._consumer_cursors[op_name] += 1
            new_consumer_cursor = self._consumer_cursors[op_name]
            if self._cursor_count.get(new_consumer_cursor) is None:
                self._cursor_count[new_consumer_cursor] = 0
            self._cursor_count[new_consumer_cursor] += 1
605

606
            self._cv.notify_all()
607 608
            time_2 = int(round(_time() * 1000000))
            #_LOGGER.warning("self._cv logic cost:{}".format(time2 - time1))
609

T
TeslaZhao 已提交
610 611 612 613
        if resp is not None:
            list_values = list(resp.values())
            _LOGGER.debug(
                self._log(
614 615 616
                    "(data_id={} log_id={}) Op({}) Got data from output_buffer, time:{}".
                    format(list_values[0].id, list_values[0].log_id, op_name,
                           time.time())))
B
barriery 已提交
617
        return resp
618 619

    def stop(self):
620
        _LOGGER.info(self._log("stop."))
621
        self._stop.value = 1
B
barrierye 已提交
622 623
        with self._cv:
            self._cv.notify_all()
624 625


626
class ThreadChannel(Queue.PriorityQueue):
627 628 629 630 631 632 633 634 635
    """ 
    (Thread version)The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
B
barriery 已提交
636
    3. Function front support timeout param to make auto-batching.
637 638 639 640 641

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
B
barrierye 已提交
642 643 644 645 646 647 648 649 650 651 652

    There are two buffers and one queue in Channel:

        op_A \                                           / op_D
        op_B - a. input_buf -> b. queue -> c. output_buf - op_E
        op_C /                                           \ op_F
    
    a. In input_buf, the input of multiple predecessor Ops is packed by data ID.
    b. The packed data will be stored in queue.
    c. In order to support multiple successor Ops to retrieve data, output_buf
        maintains the data obtained from queue.
653 654
    """

655
    def __init__(self, name=None, maxsize=-1, channel_recv_frist_arrive=False):
656 657 658 659 660 661 662 663
        Queue.Queue.__init__(self, maxsize=maxsize)
        self._maxsize = maxsize
        self.name = name
        self._stop = False

        self._cv = threading.Condition()

        self._producers = []
B
barrierye 已提交
664
        self._pushed_producer_count = {}  # {data_id: count}
B
barrierye 已提交
665
        self._input_buf = {}  # {data_id: {op_name: data}}
666

667
        self._reset_max_cursor = 1000000000000000000
B
barrierye 已提交
668 669 670 671
        self._consumer_cursors = {}  # {op_name: idx}
        self._cursor_count = {}  # {cursor: count}
        self._base_cursor = 0
        self._output_buf = []
672

673 674 675
        self._channel_recv_frist_arrive = channel_recv_frist_arrive
        self._cur_max_dataid = -1

B
barriery 已提交
676 677 678
    def get_maxsize(self):
        return self._maxsize

B
barriery 已提交
679 680 681
    def size(self):
        return self.qsize()

682 683 684 685
    def get_producers(self):
        return self._producers

    def get_consumers(self):
B
barrierye 已提交
686
        return self._consumer_cursors.keys()
687 688 689 690 691 692 693

    def _log(self, info_str):
        return "[{}] {}".format(self.name, info_str)

    def add_producer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
        if op_name in self._producers:
694
            _LOGGER.critical(
B
barriery 已提交
695 696
                self._log("Failed to add producer: producer({}) is "
                          "already in channel".format(op_name)))
697
            os._exit(-1)
698
        self._producers.append(op_name)
B
barriery 已提交
699
        _LOGGER.debug(self._log("Succ add a producer: {}".format(op_name)))
700 701 702

    def add_consumer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
B
barrierye 已提交
703
        if op_name in self._consumer_cursors:
704
            _LOGGER.critical(
B
barriery 已提交
705 706
                self._log("Failed to add consumer: consumer({}) is "
                          "already in channel".format(op_name)))
707
            os._exit(-1)
B
barrierye 已提交
708
        self._consumer_cursors[op_name] = 0
709

B
barrierye 已提交
710 711 712
        if self._cursor_count.get(0) is None:
            self._cursor_count[0] = 0
        self._cursor_count[0] += 1
B
barriery 已提交
713
        _LOGGER.debug(self._log("Succ add a consumer: {}".format(op_name)))
714 715

    def push(self, channeldata, op_name=None):
B
barrierye 已提交
716
        _LOGGER.debug(
T
TeslaZhao 已提交
717 718
            self._log("(data_id={} log_id={}) Op({}) Pushing data".format(
                channeldata.id, channeldata.log_id, op_name)))
719

720
        if len(self._producers) == 0:
721
            _LOGGER.critical(
722
                self._log(
T
TeslaZhao 已提交
723
                    "(data_id={} log_id={}) Op({}) Failed to push data: expected number of "
B
barriery 已提交
724
                    "producers to be greater than 0, but the it is 0.".format(
T
TeslaZhao 已提交
725
                        channeldata.id, channeldata.log_id, op_name)))
726
            os._exit(-1)
727 728 729 730
        elif len(self._producers) == 1:
            with self._cv:
                while self._stop is False:
                    try:
T
TeslaZhao 已提交
731 732 733 734
                        self.put((channeldata.id, {
                            op_name: channeldata
                        }),
                                 timeout=0)
735 736 737
                        break
                    except Queue.Full:
                        self._cv.wait()
B
barrierye 已提交
738 739
                if self._stop:
                    raise ChannelStopError()
740
                self._cv.notify_all()
B
barriery 已提交
741
            _LOGGER.debug(
T
TeslaZhao 已提交
742 743 744
                self._log(
                    "(data_id={} log_id={}) Op({}) Pushed data into internal_queue.".
                    format(channeldata.id, channeldata.log_id, op_name)))
745
            return True
746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768
        elif self._channel_recv_frist_arrive is True:
            with self._cv:
                if channeldata.id > self._cur_max_dataid:
                    while self._stop is False:
                        try:
                            self.put((channeldata.id, {
                                op_name: channeldata
                            }),
                                     timeout=0)
                            self._cur_max_dataid = channeldata.id
                            break
                        except Queue.Full:
                            self._cv.wait()
                    if self._stop:
                        raise ChannelStopError()
                    self._cv.notify_all()
                else:
                    # log and drop it
                    _LOGGER.debug(
                        "(data_id={}) Op({}) send data is dropped! cur_max_dataid:{}".
                        format(channeldata.id, op_name, self._cur_max_dataid))
            return True

769
        elif op_name is None:
770
            _LOGGER.critical(
771
                self._log(
T
TeslaZhao 已提交
772
                    "(data_id={} log_id={}) Op({}) Failed to push data: there are multiple"
B
barriery 已提交
773
                    " producers, so op_name cannot be None.".format(
T
TeslaZhao 已提交
774
                        channeldata.id, channeldata.log_id, op_name)))
775
            os._exit(-1)
776 777 778

        producer_num = len(self._producers)
        data_id = channeldata.id
T
TeslaZhao 已提交
779
        log_id = channeldata.log_id
780 781
        put_data = None
        with self._cv:
B
barrierye 已提交
782 783
            if data_id not in self._input_buf:
                self._input_buf[data_id] = {
784 785 786
                    name: None
                    for name in self._producers
                }
B
barrierye 已提交
787
                self._pushed_producer_count[data_id] = 0
B
barrierye 已提交
788
            self._input_buf[data_id][op_name] = channeldata
B
barrierye 已提交
789
            if self._pushed_producer_count[data_id] + 1 == producer_num:
B
barrierye 已提交
790 791
                put_data = self._input_buf[data_id]
                self._input_buf.pop(data_id)
B
barrierye 已提交
792
                self._pushed_producer_count.pop(data_id)
793
            else:
B
barrierye 已提交
794
                self._pushed_producer_count[data_id] += 1
795 796

            if put_data is None:
B
barrierye 已提交
797
                _LOGGER.debug(
B
barriery 已提交
798
                    self._log(
T
TeslaZhao 已提交
799 800
                        "(data_id={} log_id={}) Op({}) Pushed data into input_buffer.".
                        format(data_id, log_id, op_name)))
801 802 803
            else:
                while self._stop is False:
                    try:
T
TeslaZhao 已提交
804
                        self.put((data_id, put_data), timeout=0)
805 806 807
                        break
                    except Queue.Empty:
                        self._cv.wait()
B
barrierye 已提交
808 809
                if self._stop:
                    raise ChannelStopError()
810

B
barrierye 已提交
811
                _LOGGER.debug(
B
barriery 已提交
812
                    self._log(
T
TeslaZhao 已提交
813 814
                        "(data_id={} log_id={}) Op({}) Pushed data into internal_queue.".
                        format(data_id, log_id, op_name)))
815 816 817
            self._cv.notify_all()
        return True

B
barriery 已提交
818
    def front(self, op_name=None, timeout=None):
B
barriery 已提交
819
        _LOGGER.debug(
B
barriery 已提交
820 821
            self._log("Op({}) Getting data[?]; timeout(s)={}".format(op_name,
                                                                     timeout)))
B
barriery 已提交
822
        endtime = None
B
bug fix  
barriery 已提交
823 824 825 826 827
        if timeout is not None:
            if timeout <= 0:
                timeout = None
            else:
                endtime = _time() + timeout
B
barriery 已提交
828

B
barrierye 已提交
829
        if len(self._consumer_cursors) == 0:
830
            _LOGGER.critical(
831
                self._log(
B
barriery 已提交
832 833
                    "Op({}) Failed to get data: expected number of consumers to be "
                    "greater than 0, but the it is 0.".format(op_name)))
834
            os._exit(-1)
B
barrierye 已提交
835
        elif len(self._consumer_cursors) == 1:
836 837 838 839
            resp = None
            with self._cv:
                while self._stop is False and resp is None:
                    try:
T
TeslaZhao 已提交
840
                        resp = self.get(timeout=0)[1]
841 842
                        break
                    except Queue.Empty:
B
barriery 已提交
843 844 845
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
846
                                _LOGGER.debug(
B
barriery 已提交
847 848 849
                                    self._log(
                                        "Op({}) Failed to get data: timeout".
                                        format(op_name)))
B
barriery 已提交
850 851 852 853
                                raise ChannelTimeoutError()
                            self._cv.wait(remaining)
                        else:
                            self._cv.wait()
B
barrierye 已提交
854 855
                if self._stop:
                    raise ChannelStopError()
T
TeslaZhao 已提交
856 857 858 859 860
            if resp is not None:
                list_values = list(resp.values())
                _LOGGER.debug(
                    self._log("(data_id={} log_id={}) Op({}) Got data".format(
                        list_values[0].id, list_values[0].log_id, op_name)))
861 862
            return resp
        elif op_name is None:
863
            _LOGGER.critical(
B
barriery 已提交
864 865 866
                self._log("Op({}) Failed to get data: there are multiple "
                          "consumers, so op_name cannot be None.".format(
                              op_name)))
867
            os._exit(-1)
868

B
barrierye 已提交
869 870 871 872 873 874 875 876 877 878
        # In output_buf, different Ops (according to op_name) have different
        # cursors. In addition, there is a base_cursor. Their difference is
        # the data_idx to be taken by the corresponding Op at the current
        # time:    data_idx = consumer_cursor - base_cursor
        # 
        #            base_cursor    consumer_B_cursor (data_idx: 3)
        #                 |                       |
        # output_buf: | data0 | data1 | data2 | data3 |
        #                 |
        #   consumer_A_cursor (data_idx: 0)
879
        with self._cv:
B
barrierye 已提交
880 881 882 883
            # When the data required by the current Op is not in output_buf,
            # it is necessary to obtain a data from queue and add it to output_buf.
            while self._stop is False and self._consumer_cursors[
                    op_name] - self._base_cursor >= len(self._output_buf):
884
                try:
T
TeslaZhao 已提交
885
                    channeldata = self.get(timeout=0)[1]
B
barrierye 已提交
886
                    self._output_buf.append(channeldata)
T
TeslaZhao 已提交
887
                    list_values = list(channeldata.values())
B
barriery 已提交
888
                    _LOGGER.debug(
B
barriery 已提交
889
                        self._log(
T
TeslaZhao 已提交
890
                            "(data_id={} log_id={}) Op({}) Pop ready item into output_buffer".
T
TeslaZhao 已提交
891 892
                            format(list_values[0].id, list_values[0].log_id,
                                   op_name)))
893 894
                    break
                except Queue.Empty:
B
barriery 已提交
895 896 897
                    if timeout is not None:
                        remaining = endtime - _time()
                        if remaining <= 0.0:
B
barriery 已提交
898
                            _LOGGER.debug(
B
barriery 已提交
899 900
                                self._log("Op({}) Failed to get data: timeout".
                                          format(op_name)))
B
barriery 已提交
901 902 903 904
                            raise ChannelTimeoutError()
                        self._cv.wait(remaining)
                    else:
                        self._cv.wait()
B
barrierye 已提交
905 906
            if self._stop:
                raise ChannelStopError()
907

B
barrierye 已提交
908 909 910
            consumer_cursor = self._consumer_cursors[op_name]
            base_cursor = self._base_cursor
            data_idx = consumer_cursor - base_cursor
B
barrierye 已提交
911 912

            resp = None
913

B
barrierye 已提交
914 915 916 917 918 919
            self._cursor_count[consumer_cursor] -= 1
            if consumer_cursor == base_cursor and self._cursor_count[
                    consumer_cursor] == 0:
                # When all the different Ops get the data that data_idx points
                # to, pop the data from output_buf.
                self._cursor_count.pop(consumer_cursor)
B
barrierye 已提交
920
                resp = self._output_buf.pop(0)
B
barrierye 已提交
921
                self._base_cursor += 1
922 923
                # to avoid cursor overflow
                if self._base_cursor >= self._reset_max_cursor:
B
barriery 已提交
924
                    _LOGGER.info(self._log("Reset cursor in Channel"))
925 926 927 928 929 930 931
                    self._base_cursor -= self._reset_max_cursor
                    for name in self._consumer_cursors:
                        self._consumer_cursors[name] -= self._reset_max_cursor
                    self._cursor_count = {
                        cursor - self._reset_max_cursor: count
                        for cursor, count in self._cursor_count.items()
                    }
B
barrierye 已提交
932 933
            else:
                resp = copy.deepcopy(self._output_buf[data_idx])
B
barrierye 已提交
934 935 936 937 938 939

            self._consumer_cursors[op_name] += 1
            new_consumer_cursor = self._consumer_cursors[op_name]
            if self._cursor_count.get(new_consumer_cursor) is None:
                self._cursor_count[new_consumer_cursor] = 0
            self._cursor_count[new_consumer_cursor] += 1
940 941 942

            self._cv.notify_all()

T
TeslaZhao 已提交
943 944 945 946 947 948
        if resp is not None:
            list_values = list(resp.values())
            _LOGGER.debug(
                self._log(
                    "(data_id={} log_id={}) Op({}) Got data from output_buffer".
                    format(list_values[0].id, list_values[0].log_id, op_name)))
B
barrierye 已提交
949
        return resp
950 951

    def stop(self):
952
        _LOGGER.info(self._log("stop."))
953
        self._stop = True
B
barrierye 已提交
954 955 956
        with self._cv:
            self._cv.notify_all()

B
barriery 已提交
957

B
barriery 已提交
958 959 960
class ChannelTimeoutError(RuntimeError):
    def __init__(self):
        pass
B
barrierye 已提交
961

B
barriery 已提交
962

B
barrierye 已提交
963 964 965
class ChannelStopError(RuntimeError):
    def __init__(self):
        pass