From 51e21362a0ce36a58857a835dc81fa174c77bd82 Mon Sep 17 00:00:00 2001 From: barrierye Date: Thu, 9 Jul 2020 22:05:19 +0800 Subject: [PATCH] add ProcessPipelineServicer --- python/pipeline/channel.py | 36 ++++++++-- python/pipeline/dag.py | 75 +++++++++++--------- python/pipeline/operator.py | 106 ++++++++++++++++++++--------- python/pipeline/pipeline_server.py | 59 +++++++++++----- 4 files changed, 188 insertions(+), 88 deletions(-) diff --git a/python/pipeline/channel.py b/python/pipeline/channel.py index 5de1accf..071b0014 100644 --- a/python/pipeline/channel.py +++ b/python/pipeline/channel.py @@ -37,7 +37,8 @@ class ChannelDataEcode(enum.Enum): TYPE_ERROR = 3 RPC_PACKAGE_ERROR = 4 CLIENT_ERROR = 5 - UNKNOW = 6 + CLOSED_ERROR = 6 + UNKNOW = 7 class ChannelDataType(enum.Enum): @@ -258,6 +259,8 @@ class ProcessChannel(object): break except Queue.Full: self._cv.wait() + if self._stop: + raise ChannelStopError() _LOGGER.debug( self._log("{} channel size: {}".format(op_name, self._que.qsize()))) @@ -308,6 +311,8 @@ class ProcessChannel(object): break except Queue.Empty: self._cv.wait() + if self._stop: + raise ChannelStopError() _LOGGER.debug( self._log("multi | {} push data succ!".format(op_name))) @@ -337,6 +342,8 @@ class ProcessChannel(object): "{} wait for empty queue(with channel empty: {})". format(op_name, self._que.empty()))) self._cv.wait() + if self._stop: + raise ChannelStopError() _LOGGER.debug( self._log("{} get data succ: {}".format(op_name, resp.__str__( )))) @@ -379,6 +386,8 @@ class ProcessChannel(object): "{} wait for empty queue(with channel size: {})". format(op_name, self._que.qsize()))) self._cv.wait() + if self._stop: + raise ChannelStopError() consumer_cursor = self._consumer_cursors[op_name] base_cursor = self._base_cursor.value @@ -425,10 +434,10 @@ class ProcessChannel(object): return resp # reference, read only def stop(self): - #TODO - self.close() + _LOGGER.info(self._log("stop.")) self._stop = True - self._cv.notify_all() + with self._cv: + self._cv.notify_all() class ThreadChannel(Queue.Queue): @@ -527,6 +536,8 @@ class ThreadChannel(Queue.Queue): break except Queue.Full: self._cv.wait() + if self._stop: + raise ChannelStopError() self._cv.notify_all() _LOGGER.debug(self._log("{} push data succ!".format(op_name))) return True @@ -565,6 +576,8 @@ class ThreadChannel(Queue.Queue): break except Queue.Empty: self._cv.wait() + if self._stop: + raise ChannelStopError() _LOGGER.debug( self._log("multi | {} push data succ!".format(op_name))) @@ -587,6 +600,8 @@ class ThreadChannel(Queue.Queue): break except Queue.Empty: self._cv.wait() + if self._stop: + raise ChannelStopError() _LOGGER.debug( self._log("{} get data succ: {}".format(op_name, resp.__str__( )))) @@ -617,6 +632,8 @@ class ThreadChannel(Queue.Queue): break except Queue.Empty: self._cv.wait() + if self._stop: + raise ChannelStopError() consumer_cursor = self._consumer_cursors[op_name] base_cursor = self._base_cursor @@ -657,7 +674,12 @@ class ThreadChannel(Queue.Queue): return resp def stop(self): - #TODO - self.close() + _LOGGER.info(self._log("stop.")) self._stop = True - self._cv.notify_all() + with self._cv: + self._cv.notify_all() + + +class ChannelStopError(RuntimeError): + def __init__(self): + pass diff --git a/python/pipeline/dag.py b/python/pipeline/dag.py index deeb87ed..100bda6f 100644 --- a/python/pipeline/dag.py +++ b/python/pipeline/dag.py @@ -26,7 +26,8 @@ import os import logging from .operator import Op, RequestOp, ResponseOp, VirtualOp -from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcode, ChannelDataType +from .channel import (ThreadChannel, ProcessChannel, ChannelData, + ChannelDataEcode, ChannelDataType, ChannelStopError) from .profiler import TimeProfiler from .util import NameGenerator @@ -34,33 +35,29 @@ _LOGGER = logging.getLogger() class DAGExecutor(object): - def __init__(self, response_op, yml_config, show_info): - self._retry = yml_config.get('retry', 1) + def __init__(self, response_op, dag_config, show_info): + self._retry = dag_config.get('retry', 1) - client_type = yml_config.get('client_type', 'brpc') - use_multithread = yml_config.get('use_multithread', True) - use_profile = yml_config.get('profile', False) - channel_size = yml_config.get('channel_size', 0) - self._asyn_profile = yml_config.get('asyn_profile', False) + client_type = dag_config.get('client_type', 'brpc') + use_profile = dag_config.get('use_profile', False) + channel_size = dag_config.get('channel_size', 0) + self._is_thread_op = dag_config.get('is_thread_op', True) if show_info and use_profile: _LOGGER.info("================= PROFILER ================") - if use_multithread: + if self._is_thread_op: _LOGGER.info("op: thread") + _LOGGER.info("profile mode: sync") else: _LOGGER.info("op: process") - if self._asyn_profile: - _LOGGER.info("profile mode: asyn (This mode is only used" - " when using the process version Op)") - else: - _LOGGER.info("profile mode: sync") + _LOGGER.info("profile mode: asyn") _LOGGER.info("-------------------------------------------") self.name = "@G" self._profiler = TimeProfiler() self._profiler.enable(use_profile) - self._dag = DAG(self.name, response_op, use_profile, use_multithread, + self._dag = DAG(self.name, response_op, use_profile, self._is_thread_op, client_type, channel_size, show_info) (in_channel, out_channel, pack_rpc_func, unpack_rpc_func) = self._dag.build() @@ -80,17 +77,14 @@ class DAGExecutor(object): self._cv_pool = {} self._cv_for_cv_pool = threading.Condition() self._fetch_buffer = None - self._is_run = False self._recive_func = None def start(self): - self._is_run = True self._recive_func = threading.Thread( target=DAGExecutor._recive_out_channel_func, args=(self, )) self._recive_func.start() def stop(self): - self._is_run = False self._dag.stop() self._dag.join() @@ -119,8 +113,22 @@ class DAGExecutor(object): def _recive_out_channel_func(self): cv = None - while self._is_run: - channeldata_dict = self._out_channel.front(self.name) + while True: + try: + channeldata_dict = self._out_channel.front(self.name) + except ChannelStopError: + _LOGGER.info(self._log("stop.")) + with self._cv_for_cv_pool: + for data_id, cv in self._cv_pool.items(): + closed_errror_data = ChannelData( + ecode=ChannelDataEcode.CLOSED_ERROR.value, + error_info="dag closed.", + data_id=data_id) + with cv: + self._fetch_buffer = closed_errror_data + cv.notify_all() + break + if len(channeldata_dict) != 1: _LOGGER.error("out_channel cannot have multiple input ops") os._exit(-1) @@ -147,7 +155,6 @@ class DAGExecutor(object): cv.wait() _LOGGER.debug("resp func get lock (data_id: {})".format(data_id)) resp = copy.deepcopy(self._fetch_buffer) - # cv.notify_all() with self._cv_for_cv_pool: self._cv_pool.pop(data_id) return resp @@ -170,7 +177,7 @@ class DAGExecutor(object): def call(self, rpc_request): data_id = self._get_next_data_id() - if self._asyn_profile: + if not self._is_thread_op: self._profiler.record("call_{}#DAG-{}_0".format(data_id, data_id)) else: self._profiler.record("call_{}#DAG_0".format(data_id)) @@ -183,7 +190,15 @@ class DAGExecutor(object): for i in range(self._retry): _LOGGER.debug(self._log('push data')) #self._profiler.record("push_{}#{}_0".format(data_id, self.name)) - self._in_channel.push(req_channeldata, self.name) + try: + self._in_channel.push(req_channeldata, self.name) + except ChannelStopError: + _LOGGER.info(self._log("stop.")) + return self._pack_for_rpc_resp( + ChannelData( + ecode=ChannelDataEcode.CLOSED_ERROR.value, + error_info="dag closed.", + data_id=data_id)) #self._profiler.record("push_{}#{}_1".format(data_id, self.name)) _LOGGER.debug(self._log('wait for infer')) @@ -201,7 +216,7 @@ class DAGExecutor(object): rpc_resp = self._pack_for_rpc_resp(resp_channeldata) self._profiler.record("postpack_{}#{}_1".format(data_id, self.name)) - if self._asyn_profile: + if not self._is_thread_op: self._profiler.record("call_{}#DAG-{}_1".format(data_id, data_id)) else: self._profiler.record("call_{}#DAG_1".format(data_id)) @@ -217,16 +232,16 @@ class DAGExecutor(object): class DAG(object): - def __init__(self, request_name, response_op, use_profile, use_multithread, + def __init__(self, request_name, response_op, use_profile, is_thread_op, client_type, channel_size, show_info): self._request_name = request_name self._response_op = response_op self._use_profile = use_profile - self._use_multithread = use_multithread + self._is_thread_op = is_thread_op self._channel_size = channel_size self._client_type = client_type self._show_info = show_info - if not self._use_multithread: + if not self._is_thread_op: self._manager = multiprocessing.Manager() def get_use_ops(self, response_op): @@ -254,7 +269,7 @@ class DAG(object): def _gen_channel(self, name_gen): channel = None - if self._use_multithread: + if self._is_thread_op: channel = ThreadChannel( name=name_gen.next(), maxsize=self._channel_size) else: @@ -439,7 +454,7 @@ class DAG(object): self._threads_or_proces = [] for op in self._actual_ops: op.use_profiler(self._use_profile) - if self._use_multithread: + if self._is_thread_op: self._threads_or_proces.extend( op.start_with_thread(self._client_type)) else: @@ -453,7 +468,5 @@ class DAG(object): x.join() def stop(self): - for op in self._actual_ops: - op.stop() for chl in self._channels: chl.stop() diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index 7d1e1326..bc9a2eda 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -24,7 +24,8 @@ import sys from numpy import * from .proto import pipeline_service_pb2 -from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType +from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode, + ChannelData, ChannelDataType, ChannelStopError) from .util import NameGenerator from .profiler import TimeProfiler @@ -44,7 +45,6 @@ class Op(object): retry=1): if name is None: name = _op_name_gen.next() - self._is_run = False self.name = name # to identify the type of OP, it must be globally unique self.concurrency = concurrency # amount of concurrency self.set_input_ops(input_ops) @@ -65,7 +65,9 @@ class Op(object): # only for multithread self._for_init_op_lock = threading.Lock() + self._for_close_op_lock = threading.Lock() self._succ_init_op = False + self._succ_close_op = False def use_profiler(self, use_profile): self._use_profile = use_profile @@ -93,9 +95,6 @@ class Op(object): self._fetch_names = fetch_names return client - def _get_input_channel(self): - return self._input - def get_input_ops(self): return self._input_ops @@ -118,8 +117,11 @@ class Op(object): channel.add_consumer(self.name) self._input = channel - def _get_output_channels(self): - return self._outputs + def _clean_input_channel(self): + self._input = None + + def _get_input_channel(self): + return self._input def add_output_channel(self, channel): if not isinstance(channel, (ThreadChannel, ProcessChannel)): @@ -129,6 +131,12 @@ class Op(object): channel.add_producer(self.name) self._outputs.append(channel) + def _clean_output_channels(self): + self._outputs = [] + + def _get_output_channels(self): + return self._outputs + def preprocess(self, input_dicts): # multiple previous Op if len(input_dicts) != 1: @@ -144,7 +152,7 @@ class Op(object): if err != 0: raise NotImplementedError( "{} Please override preprocess func.".format(err_info)) - call_result = self._client.predict( + call_result = self.client.predict( feed=feed_dict, fetch=self._fetch_names) _LOGGER.debug(self._log("get call_result")) return call_result @@ -152,9 +160,6 @@ class Op(object): def postprocess(self, input_dict, fetch_dict): return fetch_dict - def stop(self): - self._is_run = False - def _parse_channeldata(self, channeldata_dict): data_id, error_channeldata = None, None parsed_data = {} @@ -332,30 +337,42 @@ class Op(object): self._profiler = TimeProfiler() self._profiler.enable(self._use_profile) # init client - self._client = self.init_client( + self.client = self.init_client( client_type, self._client_config, self._server_endpoints, self._fetch_names) # user defined self.init_op() self._succ_init_op = True + self._succ_close_op = False else: # init profiler self._profiler = TimeProfiler() self._profiler.enable(self._use_profile) # init client - self._client = self.init_client( - client_type, self._client_config, self._server_endpoints, - self._fetch_names) + self.client = self.init_client(client_type, self._client_config, + self._server_endpoints, + self._fetch_names) # user defined self.init_op() except Exception as e: _LOGGER.error(log(e)) os._exit(-1) - self._is_run = True - while self._is_run: + while True: #self._profiler_record("get#{}_0".format(op_info_prefix)) - channeldata_dict = input_channel.front(self.name) + try: + channeldata_dict = input_channel.front(self.name) + except ChannelStopError: + _LOGGER.info(log("stop.")) + with self._for_close_op_lock: + if not self._succ_close_op: + self._clean_input_channel() + self._clean_output_channels() + self._profiler = None + self.client = None + self._succ_init_op = False + self._succ_close_op = True + break #self._profiler_record("get#{}_1".format(op_info_prefix)) _LOGGER.debug(log("input_data: {}".format(channeldata_dict))) @@ -363,8 +380,11 @@ class Op(object): channeldata_dict) # error data in predecessor Op if error_channeldata is not None: - self._push_to_output_channels(error_channeldata, - output_channels) + try: + self._push_to_output_channels(error_channeldata, + output_channels) + except ChannelStopError: + _LOGGER.info(log("stop.")) continue # preprecess @@ -373,8 +393,11 @@ class Op(object): data_id, log) self._profiler_record("prep#{}_1".format(op_info_prefix)) if error_channeldata is not None: - self._push_to_output_channels(error_channeldata, - output_channels) + try: + self._push_to_output_channels(error_channeldata, + output_channels) + except ChannelStopError: + _LOGGER.info(log("stop.")) continue # process @@ -383,8 +406,11 @@ class Op(object): data_id, log) self._profiler_record("midp#{}_1".format(op_info_prefix)) if error_channeldata is not None: - self._push_to_output_channels(error_channeldata, - output_channels) + try: + self._push_to_output_channels(error_channeldata, + output_channels) + except ChannelStopError: + _LOGGER.info(log("stop.")) continue # postprocess @@ -393,8 +419,11 @@ class Op(object): parsed_data, midped_data, data_id, log) self._profiler_record("postp#{}_1".format(op_info_prefix)) if error_channeldata is not None: - self._push_to_output_channels(error_channeldata, - output_channels) + try: + self._push_to_output_channels(error_channeldata, + output_channels) + except ChannelStopError: + _LOGGER.info(log("stop.")) continue if self._use_profile: @@ -405,7 +434,11 @@ class Op(object): # push data to channel (if run succ) #self._profiler_record("push#{}_0".format(op_info_prefix)) - self._push_to_output_channels(output_data, output_channels) + try: + self._push_to_output_channels(output_data, output_channels) + except ChannelStopError: + _LOGGER.info(log("stop.")) + break #self._profiler_record("push#{}_1".format(op_info_prefix)) #self._profiler.print_profile() @@ -525,10 +558,17 @@ class VirtualOp(Op): log = get_log_func(op_info_prefix) tid = threading.current_thread().ident - self._is_run = True - while self._is_run: - channeldata_dict = input_channel.front(self.name) + while True: + try: + channeldata_dict = input_channel.front(self.name) + except ChannelStopError: + _LOGGER.info(log("stop.")) + break - for name, data in channeldata_dict.items(): - self._push_to_output_channels( - data, channels=output_channels, name=name) + try: + for name, data in channeldata_dict.items(): + self._push_to_output_channels( + data, channels=output_channels, name=name) + except ChannelStopError: + _LOGGER.info(log("stop.")) + break diff --git a/python/pipeline/pipeline_server.py b/python/pipeline/pipeline_server.py index e41556e0..7e20425a 100644 --- a/python/pipeline/pipeline_server.py +++ b/python/pipeline/pipeline_server.py @@ -29,16 +29,36 @@ _LOGGER = logging.getLogger() class PipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): - def __init__(self, response_op, yml_config, show_info=True): + def __init__(self, response_op, dag_config): super(PipelineService, self).__init__() # init dag executor - self._dag_executor = DAGExecutor(response_op, yml_config, show_info) + self._dag_executor = DAGExecutor( + response_op, dag_config, show_info=True) self._dag_executor.start() def inference(self, request, context): resp = self._dag_executor.call(request) return resp + def __del__(self): + self._dag_executor.stop() + + +class ProcessPipelineService(pipeline_service_pb2_grpc.PipelineServiceServicer): + def __init__(self, response_op, dag_config): + super(ProcessPipelineService, self).__init__() + self._response_op = response_op + self._dag_config = dag_config + + def inference(self, request, context): + # init dag executor + dag_executor = DAGExecutor( + self._response_op, self._dag_config, show_info=False) + dag_executor.start() + resp = dag_executor.call(request) + dag_executor.stop() + return resp + @contextlib.contextmanager def _reserve_port(port): @@ -75,25 +95,30 @@ class PipelineServer(object): def prepare_server(self, yml_file): with open(yml_file) as f: - self._yml_config = yaml.load(f.read()) - self._port = self._yml_config.get('port', 8080) + yml_config = yaml.load(f.read()) + self._port = yml_config.get('port') + if self._port is None: + raise SystemExit("Please set *port* in [{}] yaml file.".format( + yml_file)) if not self._port_is_available(self._port): raise SystemExit("Prot {} is already used".format(self._port)) - self._worker_num = self._yml_config.get('worker_num', 2) - self._multiprocess_servicer = self._yml_config.get( - 'multiprocess_servicer', False) + self._worker_num = yml_config.get('worker_num', 1) + self._build_dag_each_request = yml_config.get('build_dag_each_request', + False) _LOGGER.info("============= PIPELINE SERVER =============") _LOGGER.info("port: {}".format(self._port)) _LOGGER.info("worker_num: {}".format(self._worker_num)) - servicer_info = "multiprocess_servicer: {}".format( - self._multiprocess_servicer) - if self._multiprocess_servicer is True: + servicer_info = "build_dag_each_request: {}".format( + self._build_dag_each_request) + if self._build_dag_each_request is True: servicer_info += " (Make sure that install grpcio whl with --no-binary flag)" _LOGGER.info(servicer_info) _LOGGER.info("-------------------------------------------") + self._dag_config = yml_config.get("dag", {}) + def run_server(self): - if self._multiprocess_servicer: + if self._build_dag_each_request: with _reserve_port(self._port) as port: bind_address = 'localhost:{}'.format(port) workers = [] @@ -101,8 +126,8 @@ class PipelineServer(object): show_info = (i == 0) worker = multiprocessing.Process( target=self._run_server_func, - args=(bind_address, self._response_op, self._yml_config, - self._worker_num, show_info)) + args=(bind_address, self._response_op, self._dag_config, + self._worker_num)) worker.start() workers.append(worker) for worker in workers: @@ -111,20 +136,20 @@ class PipelineServer(object): server = grpc.server( futures.ThreadPoolExecutor(max_workers=self._worker_num)) pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( - PipelineService(self._response_op, self._yml_config), server) + PipelineService(self._response_op, self._dag_config), server) server.add_insecure_port('[::]:{}'.format(self._port)) server.start() server.wait_for_termination() - def _run_server_func(self, bind_address, response_op, yml_config, - worker_num, show_info): + def _run_server_func(self, bind_address, response_op, dag_config, + worker_num): options = (('grpc.so_reuseport', 1), ) server = grpc.server( futures.ThreadPoolExecutor( max_workers=worker_num, ), options=options) pipeline_service_pb2_grpc.add_PipelineServiceServicer_to_server( - PipelineService(response_op, yml_config, show_info), server) + ProcessPipelineService(response_op, dag_config), server) server.add_insecure_port(bind_address) server.start() server.wait_for_termination() -- GitLab