channel.py 39.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
D
dongdaxiang 已提交
16 17 18 19 20 21 22 23 24 25
import threading
import multiprocessing
import multiprocessing.queues
import sys
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
26 27 28
import numpy as np
import logging
import enum
29
import os
30
import copy
31
import time
H
huangjianhui 已提交
32
from .error_catch import ErrorCatch, CustomException, ProductErrCode
F
felixhjh 已提交
33
from .error_catch import CustomExceptionCode as ChannelDataErrcode
D
dongdaxiang 已提交
34

35
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
36

D
dongdaxiang 已提交
37
class ChannelDataType(enum.Enum):
38 39 40
    """
    Channel data type
    """
D
dongdaxiang 已提交
41 42 43 44 45
    DICT = 0
    CHANNEL_NPDATA = 1
    ERROR = 2


46 47 48 49
class ChannelData(object):
    def __init__(self,
                 datatype=None,
                 npdata=None,
B
barrierye 已提交
50
                 dictdata=None,
51
                 data_id=None,
T
TeslaZhao 已提交
52 53
                 log_id=None,
                 error_code=None,
B
barrierye 已提交
54
                 error_info=None,
T
TeslaZhao 已提交
55 56
                 prod_error_code=None,
                 prod_error_info=None,
B
barrierye 已提交
57
                 client_need_profile=False):
58 59 60
        '''
        There are several ways to use it:
        
T
TeslaZhao 已提交
61 62 63
        1. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id, log_id)
        2. ChannelData(ChannelDataType.DICT.value, dictdata, data_id, log_id)
        3. ChannelData(error_code, error_info, prod_error_code, prod_error_info, data_id, log_id)
64 65 66 67

        Protobufs are not pickle-able:
        https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module
        '''
T
TeslaZhao 已提交
68
        if error_code is not None or prod_error_code is not None:
69
            if data_id is None or error_info is None:
B
barriery 已提交
70 71
                _LOGGER.critical("Failed to generate ChannelData: data_id"
                                 " and error_info cannot be None")
72
                os._exit(-1)
73 74
            datatype = ChannelDataType.ERROR.value
        else:
B
barrierye 已提交
75
            if datatype == ChannelDataType.CHANNEL_NPDATA.value:
T
TeslaZhao 已提交
76 77
                error_code, error_info = ChannelData.check_npdata(npdata)
                if error_code != ChannelDataErrcode.OK.value:
B
barrierye 已提交
78
                    datatype = ChannelDataType.ERROR.value
T
TeslaZhao 已提交
79 80
                    _LOGGER.error("(data_id={} log_id={}) {}".format(
                        data_id, log_id, error_info))
B
barrierye 已提交
81
            elif datatype == ChannelDataType.DICT.value:
T
TeslaZhao 已提交
82 83
                error_code, error_info = ChannelData.check_dictdata(dictdata)
                if error_code != ChannelDataErrcode.OK.value:
B
barrierye 已提交
84
                    datatype = ChannelDataType.ERROR.value
T
TeslaZhao 已提交
85 86
                    _LOGGER.error("(data_id={} log_id={}) {}".format(
                        data_id, log_id, error_info))
87
            else:
T
TeslaZhao 已提交
88 89
                _LOGGER.critical("(data_id={} log_id={}) datatype not match".
                                 format(data_id, log_id))
90
                os._exit(-1)
91
        self.datatype = datatype
B
barrierye 已提交
92 93
        self.npdata = npdata
        self.dictdata = dictdata
94
        self.id = data_id
T
TeslaZhao 已提交
95 96
        self.log_id = log_id
        self.error_code = error_code
97
        self.error_info = error_info
T
TeslaZhao 已提交
98 99
        self.prod_error_code = prod_error_code
        self.prod_error_info = prod_error_info
B
barrierye 已提交
100
        self.client_need_profile = client_need_profile
B
barrierye 已提交
101
        self.profile_data_set = set()
B
barrierye 已提交
102

103 104 105 106 107 108 109 110 111 112
    def get_size(self):
        size = 0
        if isinstance(self.dictdata, dict):
            for k in self.dictdata:
                size += sys.getsizeof(self.dictdata[k]) + sys.getsizeof(k)
        if isinstance(self.npdata, dict):
            for k in self.npdata:
                size += sys.getsizeof(self.npdata[k]) + sys.getsizeof(k)
        return size

B
barrierye 已提交
113
    def add_profile(self, profile_set):
B
barrierye 已提交
114 115
        if self.client_need_profile is False:
            self.client_need_profile = True
B
barrierye 已提交
116
        self.profile_data_set |= profile_set
117

B
barrierye 已提交
118 119
    @staticmethod
    def check_dictdata(dictdata):
T
TeslaZhao 已提交
120
        error_code = ChannelDataErrcode.OK.value
B
barrierye 已提交
121
        error_info = None
B
barrierye 已提交
122 123 124 125
        if isinstance(dictdata, list):
            # batch data
            for sample in dictdata:
                if not isinstance(sample, dict):
T
TeslaZhao 已提交
126
                    error_code = ChannelDataErrcode.TYPE_ERROR.value
B
barriery 已提交
127 128
                    error_info = "Failed to check data: the type of " \
                            "data must be dict, but get {}.".format(type(sample))
B
barrierye 已提交
129 130 131
                    break
        elif not isinstance(dictdata, dict):
            # batch size = 1
T
TeslaZhao 已提交
132
            error_code = ChannelDataErrcode.TYPE_ERROR.value
B
barriery 已提交
133 134
            error_info = "Failed to check data: the type of data must " \
                    "be dict, but get {}.".format(type(dictdata))
T
TeslaZhao 已提交
135
        return error_code, error_info
B
barrierye 已提交
136

B
bug fix  
barriery 已提交
137 138
    @staticmethod
    def check_batch_npdata(batch):
T
TeslaZhao 已提交
139
        error_code = ChannelDataErrcode.OK.value
B
bug fix  
barriery 已提交
140 141
        error_info = None
        for npdata in batch:
T
TeslaZhao 已提交
142 143
            error_code, error_info = ChannelData.check_npdata(npdata)
            if error_code != ChannelDataErrcode.OK.value:
B
bug fix  
barriery 已提交
144
                break
T
TeslaZhao 已提交
145
        return error_code, error_info
B
bug fix  
barriery 已提交
146

B
barrierye 已提交
147 148
    @staticmethod
    def check_npdata(npdata):
T
TeslaZhao 已提交
149
        error_code = ChannelDataErrcode.OK.value
150
        error_info = None
W
wangjiawei04 已提交
151 152 153 154
        if isinstance(npdata, list):
            # batch data
            for sample in npdata:
                if not isinstance(sample, dict):
T
TeslaZhao 已提交
155
                    error_code = ChannelDataErrcode.TYPE_ERROR.value
B
barriery 已提交
156 157 158
                    error_info = "Failed to check data: the " \
                            "value of data must be dict, but get {}.".format(
                                    type(sample))
W
wangjiawei04 已提交
159 160 161
                    break
                for _, value in sample.items():
                    if not isinstance(value, np.ndarray):
T
TeslaZhao 已提交
162
                        error_code = ChannelDataErrcode.TYPE_ERROR.value
B
barriery 已提交
163 164 165
                        error_info = "Failed to check data: the" \
                                " value of data must be np.ndarray, but get {}.".format(
                                        type(value))
T
TeslaZhao 已提交
166
                        return error_code, error_info
W
wangjiawei04 已提交
167 168 169 170
        elif isinstance(npdata, dict):
            # batch_size = 1
            for _, value in npdata.items():
                if not isinstance(value, np.ndarray):
T
TeslaZhao 已提交
171
                    error_code = ChannelDataErrcode.TYPE_ERROR.value
B
barriery 已提交
172 173 174
                    error_info = "Failed to check data: the value " \
                            "of data must be np.ndarray, but get {}.".format(
                                    type(value))
W
wangjiawei04 已提交
175 176
                    break
        else:
T
TeslaZhao 已提交
177
            error_code = ChannelDataErrcode.TYPE_ERROR.value
B
barriery 已提交
178 179
            error_info = "Failed to check data: the value of data " \
                    "must be dict, but get {}.".format(type(npdata))
T
TeslaZhao 已提交
180
        return error_code, error_info
181 182 183

    def parse(self):
        feed = None
B
barrierye 已提交
184 185
        if self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            # return narray
186
            feed = self.npdata
B
barrierye 已提交
187 188 189
        elif self.datatype == ChannelDataType.DICT.value:
            # return dict
            feed = self.dictdata
190
        else:
B
barriery 已提交
191 192
            _LOGGER.critical("Failed to parse channeldata: error " \
                    "type({}) in datatype.".format(self.datatype))
193
            os._exit(-1)
194 195
        return feed

196 197 198 199 200 201 202 203
    def __cmp__(self, other):
        if self.id < other.id:
            return -1
        elif self.id == other.id:
            return 0
        else:
            return 1

204 205
    def get_all_data(self):
        return "type[{}], error_code[{}], data_id[{}], log_id[{}], dict_size[{}]".format(
T
TeslaZhao 已提交
206
            ChannelDataType(self.datatype).name, self.error_code, self.id,
207
            self.log_id, self.get_size())
208 209


B
barrierye 已提交
210
class ProcessChannel(object):
211 212 213 214 215 216 217 218 219
    """ 
    (Process version) The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
B
barriery 已提交
220
    3. Function front support timeout param to make auto-batching.
221 222 223 224 225

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
B
barrierye 已提交
226 227 228 229 230 231 232 233 234 235 236

    There are two buffers and one queue in Channel:

        op_A \                                           / op_D
        op_B - a. input_buf -> b. queue -> c. output_buf - op_E
        op_C /                                           \ op_F
    
    a. In input_buf, the input of multiple predecessor Ops is packed by data ID.
    b. The packed data will be stored in queue.
    c. In order to support multiple successor Ops to retrieve data, output_buf
        maintains the data obtained from queue.
237 238
    """

239 240 241 242 243
    def __init__(self,
                 manager,
                 name=None,
                 maxsize=0,
                 channel_recv_frist_arrive=False):
B
barrierye 已提交
244 245 246 247 248 249
        # For queue multiprocess: after putting an object on 
        # an empty queue there may be an infinitessimal delay
        # before the queue's :meth:`~Queue.empty`
        # see more:
        # - https://bugs.python.org/issue18277
        # - https://hg.python.org/cpython/rev/860fc6a2bd21
250
        self._que = manager.PriorityQueue(maxsize=maxsize)
251 252
        self._maxsize = maxsize
        self.name = name
253
        self._stop = manager.Value('i', 0)
254 255 256 257

        self._cv = multiprocessing.Condition()

        self._producers = []
B
barrierye 已提交
258
        self._pushed_producer_count = manager.dict()  # {data_id: count}
B
barrierye 已提交
259
        self._input_buf = manager.dict()  # {data_id: {op_name: data}}
260

261
        self._reset_max_cursor = 1000000000000000000
B
barrierye 已提交
262 263 264 265
        self._consumer_cursors = manager.dict()  # {op_name: cursor}
        self._cursor_count = manager.dict()  # {cursor: count}
        self._base_cursor = manager.Value('i', 0)
        self._output_buf = manager.list()
266

267 268 269
        self._cur_max_dataid = manager.Value('i', -1)
        self._channel_recv_frist_arrive = channel_recv_frist_arrive

B
barriery 已提交
270 271 272
    def get_maxsize(self):
        return self._maxsize

B
barriery 已提交
273 274 275
    def size(self):
        return self._que.qsize()

276 277 278 279
    def get_producers(self):
        return self._producers

    def get_consumers(self):
B
barrierye 已提交
280
        return self._consumer_cursors.keys()
281 282 283 284 285 286 287

    def _log(self, info_str):
        return "[{}] {}".format(self.name, info_str)

    def add_producer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
        if op_name in self._producers:
288
            _LOGGER.critical(
B
barriery 已提交
289 290
                self._log("Failed to add producer: producer({})" \
                        " is already in channel".format(op_name)))
291
            os._exit(-1)
292
        self._producers.append(op_name)
B
barriery 已提交
293
        _LOGGER.debug(self._log("Succ add a producer: {}".format(op_name)))
294 295 296

    def add_consumer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
B
barrierye 已提交
297
        if op_name in self._consumer_cursors:
298
            _LOGGER.critical(
B
barriery 已提交
299 300
                    self._log("Failed to add consumer: consumer({})" \
                            " is already in channel".format(op_name)))
301
            os._exit(-1)
B
barrierye 已提交
302
        self._consumer_cursors[op_name] = 0
303

B
barrierye 已提交
304 305 306
        if self._cursor_count.get(0) is None:
            self._cursor_count[0] = 0
        self._cursor_count[0] += 1
B
barriery 已提交
307
        _LOGGER.debug(self._log("Succ add a consumer: {}".format(op_name)))
308 309

    def push(self, channeldata, op_name=None):
B
barrierye 已提交
310
        _LOGGER.debug(
311
            self._log(
312
                "(data_id={} log_id={}) Op({}) Enter channel::push producers:{}, time:{}".
313
                format(channeldata.id, channeldata.log_id, op_name,
314 315
                       len(self._producers), time.time())))

316
        if len(self._producers) == 0:
317
            _LOGGER.critical(
318
                self._log(
T
TeslaZhao 已提交
319
                    "(data_id={} log_id={}) Op({}) Failed to push data: expected number"
B
barriery 已提交
320
                    " of producers to be greater than 0, but the it is 0.".
T
TeslaZhao 已提交
321
                    format(channeldata.id, channeldata.log_id, op_name)))
322
            os._exit(-1)
323
        elif len(self._producers) == 1:
324
            start_time = _time()
325
            with self._cv:
326 327
                enter_cv_time = _time()
                push_que_time = enter_cv_time
328
                while self._stop.value == 0:
329
                    try:
T
TeslaZhao 已提交
330 331 332 333
                        self._que.put((channeldata.id, {
                            op_name: channeldata
                        }),
                                      timeout=0)
334
                        push_que_time = _time()
335 336 337
                        break
                    except Queue.Full:
                        self._cv.wait()
338
                if self._stop.value == 1:
B
barrierye 已提交
339
                    raise ChannelStopError()
340
                self._cv.notify_all()
341 342
                notify_all_time = _time()
                _LOGGER.debug(
343
                    "(data_id={}) Op({}) channel push cost! enter_cv:{} ms, push_que:{} ms, notify:{} ms, data_size:{}, time:{}".
344 345 346
                    format(channeldata.id, op_name, (enter_cv_time - start_time)
                           * 1000, (push_que_time - enter_cv_time) * 1000, (
                               notify_all_time - push_que_time) * 1000,
347
                           channeldata.get_size(), time.time()))
B
barriery 已提交
348
            _LOGGER.debug(
T
TeslaZhao 已提交
349 350 351
                self._log(
                    "(data_id={} log_id={}) Op({}) Pushed data into internal queue.".
                    format(channeldata.id, channeldata.log_id, op_name)))
352
            return True
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
        elif self._channel_recv_frist_arrive == True:
            start_time = _time()
            with self._cv:
                _LOGGER.debug(
                    "(data_id={}) Op({}) Channel({}) enter channel_recv_first_arrive. _cur_max_dataid:{}".
                    format(channeldata.id, op_name, self.name,
                           self._cur_max_dataid.value))
                if channeldata.id > self._cur_max_dataid.value:
                    enter_cv_time = _time()
                    push_que_time = enter_cv_time
                    while self._stop.value == 0:
                        try:
                            self._que.put((channeldata.id, {
                                op_name: channeldata
                            }),
                                          timeout=0)
                            push_que_time = _time()
                            self._cur_max_dataid.value = channeldata.id
                            break
                        except Queue.Full:
                            self._cv.wait()
                    if self._stop.value == 1:
                        raise ChannelStopError()
                    self._cv.notify_all()
                    notify_all_time = _time()
                    _LOGGER.debug(
                        "(data_id={}) Op({}) channel push cost! enter_cv:{} ms, push_que:{} ms, notify:{} ms, data_size:{}, time:{}".
                        format(channeldata.id, op_name, (
                            enter_cv_time - start_time) * 1000, (
                                push_que_time - enter_cv_time) * 1000, (
                                    notify_all_time - push_que_time) * 1000,
                               channeldata.get_size(), time.time()))
                else:
                    # log and drop it
                    _LOGGER.debug(
                        "(data_id={}) Op({}) send data is dropped! cur_max_dataid:{}".
                        format(channeldata.id, op_name,
                               self._cur_max_dataid.value))
            return True
392
        elif op_name is None:
393
            _LOGGER.critical(
394
                self._log(
T
TeslaZhao 已提交
395
                    "(data_id={} log_id={}) Op({}) Failed to push data: there are multiple "
B
barriery 已提交
396
                    "producers, so op_name cannot be None.".format(
T
TeslaZhao 已提交
397
                        channeldata.id, channeldata.log_id, op_name)))
398
            os._exit(-1)
399 400 401

        producer_num = len(self._producers)
        data_id = channeldata.id
T
TeslaZhao 已提交
402
        log_id = channeldata.log_id
403 404
        put_data = None
        with self._cv:
B
barrierye 已提交
405 406
            if data_id not in self._input_buf:
                self._input_buf[data_id] = {
407 408 409
                    name: None
                    for name in self._producers
                }
B
barrierye 已提交
410
                self._pushed_producer_count[data_id] = 0
411
            # see: https://docs.python.org/3.6/library/multiprocessing.html?highlight=multiprocess#proxy-objects
B
barrierye 已提交
412 413 414 415 416
            # self._input_buf[data_id][op_name] = channeldata
            tmp_input_buf = self._input_buf[data_id]
            tmp_input_buf[op_name] = channeldata
            self._input_buf[data_id] = tmp_input_buf

B
barrierye 已提交
417
            if self._pushed_producer_count[data_id] + 1 == producer_num:
B
barrierye 已提交
418 419
                put_data = self._input_buf[data_id]
                self._input_buf.pop(data_id)
B
barrierye 已提交
420
                self._pushed_producer_count.pop(data_id)
421
            else:
B
barrierye 已提交
422
                self._pushed_producer_count[data_id] += 1
423 424

            if put_data is None:
B
barrierye 已提交
425
                _LOGGER.debug(
B
barriery 已提交
426
                    self._log(
T
TeslaZhao 已提交
427 428
                        "(data_id={} log_id={}) Op({}) Pushed data into input_buffer.".
                        format(data_id, log_id, op_name)))
429
            else:
430
                while self._stop.value == 0:
431
                    try:
T
TeslaZhao 已提交
432
                        self._que.put((data_id, put_data), timeout=0)
433 434 435
                        break
                    except Queue.Empty:
                        self._cv.wait()
436
                if self._stop.value == 1:
B
barrierye 已提交
437
                    raise ChannelStopError()
438

B
barrierye 已提交
439
                _LOGGER.debug(
B
barriery 已提交
440
                    self._log(
441 442
                        "(data_id={} log_id={}) Op({}) Pushed data into internal_queue. time:{}".
                        format(data_id, log_id, op_name, time.time())))
443 444 445
            self._cv.notify_all()
        return True

B
barriery 已提交
446
    def front(self, op_name=None, timeout=None):
B
barriery 已提交
447
        _LOGGER.debug(
B
barriery 已提交
448 449
            self._log("Op({}) Getting data[?]; timeout(s)={}".format(op_name,
                                                                     timeout)))
B
barriery 已提交
450
        endtime = None
B
bug fix  
barriery 已提交
451 452 453 454 455
        if timeout is not None:
            if timeout <= 0:
                timeout = None
            else:
                endtime = _time() + timeout
B
barriery 已提交
456

B
barrierye 已提交
457
        if len(self._consumer_cursors) == 0:
458
            _LOGGER.critical(
459
                self._log(
B
barriery 已提交
460 461
                    "Op({}) Failed to get data: expected number of consumers to be " \
                            "greater than 0, but the it is 0.".format(op_name)))
462
            os._exit(-1)
B
barrierye 已提交
463
        elif len(self._consumer_cursors) == 1:
464
            resp = None
465 466 467
            time_1 = int(round(_time() * 1000000))
            time_2 = time_1
            time_3 = time_2
468
            with self._cv:
469
                time_2 = int(round(_time() * 1000000))
470
                while self._stop.value == 0 and resp is None:
471
                    try:
T
TeslaZhao 已提交
472
                        resp = self._que.get(timeout=0)[1]
473
                        time_3 = int(round(_time() * 1000000))
474 475
                        break
                    except Queue.Empty:
B
barriery 已提交
476 477 478
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
479
                                _LOGGER.debug(
B
barriery 已提交
480 481
                                    self._log("Op({}) Failed to get data: "
                                              "timeout".format(op_name)))
B
barriery 已提交
482 483 484 485
                                raise ChannelTimeoutError()
                            self._cv.wait(remaining)
                        else:
                            self._cv.wait()
486
                if self._stop.value == 1:
B
barrierye 已提交
487
                    raise ChannelStopError()
488 489 490
            key = list(resp.keys())[0]
            data_id = resp[key].id
            _LOGGER.debug(
491
                "(data_id={}) op({}) front cost enter_cv:{} ms, queue_get:{} ms, time:{}".
492
                format(data_id, op_name, (time_2 - time_1) / 1000.0, (
493
                    time_3 - time_2) / 1000.0, time.time()))
T
TeslaZhao 已提交
494 495 496 497 498
            if resp is not None:
                list_values = list(resp.values())
                _LOGGER.debug(
                    self._log("(data_id={} log_id={}) Op({}) Got data".format(
                        list_values[0].id, list_values[0].log_id, op_name)))
499 500
            return resp
        elif op_name is None:
501
            _LOGGER.critical(
502
                self._log(
B
barriery 已提交
503 504
                    "Op({}) Failed to get data: there are multiple consumers, "
                    "so op_name cannot be None.".format(op_name)))
505
            os._exit(-1)
506

B
barrierye 已提交
507 508 509 510 511 512 513 514 515 516
        # In output_buf, different Ops (according to op_name) have different
        # cursors. In addition, there is a base_cursor. Their difference is
        # the data_idx to be taken by the corresponding Op at the current
        # time:    data_idx = consumer_cursor - base_cursor
        # 
        #            base_cursor    consumer_B_cursor (data_idx: 3)
        #                 |                       |
        # output_buf: | data0 | data1 | data2 | data3 |
        #                 |
        #   consumer_A_cursor (data_idx: 0)
517
        with self._cv:
B
barrierye 已提交
518 519
            # When the data required by the current Op is not in output_buf,
            # it is necessary to obtain a data from queue and add it to output_buf.
520
            while self._stop.value == 0 and self._consumer_cursors[
B
barrierye 已提交
521
                    op_name] - self._base_cursor.value >= len(self._output_buf):
522
                try:
T
TeslaZhao 已提交
523
                    channeldata = self._que.get(timeout=0)[1]
B
barrierye 已提交
524
                    self._output_buf.append(channeldata)
T
TeslaZhao 已提交
525
                    list_values = list(channeldata.values())
B
barriery 已提交
526
                    _LOGGER.debug(
B
barriery 已提交
527
                        self._log(
528
                            "(data_id={} log_id={}) Op({}) Pop ready item into output_buffer, time:{}".
T
TeslaZhao 已提交
529
                            format(list_values[0].id, list_values[0].log_id,
530
                                   op_name, time.time())))
531 532
                    break
                except Queue.Empty:
B
barriery 已提交
533 534 535
                    if timeout is not None:
                        remaining = endtime - _time()
                        if remaining <= 0.0:
B
barriery 已提交
536
                            _LOGGER.debug(
B
barriery 已提交
537 538
                                self._log("Op({}) Failed to get data: timeout".
                                          format(op_name)))
B
barriery 已提交
539 540 541 542
                            raise ChannelTimeoutError()
                        self._cv.wait(remaining)
                    else:
                        self._cv.wait()
543
            if self._stop.value == 1:
B
barrierye 已提交
544
                raise ChannelStopError()
545

546
            time_1 = int(round(_time() * 1000000))
B
barrierye 已提交
547 548 549 550
            consumer_cursor = self._consumer_cursors[op_name]
            base_cursor = self._base_cursor.value
            data_idx = consumer_cursor - base_cursor
            resp = self._output_buf[data_idx]
551

B
barrierye 已提交
552 553 554 555 556 557 558 559
            self._cursor_count[consumer_cursor] -= 1
            if consumer_cursor == base_cursor and self._cursor_count[
                    consumer_cursor] == 0:
                # When all the different Ops get the data that data_idx points
                # to, pop the data from output_buf.
                self._cursor_count.pop(consumer_cursor)
                self._output_buf.pop(0)
                self._base_cursor.value += 1
560 561
                # to avoid cursor overflow
                if self._base_cursor.value >= self._reset_max_cursor:
B
barriery 已提交
562
                    _LOGGER.info(self._log("Reset cursor in Channel"))
563 564 565 566 567 568 569 570 571 572
                    self._base_cursor.value -= self._reset_max_cursor
                    for name in self._consumer_cursors.keys():
                        self._consumer_cursors[name] -= self._reset_max_cursor
                    cursor_count_tmp = {
                        cursor - self._reset_max_cursor: count
                        for cursor, count in self._cursor_count.copy().items()
                    }
                    self._cursor_count.clear()
                    for cursor, count in cursor_count_tmp.items():
                        self._cursor_count[cursor] = count
B
barrierye 已提交
573 574 575 576 577 578

            self._consumer_cursors[op_name] += 1
            new_consumer_cursor = self._consumer_cursors[op_name]
            if self._cursor_count.get(new_consumer_cursor) is None:
                self._cursor_count[new_consumer_cursor] = 0
            self._cursor_count[new_consumer_cursor] += 1
579

580
            self._cv.notify_all()
581 582
            time_2 = int(round(_time() * 1000000))
            #_LOGGER.warning("self._cv logic cost:{}".format(time2 - time1))
583

T
TeslaZhao 已提交
584 585 586 587
        if resp is not None:
            list_values = list(resp.values())
            _LOGGER.debug(
                self._log(
588 589 590
                    "(data_id={} log_id={}) Op({}) Got data from output_buffer, time:{}".
                    format(list_values[0].id, list_values[0].log_id, op_name,
                           time.time())))
B
barriery 已提交
591
        return resp
592 593

    def stop(self):
594
        _LOGGER.info(self._log("stop."))
595
        self._stop.value = 1
B
barrierye 已提交
596 597
        with self._cv:
            self._cv.notify_all()
598 599


600
class ThreadChannel(Queue.PriorityQueue):
601 602 603 604 605 606 607 608 609
    """ 
    (Thread version)The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
B
barriery 已提交
610
    3. Function front support timeout param to make auto-batching.
611 612 613 614 615

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
B
barrierye 已提交
616 617 618 619 620 621 622 623 624 625 626

    There are two buffers and one queue in Channel:

        op_A \                                           / op_D
        op_B - a. input_buf -> b. queue -> c. output_buf - op_E
        op_C /                                           \ op_F
    
    a. In input_buf, the input of multiple predecessor Ops is packed by data ID.
    b. The packed data will be stored in queue.
    c. In order to support multiple successor Ops to retrieve data, output_buf
        maintains the data obtained from queue.
627 628
    """

629
    def __init__(self, name=None, maxsize=-1, channel_recv_frist_arrive=False):
630 631 632 633 634 635 636 637
        Queue.Queue.__init__(self, maxsize=maxsize)
        self._maxsize = maxsize
        self.name = name
        self._stop = False

        self._cv = threading.Condition()

        self._producers = []
B
barrierye 已提交
638
        self._pushed_producer_count = {}  # {data_id: count}
B
barrierye 已提交
639
        self._input_buf = {}  # {data_id: {op_name: data}}
640

641
        self._reset_max_cursor = 1000000000000000000
B
barrierye 已提交
642 643 644 645
        self._consumer_cursors = {}  # {op_name: idx}
        self._cursor_count = {}  # {cursor: count}
        self._base_cursor = 0
        self._output_buf = []
646

647 648 649
        self._channel_recv_frist_arrive = channel_recv_frist_arrive
        self._cur_max_dataid = -1

B
barriery 已提交
650 651 652
    def get_maxsize(self):
        return self._maxsize

B
barriery 已提交
653 654 655
    def size(self):
        return self.qsize()

656 657 658 659
    def get_producers(self):
        return self._producers

    def get_consumers(self):
B
barrierye 已提交
660
        return self._consumer_cursors.keys()
661 662 663 664 665 666 667

    def _log(self, info_str):
        return "[{}] {}".format(self.name, info_str)

    def add_producer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
        if op_name in self._producers:
668
            _LOGGER.critical(
B
barriery 已提交
669 670
                self._log("Failed to add producer: producer({}) is "
                          "already in channel".format(op_name)))
671
            os._exit(-1)
672
        self._producers.append(op_name)
B
barriery 已提交
673
        _LOGGER.debug(self._log("Succ add a producer: {}".format(op_name)))
674 675 676

    def add_consumer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
B
barrierye 已提交
677
        if op_name in self._consumer_cursors:
678
            _LOGGER.critical(
B
barriery 已提交
679 680
                self._log("Failed to add consumer: consumer({}) is "
                          "already in channel".format(op_name)))
681
            os._exit(-1)
B
barrierye 已提交
682
        self._consumer_cursors[op_name] = 0
683

B
barrierye 已提交
684 685 686
        if self._cursor_count.get(0) is None:
            self._cursor_count[0] = 0
        self._cursor_count[0] += 1
B
barriery 已提交
687
        _LOGGER.debug(self._log("Succ add a consumer: {}".format(op_name)))
688 689

    def push(self, channeldata, op_name=None):
B
barrierye 已提交
690
        _LOGGER.debug(
T
TeslaZhao 已提交
691 692
            self._log("(data_id={} log_id={}) Op({}) Pushing data".format(
                channeldata.id, channeldata.log_id, op_name)))
693

694
        if len(self._producers) == 0:
695
            _LOGGER.critical(
696
                self._log(
T
TeslaZhao 已提交
697
                    "(data_id={} log_id={}) Op({}) Failed to push data: expected number of "
B
barriery 已提交
698
                    "producers to be greater than 0, but the it is 0.".format(
T
TeslaZhao 已提交
699
                        channeldata.id, channeldata.log_id, op_name)))
700
            os._exit(-1)
701 702 703 704
        elif len(self._producers) == 1:
            with self._cv:
                while self._stop is False:
                    try:
T
TeslaZhao 已提交
705 706 707 708
                        self.put((channeldata.id, {
                            op_name: channeldata
                        }),
                                 timeout=0)
709 710 711
                        break
                    except Queue.Full:
                        self._cv.wait()
B
barrierye 已提交
712 713
                if self._stop:
                    raise ChannelStopError()
714
                self._cv.notify_all()
B
barriery 已提交
715
            _LOGGER.debug(
T
TeslaZhao 已提交
716 717 718
                self._log(
                    "(data_id={} log_id={}) Op({}) Pushed data into internal_queue.".
                    format(channeldata.id, channeldata.log_id, op_name)))
719
            return True
720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742
        elif self._channel_recv_frist_arrive is True:
            with self._cv:
                if channeldata.id > self._cur_max_dataid:
                    while self._stop is False:
                        try:
                            self.put((channeldata.id, {
                                op_name: channeldata
                            }),
                                     timeout=0)
                            self._cur_max_dataid = channeldata.id
                            break
                        except Queue.Full:
                            self._cv.wait()
                    if self._stop:
                        raise ChannelStopError()
                    self._cv.notify_all()
                else:
                    # log and drop it
                    _LOGGER.debug(
                        "(data_id={}) Op({}) send data is dropped! cur_max_dataid:{}".
                        format(channeldata.id, op_name, self._cur_max_dataid))
            return True

743
        elif op_name is None:
744
            _LOGGER.critical(
745
                self._log(
T
TeslaZhao 已提交
746
                    "(data_id={} log_id={}) Op({}) Failed to push data: there are multiple"
B
barriery 已提交
747
                    " producers, so op_name cannot be None.".format(
T
TeslaZhao 已提交
748
                        channeldata.id, channeldata.log_id, op_name)))
749
            os._exit(-1)
750 751 752

        producer_num = len(self._producers)
        data_id = channeldata.id
T
TeslaZhao 已提交
753
        log_id = channeldata.log_id
754 755
        put_data = None
        with self._cv:
B
barrierye 已提交
756 757
            if data_id not in self._input_buf:
                self._input_buf[data_id] = {
758 759 760
                    name: None
                    for name in self._producers
                }
B
barrierye 已提交
761
                self._pushed_producer_count[data_id] = 0
B
barrierye 已提交
762
            self._input_buf[data_id][op_name] = channeldata
B
barrierye 已提交
763
            if self._pushed_producer_count[data_id] + 1 == producer_num:
B
barrierye 已提交
764 765
                put_data = self._input_buf[data_id]
                self._input_buf.pop(data_id)
B
barrierye 已提交
766
                self._pushed_producer_count.pop(data_id)
767
            else:
B
barrierye 已提交
768
                self._pushed_producer_count[data_id] += 1
769 770

            if put_data is None:
B
barrierye 已提交
771
                _LOGGER.debug(
B
barriery 已提交
772
                    self._log(
T
TeslaZhao 已提交
773 774
                        "(data_id={} log_id={}) Op({}) Pushed data into input_buffer.".
                        format(data_id, log_id, op_name)))
775 776 777
            else:
                while self._stop is False:
                    try:
T
TeslaZhao 已提交
778
                        self.put((data_id, put_data), timeout=0)
779 780 781
                        break
                    except Queue.Empty:
                        self._cv.wait()
B
barrierye 已提交
782 783
                if self._stop:
                    raise ChannelStopError()
784

B
barrierye 已提交
785
                _LOGGER.debug(
B
barriery 已提交
786
                    self._log(
T
TeslaZhao 已提交
787 788
                        "(data_id={} log_id={}) Op({}) Pushed data into internal_queue.".
                        format(data_id, log_id, op_name)))
789 790 791
            self._cv.notify_all()
        return True

B
barriery 已提交
792
    def front(self, op_name=None, timeout=None):
B
barriery 已提交
793
        _LOGGER.debug(
B
barriery 已提交
794 795
            self._log("Op({}) Getting data[?]; timeout(s)={}".format(op_name,
                                                                     timeout)))
B
barriery 已提交
796
        endtime = None
B
bug fix  
barriery 已提交
797 798 799 800 801
        if timeout is not None:
            if timeout <= 0:
                timeout = None
            else:
                endtime = _time() + timeout
B
barriery 已提交
802

B
barrierye 已提交
803
        if len(self._consumer_cursors) == 0:
804
            _LOGGER.critical(
805
                self._log(
B
barriery 已提交
806 807
                    "Op({}) Failed to get data: expected number of consumers to be "
                    "greater than 0, but the it is 0.".format(op_name)))
808
            os._exit(-1)
B
barrierye 已提交
809
        elif len(self._consumer_cursors) == 1:
810 811 812 813
            resp = None
            with self._cv:
                while self._stop is False and resp is None:
                    try:
T
TeslaZhao 已提交
814
                        resp = self.get(timeout=0)[1]
815 816
                        break
                    except Queue.Empty:
B
barriery 已提交
817 818 819
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
820
                                _LOGGER.debug(
B
barriery 已提交
821 822 823
                                    self._log(
                                        "Op({}) Failed to get data: timeout".
                                        format(op_name)))
B
barriery 已提交
824 825 826 827
                                raise ChannelTimeoutError()
                            self._cv.wait(remaining)
                        else:
                            self._cv.wait()
B
barrierye 已提交
828 829
                if self._stop:
                    raise ChannelStopError()
T
TeslaZhao 已提交
830 831 832 833 834
            if resp is not None:
                list_values = list(resp.values())
                _LOGGER.debug(
                    self._log("(data_id={} log_id={}) Op({}) Got data".format(
                        list_values[0].id, list_values[0].log_id, op_name)))
835 836
            return resp
        elif op_name is None:
837
            _LOGGER.critical(
B
barriery 已提交
838 839 840
                self._log("Op({}) Failed to get data: there are multiple "
                          "consumers, so op_name cannot be None.".format(
                              op_name)))
841
            os._exit(-1)
842

B
barrierye 已提交
843 844 845 846 847 848 849 850 851 852
        # In output_buf, different Ops (according to op_name) have different
        # cursors. In addition, there is a base_cursor. Their difference is
        # the data_idx to be taken by the corresponding Op at the current
        # time:    data_idx = consumer_cursor - base_cursor
        # 
        #            base_cursor    consumer_B_cursor (data_idx: 3)
        #                 |                       |
        # output_buf: | data0 | data1 | data2 | data3 |
        #                 |
        #   consumer_A_cursor (data_idx: 0)
853
        with self._cv:
B
barrierye 已提交
854 855 856 857
            # When the data required by the current Op is not in output_buf,
            # it is necessary to obtain a data from queue and add it to output_buf.
            while self._stop is False and self._consumer_cursors[
                    op_name] - self._base_cursor >= len(self._output_buf):
858
                try:
T
TeslaZhao 已提交
859
                    channeldata = self.get(timeout=0)[1]
B
barrierye 已提交
860
                    self._output_buf.append(channeldata)
T
TeslaZhao 已提交
861
                    list_values = list(channeldata.values())
B
barriery 已提交
862
                    _LOGGER.debug(
B
barriery 已提交
863
                        self._log(
T
TeslaZhao 已提交
864
                            "(data_id={} log_id={}) Op({}) Pop ready item into output_buffer".
T
TeslaZhao 已提交
865 866
                            format(list_values[0].id, list_values[0].log_id,
                                   op_name)))
867 868
                    break
                except Queue.Empty:
B
barriery 已提交
869 870 871
                    if timeout is not None:
                        remaining = endtime - _time()
                        if remaining <= 0.0:
B
barriery 已提交
872
                            _LOGGER.debug(
B
barriery 已提交
873 874
                                self._log("Op({}) Failed to get data: timeout".
                                          format(op_name)))
B
barriery 已提交
875 876 877 878
                            raise ChannelTimeoutError()
                        self._cv.wait(remaining)
                    else:
                        self._cv.wait()
B
barrierye 已提交
879 880
            if self._stop:
                raise ChannelStopError()
881

B
barrierye 已提交
882 883 884
            consumer_cursor = self._consumer_cursors[op_name]
            base_cursor = self._base_cursor
            data_idx = consumer_cursor - base_cursor
B
barrierye 已提交
885 886

            resp = None
887

B
barrierye 已提交
888 889 890 891 892 893
            self._cursor_count[consumer_cursor] -= 1
            if consumer_cursor == base_cursor and self._cursor_count[
                    consumer_cursor] == 0:
                # When all the different Ops get the data that data_idx points
                # to, pop the data from output_buf.
                self._cursor_count.pop(consumer_cursor)
B
barrierye 已提交
894
                resp = self._output_buf.pop(0)
B
barrierye 已提交
895
                self._base_cursor += 1
896 897
                # to avoid cursor overflow
                if self._base_cursor >= self._reset_max_cursor:
B
barriery 已提交
898
                    _LOGGER.info(self._log("Reset cursor in Channel"))
899 900 901 902 903 904 905
                    self._base_cursor -= self._reset_max_cursor
                    for name in self._consumer_cursors:
                        self._consumer_cursors[name] -= self._reset_max_cursor
                    self._cursor_count = {
                        cursor - self._reset_max_cursor: count
                        for cursor, count in self._cursor_count.items()
                    }
B
barrierye 已提交
906 907
            else:
                resp = copy.deepcopy(self._output_buf[data_idx])
B
barrierye 已提交
908 909 910 911 912 913

            self._consumer_cursors[op_name] += 1
            new_consumer_cursor = self._consumer_cursors[op_name]
            if self._cursor_count.get(new_consumer_cursor) is None:
                self._cursor_count[new_consumer_cursor] = 0
            self._cursor_count[new_consumer_cursor] += 1
914 915 916

            self._cv.notify_all()

T
TeslaZhao 已提交
917 918 919 920 921 922
        if resp is not None:
            list_values = list(resp.values())
            _LOGGER.debug(
                self._log(
                    "(data_id={} log_id={}) Op({}) Got data from output_buffer".
                    format(list_values[0].id, list_values[0].log_id, op_name)))
B
barrierye 已提交
923
        return resp
924 925

    def stop(self):
926
        _LOGGER.info(self._log("stop."))
927
        self._stop = True
B
barrierye 已提交
928 929 930
        with self._cv:
            self._cv.notify_all()

B
barriery 已提交
931

B
barriery 已提交
932 933 934
class ChannelTimeoutError(RuntimeError):
    def __init__(self):
        pass
B
barrierye 已提交
935

B
barriery 已提交
936

B
barrierye 已提交
937 938 939
class ChannelStopError(RuntimeError):
    def __init__(self):
        pass