From ac66194eb5f9b4282f47f78437ccef0f229d2f91 Mon Sep 17 00:00:00 2001 From: barrierye Date: Sun, 28 Jun 2020 20:03:35 +0800 Subject: [PATCH] add comment for channel --- python/pipeline/channel.py | 241 +++++++++++++++++------------ python/pipeline/operator.py | 3 - python/pipeline/pipeline_server.py | 22 ++- 3 files changed, 159 insertions(+), 107 deletions(-) diff --git a/python/pipeline/channel.py b/python/pipeline/channel.py index 9b0023e9..23b14b95 100644 --- a/python/pipeline/channel.py +++ b/python/pipeline/channel.py @@ -143,6 +143,17 @@ class ProcessChannel(multiprocessing.queues.Queue): 1. The ID of the data in the channel must be different. 2. The function add_producer() and add_consumer() are not thread safe, and can only be called during initialization. + + There are two buffers and one queue in Channel: + + op_A \ / op_D + op_B - a. input_buf -> b. queue -> c. output_buf - op_E + op_C / \ op_F + + a. In input_buf, the input of multiple predecessor Ops is packed by data ID. + b. The packed data will be stored in queue. + c. In order to support multiple successor Ops to retrieve data, output_buf + maintains the data obtained from queue. """ def __init__(self, manager, name=None, maxsize=0, timeout=None): @@ -162,19 +173,19 @@ class ProcessChannel(multiprocessing.queues.Queue): self._cv = multiprocessing.Condition() self._producers = [] - self._producer_res_count = manager.dict() # {data_id: count} - self._push_res = manager.dict() # {data_id: {op_name: data}} + self.pushed_producer_count = manager.dict() # {data_id: count} + self._input_buf = manager.dict() # {data_id: {op_name: data}} - self._consumers = manager.dict() # {op_name: idx} - self._idx_consumer_num = manager.dict() # {idx: num} - self._consumer_base_idx = manager.Value('i', 0) - self._front_res = manager.list() + self._consumer_cursors = manager.dict() # {op_name: cursor} + self._cursor_count = manager.dict() # {cursor: count} + self._base_cursor = manager.Value('i', 0) + self._output_buf = manager.list() def get_producers(self): return self._producers def get_consumers(self): - return self._consumers.keys() + return self._consumer_cursors.keys() def _log(self, info_str): return "[{}] {}".format(self.name, info_str) @@ -192,14 +203,14 @@ class ProcessChannel(multiprocessing.queues.Queue): def add_consumer(self, op_name): """ not thread safe, and can only be called during initialization. """ - if op_name in self._consumers: + if op_name in self._consumer_cursors: raise ValueError( self._log("consumer({}) is already in channel".format(op_name))) - self._consumers[op_name] = 0 + self._consumer_cursors[op_name] = 0 - if self._idx_consumer_num.get(0) is None: - self._idx_consumer_num[0] = 0 - self._idx_consumer_num[0] += 1 + if self._cursor_count.get(0) is None: + self._cursor_count[0] = 0 + self._cursor_count[0] += 1 def push(self, channeldata, op_name=None): _LOGGER.debug( @@ -235,24 +246,24 @@ class ProcessChannel(multiprocessing.queues.Queue): put_data = None with self._cv: _LOGGER.debug(self._log("{} get lock".format(op_name))) - if data_id not in self._push_res: - self._push_res[data_id] = { + if data_id not in self._input_buf: + self._input_buf[data_id] = { name: None for name in self._producers } - self._producer_res_count[data_id] = 0 + self.pushed_producer_count[data_id] = 0 # see: https://docs.python.org/3.6/library/multiprocessing.html?highlight=multiprocess#proxy-objects - # self._push_res[data_id][op_name] = channeldata - tmp_push_res = self._push_res[data_id] - tmp_push_res[op_name] = channeldata - self._push_res[data_id] = tmp_push_res - - if self._producer_res_count[data_id] + 1 == producer_num: - put_data = self._push_res[data_id] - self._push_res.pop(data_id) - self._producer_res_count.pop(data_id) + # self._input_buf[data_id][op_name] = channeldata + tmp_input_buf = self._input_buf[data_id] + tmp_input_buf[op_name] = channeldata + self._input_buf[data_id] = tmp_input_buf + + if self.pushed_producer_count[data_id] + 1 == producer_num: + put_data = self._input_buf[data_id] + self._input_buf.pop(data_id) + self.pushed_producer_count.pop(data_id) else: - self._producer_res_count[data_id] += 1 + self.pushed_producer_count[data_id] += 1 if put_data is None: _LOGGER.debug( @@ -276,12 +287,12 @@ class ProcessChannel(multiprocessing.queues.Queue): def front(self, op_name=None): _LOGGER.debug(self._log("{} try to get data...".format(op_name))) - if len(self._consumers) == 0: + if len(self._consumer_cursors) == 0: raise Exception( self._log( "expected number of consumers to be greater than 0, but the it is 0." )) - elif len(self._consumers) == 1: + elif len(self._consumer_cursors) == 1: resp = None with self._cv: while self._stop is False and resp is None: @@ -312,16 +323,26 @@ class ProcessChannel(multiprocessing.queues.Queue): self._log( "There are multiple consumers, so op_name cannot be None.")) + # In output_buf, different Ops (according to op_name) have different + # cursors. In addition, there is a base_cursor. Their difference is + # the data_idx to be taken by the corresponding Op at the current + # time: data_idx = consumer_cursor - base_cursor + # + # base_cursor consumer_B_cursor (data_idx: 3) + # | | + # output_buf: | data0 | data1 | data2 | data3 | + # | + # consumer_A_cursor (data_idx: 0) with self._cv: - # data_idx = consumer_idx - base_idx - while self._stop is False and self._consumers[ - op_name] - self._consumer_base_idx.value >= len( - self._front_res): + # When the data required by the current Op is not in output_buf, + # it is necessary to obtain a data from queue and add it to output_buf. + while self._stop is False and self._consumer_cursors[ + op_name] - self._base_cursor.value >= len(self._output_buf): _LOGGER.debug( self._log( - "({}) B self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}". - format(op_name, self._consumers, self. - _consumer_base_idx.value, len(self._front_res)))) + "({}) B self._consumer_cursors: {}, self._base_cursor: {}, len(self._output_buf): {}". + format(op_name, self._consumer_cursors, + self._base_cursor.value, len(self._output_buf)))) try: _LOGGER.debug( self._log("{} try to get(with channel size: {})".format( @@ -333,7 +354,7 @@ class ProcessChannel(multiprocessing.queues.Queue): # - https://bugs.python.org/issue18277 # - https://hg.python.org/cpython/rev/860fc6a2bd21 channeldata = self.get(timeout=1e-3) - self._front_res.append(channeldata) + self._output_buf.append(channeldata) break except Queue.Empty: _LOGGER.debug( @@ -342,29 +363,31 @@ class ProcessChannel(multiprocessing.queues.Queue): format(op_name, self.qsize()))) self._cv.wait() - consumer_idx = self._consumers[op_name] - base_idx = self._consumer_base_idx.value - data_idx = consumer_idx - base_idx - resp = self._front_res[data_idx] + consumer_cursor = self._consumer_cursors[op_name] + base_cursor = self._base_cursor.value + data_idx = consumer_cursor - base_cursor + resp = self._output_buf[data_idx] _LOGGER.debug(self._log("{} get data: {}".format(op_name, resp))) - self._idx_consumer_num[consumer_idx] -= 1 - if consumer_idx == base_idx and self._idx_consumer_num[ - consumer_idx] == 0: - self._idx_consumer_num.pop(consumer_idx) - self._front_res.pop(0) - self._consumer_base_idx.value += 1 - - self._consumers[op_name] += 1 - new_consumer_idx = self._consumers[op_name] - if self._idx_consumer_num.get(new_consumer_idx) is None: - self._idx_consumer_num[new_consumer_idx] = 0 - self._idx_consumer_num[new_consumer_idx] += 1 + self._cursor_count[consumer_cursor] -= 1 + if consumer_cursor == base_cursor and self._cursor_count[ + consumer_cursor] == 0: + # When all the different Ops get the data that data_idx points + # to, pop the data from output_buf. + self._cursor_count.pop(consumer_cursor) + self._output_buf.pop(0) + self._base_cursor.value += 1 + + self._consumer_cursors[op_name] += 1 + new_consumer_cursor = self._consumer_cursors[op_name] + if self._cursor_count.get(new_consumer_cursor) is None: + self._cursor_count[new_consumer_cursor] = 0 + self._cursor_count[new_consumer_cursor] += 1 _LOGGER.debug( self._log( - "({}) A self._consumers: {}, self._consumer_base_idx: {}, len(self._front_res): {}". - format(op_name, self._consumers, self._consumer_base_idx. - value, len(self._front_res)))) + "({}) A self._consumer_cursors: {}, self._base_cursor: {}, len(self._output_buf): {}". + format(op_name, self._consumer_cursors, + self._base_cursor.value, len(self._output_buf)))) _LOGGER.debug(self._log("{} notify all".format(op_name))) self._cv.notify_all() @@ -394,6 +417,17 @@ class ThreadChannel(Queue.Queue): 1. The ID of the data in the channel must be different. 2. The function add_producer() and add_consumer() are not thread safe, and can only be called during initialization. + + There are two buffers and one queue in Channel: + + op_A \ / op_D + op_B - a. input_buf -> b. queue -> c. output_buf - op_E + op_C / \ op_F + + a. In input_buf, the input of multiple predecessor Ops is packed by data ID. + b. The packed data will be stored in queue. + c. In order to support multiple successor Ops to retrieve data, output_buf + maintains the data obtained from queue. """ def __init__(self, name=None, maxsize=-1, timeout=None): @@ -406,19 +440,19 @@ class ThreadChannel(Queue.Queue): self._cv = threading.Condition() self._producers = [] - self._producer_res_count = {} # {data_id: count} - self._push_res = {} # {data_id: {op_name: data}} + self.pushed_producer_count = {} # {data_id: count} + self._input_buf = {} # {data_id: {op_name: data}} - self._consumers = {} # {op_name: idx} - self._idx_consumer_num = {} # {idx: num} - self._consumer_base_idx = 0 - self._front_res = [] + self._consumer_cursors = {} # {op_name: idx} + self._cursor_count = {} # {cursor: count} + self._base_cursor = 0 + self._output_buf = [] def get_producers(self): return self._producers def get_consumers(self): - return self._consumers.keys() + return self._consumer_cursors.keys() def _log(self, info_str): return "[{}] {}".format(self.name, info_str) @@ -436,14 +470,14 @@ class ThreadChannel(Queue.Queue): def add_consumer(self, op_name): """ not thread safe, and can only be called during initialization. """ - if op_name in self._consumers: + if op_name in self._consumer_cursors: raise ValueError( self._log("consumer({}) is already in channel".format(op_name))) - self._consumers[op_name] = 0 + self._consumer_cursors[op_name] = 0 - if self._idx_consumer_num.get(0) is None: - self._idx_consumer_num[0] = 0 - self._idx_consumer_num[0] += 1 + if self._cursor_count.get(0) is None: + self._cursor_count[0] = 0 + self._cursor_count[0] += 1 def push(self, channeldata, op_name=None): _LOGGER.debug( @@ -475,19 +509,19 @@ class ThreadChannel(Queue.Queue): put_data = None with self._cv: _LOGGER.debug(self._log("{} get lock".format(op_name))) - if data_id not in self._push_res: - self._push_res[data_id] = { + if data_id not in self._input_buf: + self._input_buf[data_id] = { name: None for name in self._producers } - self._producer_res_count[data_id] = 0 - self._push_res[data_id][op_name] = channeldata - if self._producer_res_count[data_id] + 1 == producer_num: - put_data = self._push_res[data_id] - self._push_res.pop(data_id) - self._producer_res_count.pop(data_id) + self.pushed_producer_count[data_id] = 0 + self._input_buf[data_id][op_name] = channeldata + if self.pushed_producer_count[data_id] + 1 == producer_num: + put_data = self._input_buf[data_id] + self._input_buf.pop(data_id) + self.pushed_producer_count.pop(data_id) else: - self._producer_res_count[data_id] += 1 + self.pushed_producer_count[data_id] += 1 if put_data is None: _LOGGER.debug( @@ -508,12 +542,12 @@ class ThreadChannel(Queue.Queue): def front(self, op_name=None): _LOGGER.debug(self._log("{} try to get data".format(op_name))) - if len(self._consumers) == 0: + if len(self._consumer_cursors) == 0: raise Exception( self._log( "expected number of consumers to be greater than 0, but the it is 0." )) - elif len(self._consumers) == 1: + elif len(self._consumer_cursors) == 1: resp = None with self._cv: while self._stop is False and resp is None: @@ -531,35 +565,48 @@ class ThreadChannel(Queue.Queue): self._log( "There are multiple consumers, so op_name cannot be None.")) + # In output_buf, different Ops (according to op_name) have different + # cursors. In addition, there is a base_cursor. Their difference is + # the data_idx to be taken by the corresponding Op at the current + # time: data_idx = consumer_cursor - base_cursor + # + # base_cursor consumer_B_cursor (data_idx: 3) + # | | + # output_buf: | data0 | data1 | data2 | data3 | + # | + # consumer_A_cursor (data_idx: 0) with self._cv: - # data_idx = consumer_idx - base_idx - while self._stop is False and self._consumers[ - op_name] - self._consumer_base_idx >= len(self._front_res): + # When the data required by the current Op is not in output_buf, + # it is necessary to obtain a data from queue and add it to output_buf. + while self._stop is False and self._consumer_cursors[ + op_name] - self._base_cursor >= len(self._output_buf): try: channeldata = self.get(timeout=0) - self._front_res.append(channeldata) + self._output_buf.append(channeldata) break except Queue.Empty: self._cv.wait() - consumer_idx = self._consumers[op_name] - base_idx = self._consumer_base_idx - data_idx = consumer_idx - base_idx - resp = self._front_res[data_idx] + consumer_cursor = self._consumer_cursors[op_name] + base_cursor = self._base_cursor + data_idx = consumer_cursor - base_cursor + resp = self._output_buf[data_idx] _LOGGER.debug(self._log("{} get data: {}".format(op_name, resp))) - self._idx_consumer_num[consumer_idx] -= 1 - if consumer_idx == base_idx and self._idx_consumer_num[ - consumer_idx] == 0: - self._idx_consumer_num.pop(consumer_idx) - self._front_res.pop(0) - self._consumer_base_idx += 1 - - self._consumers[op_name] += 1 - new_consumer_idx = self._consumers[op_name] - if self._idx_consumer_num.get(new_consumer_idx) is None: - self._idx_consumer_num[new_consumer_idx] = 0 - self._idx_consumer_num[new_consumer_idx] += 1 + self._cursor_count[consumer_cursor] -= 1 + if consumer_cursor == base_cursor and self._cursor_count[ + consumer_cursor] == 0: + # When all the different Ops get the data that data_idx points + # to, pop the data from output_buf. + self._cursor_count.pop(consumer_cursor) + self._output_buf.pop(0) + self._base_cursor += 1 + + self._consumer_cursors[op_name] += 1 + new_consumer_cursor = self._consumer_cursors[op_name] + if self._cursor_count.get(new_consumer_cursor) is None: + self._cursor_count[new_consumer_cursor] = 0 + self._cursor_count[new_consumer_cursor] += 1 self._cv.notify_all() diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index cee56027..d82cac88 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -146,9 +146,6 @@ class Op(object): return fetch_dict def stop(self): - self._input.stop() - for channel in self._outputs: - channel.stop() self._is_run = False def _parse_channeldata(self, channeldata_dict): diff --git a/python/pipeline/pipeline_server.py b/python/pipeline/pipeline_server.py index 9963a2eb..92ad1312 100644 --- a/python/pipeline/pipeline_server.py +++ b/python/pipeline/pipeline_server.py @@ -66,6 +66,7 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): self._globel_resp_dict = {} self._id_counter = 0 self._retry = retry + self._is_run = True self._pack_func = pack_func self._unpack_func = unpack_func self._recive_func = threading.Thread( @@ -91,8 +92,11 @@ class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): out_channel.add_consumer(self.name) self._out_channel = out_channel + def stop(self): + self._is_run = False + def _recive_out_channel_func(self): - while True: + while self._is_run: channeldata_dict = self._out_channel.front(self.name) if len(channeldata_dict) != 1: raise Exception("out_channel cannot have multiple input ops") @@ -416,22 +420,26 @@ class PipelineServer(object): op.start_with_process(self._client_type)) return threads_or_proces - def _stop_ops(self): + def _stop_all(self, service): + service.stop() for op in self._actual_ops: op.stop() + for chl in self._channels: + chl.stop() def run_server(self): op_threads_or_proces = self._run_ops() + service = PipelineService(self._in_channel, self._out_channel, + self._unpack_func, self._pack_func, + self._retry) server = grpc.server( futures.ThreadPoolExecutor(max_workers=self._worker_num)) - pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( - PipelineService(self._in_channel, self._out_channel, - self._unpack_func, self._pack_func, self._retry), - server) + pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server(service, + server) server.add_insecure_port('[::]:{}'.format(self._port)) server.start() server.wait_for_termination() - self._stop_ops() # TODO + self._stop_all() # TODO for x in op_threads_or_proces: x.join() -- GitLab