channel.py 31.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
D
dongdaxiang 已提交
16 17 18 19 20 21 22 23 24 25
import threading
import multiprocessing
import multiprocessing.queues
import sys
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
26 27 28
import numpy as np
import logging
import enum
29
import os
30
import copy
D
dongdaxiang 已提交
31

W
wangjiawei04 已提交
32
_LOGGER = logging.getLogger()
B
barrierye 已提交
33

D
dongdaxiang 已提交
34 35 36 37 38 39 40

class ChannelDataEcode(enum.Enum):
    OK = 0
    TIMEOUT = 1
    NOT_IMPLEMENTED = 2
    TYPE_ERROR = 3
    RPC_PACKAGE_ERROR = 4
B
barrierye 已提交
41
    CLIENT_ERROR = 5
B
barrierye 已提交
42 43
    CLOSED_ERROR = 6
    UNKNOW = 7
D
dongdaxiang 已提交
44 45 46 47 48 49 50 51


class ChannelDataType(enum.Enum):
    DICT = 0
    CHANNEL_NPDATA = 1
    ERROR = 2


52 53 54 55
class ChannelData(object):
    def __init__(self,
                 datatype=None,
                 npdata=None,
B
barrierye 已提交
56
                 dictdata=None,
57 58
                 data_id=None,
                 ecode=None,
B
barrierye 已提交
59 60
                 error_info=None,
                 client_need_profile=False):
61 62 63
        '''
        There are several ways to use it:
        
B
barrierye 已提交
64 65 66
        1. ChannelData(ChannelDataType.CHANNEL_NPDATA.value, npdata, data_id)
        2. ChannelData(ChannelDataType.DICT.value, dictdata, data_id)
        3. ChannelData(ecode, error_info, data_id)
67 68 69 70 71 72

        Protobufs are not pickle-able:
        https://stackoverflow.com/questions/55344376/how-to-import-protobuf-module
        '''
        if ecode is not None:
            if data_id is None or error_info is None:
B
barriery 已提交
73 74
                _LOGGER.critical("Failed to generate ChannelData: data_id"
                                 " and error_info cannot be None")
75
                os._exit(-1)
76 77
            datatype = ChannelDataType.ERROR.value
        else:
B
barrierye 已提交
78 79
            if datatype == ChannelDataType.CHANNEL_NPDATA.value:
                ecode, error_info = ChannelData.check_npdata(npdata)
80
                if ecode != ChannelDataEcode.OK.value:
B
barrierye 已提交
81
                    datatype = ChannelDataType.ERROR.value
B
barriery 已提交
82
                    _LOGGER.error("(logid={}) {}".format(data_id, error_info))
B
barrierye 已提交
83 84 85 86
            elif datatype == ChannelDataType.DICT.value:
                ecode, error_info = ChannelData.check_dictdata(dictdata)
                if ecode != ChannelDataEcode.OK.value:
                    datatype = ChannelDataType.ERROR.value
B
barriery 已提交
87
                    _LOGGER.error("(logid={}) {}".format(data_id, error_info))
88
            else:
B
barriery 已提交
89 90
                _LOGGER.critical("(logid={}) datatype not match".format(
                    data_id))
91
                os._exit(-1)
92
        self.datatype = datatype
B
barrierye 已提交
93 94
        self.npdata = npdata
        self.dictdata = dictdata
95 96 97
        self.id = data_id
        self.ecode = ecode
        self.error_info = error_info
B
barrierye 已提交
98
        self.client_need_profile = client_need_profile
B
barrierye 已提交
99
        self.profile_data_set = set()
B
barrierye 已提交
100

B
barrierye 已提交
101
    def add_profile(self, profile_set):
B
barrierye 已提交
102 103
        if self.client_need_profile is False:
            self.client_need_profile = True
B
barrierye 已提交
104
        self.profile_data_set |= profile_set
105

B
barrierye 已提交
106 107 108 109
    @staticmethod
    def check_dictdata(dictdata):
        ecode = ChannelDataEcode.OK.value
        error_info = None
B
barrierye 已提交
110 111 112 113 114
        if isinstance(dictdata, list):
            # batch data
            for sample in dictdata:
                if not isinstance(sample, dict):
                    ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
115 116
                    error_info = "Failed to check data: the type of " \
                            "data must be dict, but get {}.".format(type(sample))
B
barrierye 已提交
117 118 119
                    break
        elif not isinstance(dictdata, dict):
            # batch size = 1
B
barrierye 已提交
120
            ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
121 122
            error_info = "Failed to check data: the type of data must " \
                    "be dict, but get {}.".format(type(dictdata))
B
barrierye 已提交
123
        return ecode, error_info
B
barrierye 已提交
124

B
bug fix  
barriery 已提交
125 126 127 128 129 130 131 132 133 134
    @staticmethod
    def check_batch_npdata(batch):
        ecode = ChannelDataEcode.OK.value
        error_info = None
        for npdata in batch:
            ecode, error_info = ChannelData.check_npdata(npdata)
            if ecode != ChannelDataEcode.OK.value:
                break
        return ecode, error_info

B
barrierye 已提交
135 136
    @staticmethod
    def check_npdata(npdata):
137 138
        ecode = ChannelDataEcode.OK.value
        error_info = None
W
wangjiawei04 已提交
139 140 141 142 143
        if isinstance(npdata, list):
            # batch data
            for sample in npdata:
                if not isinstance(sample, dict):
                    ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
144 145 146
                    error_info = "Failed to check data: the " \
                            "value of data must be dict, but get {}.".format(
                                    type(sample))
W
wangjiawei04 已提交
147 148 149 150
                    break
                for _, value in sample.items():
                    if not isinstance(value, np.ndarray):
                        ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
151 152 153
                        error_info = "Failed to check data: the" \
                                " value of data must be np.ndarray, but get {}.".format(
                                        type(value))
W
wangjiawei04 已提交
154 155 156 157 158 159
                        return ecode, error_info
        elif isinstance(npdata, dict):
            # batch_size = 1
            for _, value in npdata.items():
                if not isinstance(value, np.ndarray):
                    ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
160 161 162
                    error_info = "Failed to check data: the value " \
                            "of data must be np.ndarray, but get {}.".format(
                                    type(value))
W
wangjiawei04 已提交
163 164 165
                    break
        else:
            ecode = ChannelDataEcode.TYPE_ERROR.value
B
barriery 已提交
166 167
            error_info = "Failed to check data: the value of data " \
                    "must be dict, but get {}.".format(type(npdata))
168 169 170 171
        return ecode, error_info

    def parse(self):
        feed = None
B
barrierye 已提交
172 173
        if self.datatype == ChannelDataType.CHANNEL_NPDATA.value:
            # return narray
174
            feed = self.npdata
B
barrierye 已提交
175 176 177
        elif self.datatype == ChannelDataType.DICT.value:
            # return dict
            feed = self.dictdata
178
        else:
B
barriery 已提交
179 180
            _LOGGER.critical("Failed to parse channeldata: error " \
                    "type({}) in datatype.".format(self.datatype))
181
            os._exit(-1)
182 183 184 185 186 187 188
        return feed

    def __str__(self):
        return "type[{}], ecode[{}], id[{}]".format(
            ChannelDataType(self.datatype).name, self.ecode, self.id)


B
barrierye 已提交
189
class ProcessChannel(object):
190 191 192 193 194 195 196 197 198
    """ 
    (Process version) The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
B
barriery 已提交
199
    3. Function front support timeout param to make auto-batching.
200 201 202 203 204

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
B
barrierye 已提交
205 206 207 208 209 210 211 212 213 214 215

    There are two buffers and one queue in Channel:

        op_A \                                           / op_D
        op_B - a. input_buf -> b. queue -> c. output_buf - op_E
        op_C /                                           \ op_F
    
    a. In input_buf, the input of multiple predecessor Ops is packed by data ID.
    b. The packed data will be stored in queue.
    c. In order to support multiple successor Ops to retrieve data, output_buf
        maintains the data obtained from queue.
216 217
    """

B
barriery 已提交
218
    def __init__(self, manager, name=None, maxsize=0):
B
barrierye 已提交
219 220 221 222 223 224 225
        # For queue multiprocess: after putting an object on 
        # an empty queue there may be an infinitessimal delay
        # before the queue's :meth:`~Queue.empty`
        # see more:
        # - https://bugs.python.org/issue18277
        # - https://hg.python.org/cpython/rev/860fc6a2bd21
        self._que = manager.Queue(maxsize=maxsize)
226 227
        self._maxsize = maxsize
        self.name = name
228
        self._stop = manager.Value('i', 0)
229 230 231 232

        self._cv = multiprocessing.Condition()

        self._producers = []
B
barrierye 已提交
233
        self._pushed_producer_count = manager.dict()  # {data_id: count}
B
barrierye 已提交
234
        self._input_buf = manager.dict()  # {data_id: {op_name: data}}
235

236
        self._reset_max_cursor = 1000000000000000000
B
barrierye 已提交
237 238 239 240
        self._consumer_cursors = manager.dict()  # {op_name: cursor}
        self._cursor_count = manager.dict()  # {cursor: count}
        self._base_cursor = manager.Value('i', 0)
        self._output_buf = manager.list()
241 242 243 244 245

    def get_producers(self):
        return self._producers

    def get_consumers(self):
B
barrierye 已提交
246
        return self._consumer_cursors.keys()
247 248 249 250 251 252 253

    def _log(self, info_str):
        return "[{}] {}".format(self.name, info_str)

    def add_producer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
        if op_name in self._producers:
254
            _LOGGER.critical(
B
barriery 已提交
255 256
                self._log("Failed to add producer: producer({})" \
                        " is already in channel".format(op_name)))
257
            os._exit(-1)
258
        self._producers.append(op_name)
B
barriery 已提交
259
        _LOGGER.debug(self._log("Succ add a producer: {}".format(op_name)))
260 261 262

    def add_consumer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
B
barrierye 已提交
263
        if op_name in self._consumer_cursors:
264
            _LOGGER.critical(
B
barriery 已提交
265 266
                    self._log("Failed to add consumer: consumer({})" \
                            " is already in channel".format(op_name)))
267
            os._exit(-1)
B
barrierye 已提交
268
        self._consumer_cursors[op_name] = 0
269

B
barrierye 已提交
270 271 272
        if self._cursor_count.get(0) is None:
            self._cursor_count[0] = 0
        self._cursor_count[0] += 1
B
barriery 已提交
273
        _LOGGER.debug(self._log("Succ add a consumer: {}".format(op_name)))
274 275

    def push(self, channeldata, op_name=None):
B
barrierye 已提交
276
        _LOGGER.debug(
B
barriery 已提交
277 278
            self._log("(logid={}) Op({}) Pushing data".format(channeldata.id,
                                                              op_name)))
279
        if len(self._producers) == 0:
280
            _LOGGER.critical(
281
                self._log(
B
barriery 已提交
282 283 284
                    "(logid={}) Op({}) Failed to push data: expected number"
                    " of producers to be greater than 0, but the it is 0.".
                    format(channeldata.id, op_name)))
285
            os._exit(-1)
286 287
        elif len(self._producers) == 1:
            with self._cv:
288
                while self._stop.value == 0:
289
                    try:
B
barrierye 已提交
290
                        self._que.put({op_name: channeldata}, timeout=0)
291 292 293
                        break
                    except Queue.Full:
                        self._cv.wait()
294
                if self._stop.value == 1:
B
barrierye 已提交
295
                    raise ChannelStopError()
296
                self._cv.notify_all()
B
barriery 已提交
297
            _LOGGER.debug(
B
barriery 已提交
298 299
                self._log("(logid={}) Op({}) Pushed data into internal queue.".
                          format(channeldata.id, op_name)))
300 301
            return True
        elif op_name is None:
302
            _LOGGER.critical(
303
                self._log(
B
barriery 已提交
304 305 306
                    "(logid={}) Op({}) Failed to push data: there are multiple "
                    "producers, so op_name cannot be None.".format(
                        channeldata.id, op_name)))
307
            os._exit(-1)
308 309 310 311 312

        producer_num = len(self._producers)
        data_id = channeldata.id
        put_data = None
        with self._cv:
B
barrierye 已提交
313 314
            if data_id not in self._input_buf:
                self._input_buf[data_id] = {
315 316 317
                    name: None
                    for name in self._producers
                }
B
barrierye 已提交
318
                self._pushed_producer_count[data_id] = 0
319
            # see: https://docs.python.org/3.6/library/multiprocessing.html?highlight=multiprocess#proxy-objects
B
barrierye 已提交
320 321 322 323 324
            # self._input_buf[data_id][op_name] = channeldata
            tmp_input_buf = self._input_buf[data_id]
            tmp_input_buf[op_name] = channeldata
            self._input_buf[data_id] = tmp_input_buf

B
barrierye 已提交
325
            if self._pushed_producer_count[data_id] + 1 == producer_num:
B
barrierye 已提交
326 327
                put_data = self._input_buf[data_id]
                self._input_buf.pop(data_id)
B
barrierye 已提交
328
                self._pushed_producer_count.pop(data_id)
329
            else:
B
barrierye 已提交
330
                self._pushed_producer_count[data_id] += 1
331 332

            if put_data is None:
B
barrierye 已提交
333
                _LOGGER.debug(
B
barriery 已提交
334 335 336
                    self._log(
                        "(logid={}) Op({}) Pushed data into input_buffer.".
                        format(data_id, op_name)))
337
            else:
338
                while self._stop.value == 0:
339
                    try:
B
barrierye 已提交
340
                        self._que.put(put_data, timeout=0)
341 342 343
                        break
                    except Queue.Empty:
                        self._cv.wait()
344
                if self._stop.value == 1:
B
barrierye 已提交
345
                    raise ChannelStopError()
346

B
barrierye 已提交
347
                _LOGGER.debug(
B
barriery 已提交
348 349 350
                    self._log(
                        "(logid={}) Op({}) Pushed data into internal_queue.".
                        format(data_id, op_name)))
351 352 353
            self._cv.notify_all()
        return True

B
barriery 已提交
354
    def front(self, op_name=None, timeout=None):
B
barriery 已提交
355
        _LOGGER.debug(
B
barriery 已提交
356 357
            self._log("Op({}) Getting data[?]; timeout(s)={}".format(op_name,
                                                                     timeout)))
B
barriery 已提交
358
        endtime = None
B
bug fix  
barriery 已提交
359 360 361 362 363
        if timeout is not None:
            if timeout <= 0:
                timeout = None
            else:
                endtime = _time() + timeout
B
barriery 已提交
364

B
barrierye 已提交
365
        if len(self._consumer_cursors) == 0:
366
            _LOGGER.critical(
367
                self._log(
B
barriery 已提交
368 369
                    "Op({}) Failed to get data: expected number of consumers to be " \
                            "greater than 0, but the it is 0.".format(op_name)))
370
            os._exit(-1)
B
barrierye 已提交
371
        elif len(self._consumer_cursors) == 1:
372 373
            resp = None
            with self._cv:
374
                while self._stop.value == 0 and resp is None:
375
                    try:
B
barrierye 已提交
376
                        resp = self._que.get(timeout=0)
377 378
                        break
                    except Queue.Empty:
B
barriery 已提交
379 380 381
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
382
                                _LOGGER.debug(
B
barriery 已提交
383 384
                                    self._log("Op({}) Failed to get data: "
                                              "timeout".format(op_name)))
B
barriery 已提交
385 386 387 388
                                raise ChannelTimeoutError()
                            self._cv.wait(remaining)
                        else:
                            self._cv.wait()
389
                if self._stop.value == 1:
B
barrierye 已提交
390
                    raise ChannelStopError()
B
barriery 已提交
391
            _LOGGER.debug(
B
barriery 已提交
392 393
                self._log("(logid={}) Op({}) Got data".format(resp.values()[0]
                                                              .id, op_name)))
394 395
            return resp
        elif op_name is None:
396
            _LOGGER.critical(
397
                self._log(
B
barriery 已提交
398 399
                    "Op({}) Failed to get data: there are multiple consumers, "
                    "so op_name cannot be None.".format(op_name)))
400
            os._exit(-1)
401

B
barrierye 已提交
402 403 404 405 406 407 408 409 410 411
        # In output_buf, different Ops (according to op_name) have different
        # cursors. In addition, there is a base_cursor. Their difference is
        # the data_idx to be taken by the corresponding Op at the current
        # time:    data_idx = consumer_cursor - base_cursor
        # 
        #            base_cursor    consumer_B_cursor (data_idx: 3)
        #                 |                       |
        # output_buf: | data0 | data1 | data2 | data3 |
        #                 |
        #   consumer_A_cursor (data_idx: 0)
412
        with self._cv:
B
barrierye 已提交
413 414
            # When the data required by the current Op is not in output_buf,
            # it is necessary to obtain a data from queue and add it to output_buf.
415
            while self._stop.value == 0 and self._consumer_cursors[
B
barrierye 已提交
416
                    op_name] - self._base_cursor.value >= len(self._output_buf):
417
                try:
B
barrierye 已提交
418
                    channeldata = self._que.get(timeout=0)
B
barrierye 已提交
419
                    self._output_buf.append(channeldata)
B
barriery 已提交
420
                    _LOGGER.debug(
B
barriery 已提交
421 422 423
                        self._log(
                            "(logid={}) Op({}) Pop ready item into output_buffer".
                            format(channeldata.values()[0].id, op_name)))
424 425
                    break
                except Queue.Empty:
B
barriery 已提交
426 427 428
                    if timeout is not None:
                        remaining = endtime - _time()
                        if remaining <= 0.0:
B
barriery 已提交
429
                            _LOGGER.debug(
B
barriery 已提交
430 431
                                self._log("Op({}) Failed to get data: timeout".
                                          format(op_name)))
B
barriery 已提交
432 433 434 435
                            raise ChannelTimeoutError()
                        self._cv.wait(remaining)
                    else:
                        self._cv.wait()
436
            if self._stop.value == 1:
B
barrierye 已提交
437
                raise ChannelStopError()
438

B
barrierye 已提交
439 440 441 442
            consumer_cursor = self._consumer_cursors[op_name]
            base_cursor = self._base_cursor.value
            data_idx = consumer_cursor - base_cursor
            resp = self._output_buf[data_idx]
443

B
barrierye 已提交
444 445 446 447 448 449 450 451
            self._cursor_count[consumer_cursor] -= 1
            if consumer_cursor == base_cursor and self._cursor_count[
                    consumer_cursor] == 0:
                # When all the different Ops get the data that data_idx points
                # to, pop the data from output_buf.
                self._cursor_count.pop(consumer_cursor)
                self._output_buf.pop(0)
                self._base_cursor.value += 1
452 453
                # to avoid cursor overflow
                if self._base_cursor.value >= self._reset_max_cursor:
B
barriery 已提交
454
                    _LOGGER.info(self._log("Reset cursor in Channel"))
455 456 457 458 459 460 461 462 463 464
                    self._base_cursor.value -= self._reset_max_cursor
                    for name in self._consumer_cursors.keys():
                        self._consumer_cursors[name] -= self._reset_max_cursor
                    cursor_count_tmp = {
                        cursor - self._reset_max_cursor: count
                        for cursor, count in self._cursor_count.copy().items()
                    }
                    self._cursor_count.clear()
                    for cursor, count in cursor_count_tmp.items():
                        self._cursor_count[cursor] = count
B
barrierye 已提交
465 466 467 468 469 470

            self._consumer_cursors[op_name] += 1
            new_consumer_cursor = self._consumer_cursors[op_name]
            if self._cursor_count.get(new_consumer_cursor) is None:
                self._cursor_count[new_consumer_cursor] = 0
            self._cursor_count[new_consumer_cursor] += 1
471

472 473
            self._cv.notify_all()

B
barriery 已提交
474
        _LOGGER.debug(
B
barriery 已提交
475 476
            self._log("(logid={}) Op({}) Got data from output_buffer".format(
                resp.values()[0].id, op_name)))
B
barriery 已提交
477
        return resp
478 479

    def stop(self):
480
        _LOGGER.info(self._log("stop."))
481
        self._stop.value = 1
B
barrierye 已提交
482 483
        with self._cv:
            self._cv.notify_all()
484 485 486 487 488 489 490 491 492 493 494 495


class ThreadChannel(Queue.Queue):
    """ 
    (Thread version)The channel used for communication between Ops.

    1. Support multiple different Op feed data (multiple producer)
        Different types of data will be packaged through the data ID
    2. Support multiple different Op fetch data (multiple consumer)
        Only when all types of Ops get the data of the same ID,
        the data will be poped; The Op of the same type will not
        get the data of the same ID.
B
barriery 已提交
496
    3. Function front support timeout param to make auto-batching.
497 498 499 500 501

    Note:
    1. The ID of the data in the channel must be different.
    2. The function add_producer() and add_consumer() are not thread safe,
       and can only be called during initialization.
B
barrierye 已提交
502 503 504 505 506 507 508 509 510 511 512

    There are two buffers and one queue in Channel:

        op_A \                                           / op_D
        op_B - a. input_buf -> b. queue -> c. output_buf - op_E
        op_C /                                           \ op_F
    
    a. In input_buf, the input of multiple predecessor Ops is packed by data ID.
    b. The packed data will be stored in queue.
    c. In order to support multiple successor Ops to retrieve data, output_buf
        maintains the data obtained from queue.
513 514
    """

B
barriery 已提交
515
    def __init__(self, name=None, maxsize=-1):
516 517 518 519 520 521 522 523
        Queue.Queue.__init__(self, maxsize=maxsize)
        self._maxsize = maxsize
        self.name = name
        self._stop = False

        self._cv = threading.Condition()

        self._producers = []
B
barrierye 已提交
524
        self._pushed_producer_count = {}  # {data_id: count}
B
barrierye 已提交
525
        self._input_buf = {}  # {data_id: {op_name: data}}
526

527
        self._reset_max_cursor = 1000000000000000000
B
barrierye 已提交
528 529 530 531
        self._consumer_cursors = {}  # {op_name: idx}
        self._cursor_count = {}  # {cursor: count}
        self._base_cursor = 0
        self._output_buf = []
532 533 534 535 536

    def get_producers(self):
        return self._producers

    def get_consumers(self):
B
barrierye 已提交
537
        return self._consumer_cursors.keys()
538 539 540 541 542 543 544

    def _log(self, info_str):
        return "[{}] {}".format(self.name, info_str)

    def add_producer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
        if op_name in self._producers:
545
            _LOGGER.critical(
B
barriery 已提交
546 547
                self._log("Failed to add producer: producer({}) is "
                          "already in channel".format(op_name)))
548
            os._exit(-1)
549
        self._producers.append(op_name)
B
barriery 已提交
550
        _LOGGER.debug(self._log("Succ add a producer: {}".format(op_name)))
551 552 553

    def add_consumer(self, op_name):
        """ not thread safe, and can only be called during initialization. """
B
barrierye 已提交
554
        if op_name in self._consumer_cursors:
555
            _LOGGER.critical(
B
barriery 已提交
556 557
                self._log("Failed to add consumer: consumer({}) is "
                          "already in channel".format(op_name)))
558
            os._exit(-1)
B
barrierye 已提交
559
        self._consumer_cursors[op_name] = 0
560

B
barrierye 已提交
561 562 563
        if self._cursor_count.get(0) is None:
            self._cursor_count[0] = 0
        self._cursor_count[0] += 1
B
barriery 已提交
564
        _LOGGER.debug(self._log("Succ add a consumer: {}".format(op_name)))
565 566

    def push(self, channeldata, op_name=None):
B
barrierye 已提交
567
        _LOGGER.debug(
B
barriery 已提交
568 569
            self._log("(logid={}) Op({}) Pushing data".format(channeldata.id,
                                                              op_name)))
570
        if len(self._producers) == 0:
571
            _LOGGER.critical(
572
                self._log(
B
barriery 已提交
573 574 575
                    "(logid={}) Op({}) Failed to push data: expected number of "
                    "producers to be greater than 0, but the it is 0.".format(
                        channeldata.id, op_name)))
576
            os._exit(-1)
577 578 579 580
        elif len(self._producers) == 1:
            with self._cv:
                while self._stop is False:
                    try:
B
barrierye 已提交
581
                        self.put({op_name: channeldata}, timeout=0)
582 583 584
                        break
                    except Queue.Full:
                        self._cv.wait()
B
barrierye 已提交
585 586
                if self._stop:
                    raise ChannelStopError()
587
                self._cv.notify_all()
B
barriery 已提交
588
            _LOGGER.debug(
B
barriery 已提交
589 590
                self._log("(logid={}) Op({}) Pushed data into internal_queue.".
                          format(channeldata.id, op_name)))
591 592
            return True
        elif op_name is None:
593
            _LOGGER.critical(
594
                self._log(
B
barriery 已提交
595 596 597
                    "(logid={}) Op({}) Failed to push data: there are multiple"
                    " producers, so op_name cannot be None.".format(
                        channeldata.id, op_name)))
598
            os._exit(-1)
599 600 601 602 603

        producer_num = len(self._producers)
        data_id = channeldata.id
        put_data = None
        with self._cv:
B
barrierye 已提交
604 605
            if data_id not in self._input_buf:
                self._input_buf[data_id] = {
606 607 608
                    name: None
                    for name in self._producers
                }
B
barrierye 已提交
609
                self._pushed_producer_count[data_id] = 0
B
barrierye 已提交
610
            self._input_buf[data_id][op_name] = channeldata
B
barrierye 已提交
611
            if self._pushed_producer_count[data_id] + 1 == producer_num:
B
barrierye 已提交
612 613
                put_data = self._input_buf[data_id]
                self._input_buf.pop(data_id)
B
barrierye 已提交
614
                self._pushed_producer_count.pop(data_id)
615
            else:
B
barrierye 已提交
616
                self._pushed_producer_count[data_id] += 1
617 618

            if put_data is None:
B
barrierye 已提交
619
                _LOGGER.debug(
B
barriery 已提交
620 621 622
                    self._log(
                        "(logid={}) Op({}) Pushed data into input_buffer.".
                        format(data_id, op_name)))
623 624 625 626 627 628 629
            else:
                while self._stop is False:
                    try:
                        self.put(put_data, timeout=0)
                        break
                    except Queue.Empty:
                        self._cv.wait()
B
barrierye 已提交
630 631
                if self._stop:
                    raise ChannelStopError()
632

B
barrierye 已提交
633
                _LOGGER.debug(
B
barriery 已提交
634 635 636
                    self._log(
                        "(logid={}) Op({}) Pushed data into internal_queue.".
                        format(data_id, op_name)))
637 638 639
            self._cv.notify_all()
        return True

B
barriery 已提交
640
    def front(self, op_name=None, timeout=None):
B
barriery 已提交
641
        _LOGGER.debug(
B
barriery 已提交
642 643
            self._log("Op({}) Getting data[?]; timeout(s)={}".format(op_name,
                                                                     timeout)))
B
barriery 已提交
644
        endtime = None
B
bug fix  
barriery 已提交
645 646 647 648 649
        if timeout is not None:
            if timeout <= 0:
                timeout = None
            else:
                endtime = _time() + timeout
B
barriery 已提交
650

B
barrierye 已提交
651
        if len(self._consumer_cursors) == 0:
652
            _LOGGER.critical(
653
                self._log(
B
barriery 已提交
654 655
                    "Op({}) Failed to get data: expected number of consumers to be "
                    "greater than 0, but the it is 0.".format(op_name)))
656
            os._exit(-1)
B
barrierye 已提交
657
        elif len(self._consumer_cursors) == 1:
658 659 660 661 662 663 664
            resp = None
            with self._cv:
                while self._stop is False and resp is None:
                    try:
                        resp = self.get(timeout=0)
                        break
                    except Queue.Empty:
B
barriery 已提交
665 666 667
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
668
                                _LOGGER.debug(
B
barriery 已提交
669 670 671
                                    self._log(
                                        "Op({}) Failed to get data: timeout".
                                        format(op_name)))
B
barriery 已提交
672 673 674 675
                                raise ChannelTimeoutError()
                            self._cv.wait(remaining)
                        else:
                            self._cv.wait()
B
barrierye 已提交
676 677
                if self._stop:
                    raise ChannelStopError()
B
barrierye 已提交
678
            _LOGGER.debug(
B
barriery 已提交
679 680
                self._log("(logid={}) Op({}) Got data".format(resp.values()[0]
                                                              .id, op_name)))
681 682
            return resp
        elif op_name is None:
683
            _LOGGER.critical(
B
barriery 已提交
684 685 686
                self._log("Op({}) Failed to get data: there are multiple "
                          "consumers, so op_name cannot be None.".format(
                              op_name)))
687
            os._exit(-1)
688

B
barrierye 已提交
689 690 691 692 693 694 695 696 697 698
        # In output_buf, different Ops (according to op_name) have different
        # cursors. In addition, there is a base_cursor. Their difference is
        # the data_idx to be taken by the corresponding Op at the current
        # time:    data_idx = consumer_cursor - base_cursor
        # 
        #            base_cursor    consumer_B_cursor (data_idx: 3)
        #                 |                       |
        # output_buf: | data0 | data1 | data2 | data3 |
        #                 |
        #   consumer_A_cursor (data_idx: 0)
699
        with self._cv:
B
barrierye 已提交
700 701 702 703
            # When the data required by the current Op is not in output_buf,
            # it is necessary to obtain a data from queue and add it to output_buf.
            while self._stop is False and self._consumer_cursors[
                    op_name] - self._base_cursor >= len(self._output_buf):
704 705
                try:
                    channeldata = self.get(timeout=0)
B
barrierye 已提交
706
                    self._output_buf.append(channeldata)
B
barriery 已提交
707
                    _LOGGER.debug(
B
barriery 已提交
708 709 710
                        self._log(
                            "(logid={}) Op({}) Pop ready item into output_buffer".
                            format(channeldata.values()[0].id, op_name)))
711 712
                    break
                except Queue.Empty:
B
barriery 已提交
713 714 715
                    if timeout is not None:
                        remaining = endtime - _time()
                        if remaining <= 0.0:
B
barriery 已提交
716
                            _LOGGER.debug(
B
barriery 已提交
717 718
                                self._log("Op({}) Failed to get data: timeout".
                                          format(op_name)))
B
barriery 已提交
719 720 721 722
                            raise ChannelTimeoutError()
                        self._cv.wait(remaining)
                    else:
                        self._cv.wait()
B
barrierye 已提交
723 724
            if self._stop:
                raise ChannelStopError()
725

B
barrierye 已提交
726 727 728
            consumer_cursor = self._consumer_cursors[op_name]
            base_cursor = self._base_cursor
            data_idx = consumer_cursor - base_cursor
B
barrierye 已提交
729 730

            resp = None
731

B
barrierye 已提交
732 733 734 735 736 737
            self._cursor_count[consumer_cursor] -= 1
            if consumer_cursor == base_cursor and self._cursor_count[
                    consumer_cursor] == 0:
                # When all the different Ops get the data that data_idx points
                # to, pop the data from output_buf.
                self._cursor_count.pop(consumer_cursor)
B
barrierye 已提交
738
                resp = self._output_buf.pop(0)
B
barrierye 已提交
739
                self._base_cursor += 1
740 741
                # to avoid cursor overflow
                if self._base_cursor >= self._reset_max_cursor:
B
barriery 已提交
742
                    _LOGGER.info(self._log("Reset cursor in Channel"))
743 744 745 746 747 748 749
                    self._base_cursor -= self._reset_max_cursor
                    for name in self._consumer_cursors:
                        self._consumer_cursors[name] -= self._reset_max_cursor
                    self._cursor_count = {
                        cursor - self._reset_max_cursor: count
                        for cursor, count in self._cursor_count.items()
                    }
B
barrierye 已提交
750 751
            else:
                resp = copy.deepcopy(self._output_buf[data_idx])
B
barrierye 已提交
752 753 754 755 756 757

            self._consumer_cursors[op_name] += 1
            new_consumer_cursor = self._consumer_cursors[op_name]
            if self._cursor_count.get(new_consumer_cursor) is None:
                self._cursor_count[new_consumer_cursor] = 0
            self._cursor_count[new_consumer_cursor] += 1
758 759 760

            self._cv.notify_all()

B
barriery 已提交
761
        _LOGGER.debug(
B
barriery 已提交
762 763
            self._log("(logid={}) Op({}) Got data from output_buffer".format(
                resp.values()[0].id, op_name)))
B
barrierye 已提交
764
        return resp
765 766

    def stop(self):
767
        _LOGGER.info(self._log("stop."))
768
        self._stop = True
B
barrierye 已提交
769 770 771
        with self._cv:
            self._cv.notify_all()

B
barriery 已提交
772

B
barriery 已提交
773 774 775
class ChannelTimeoutError(RuntimeError):
    def __init__(self):
        pass
B
barrierye 已提交
776

B
barriery 已提交
777

B
barrierye 已提交
778 779 780
class ChannelStopError(RuntimeError):
    def __init__(self):
        pass