diff --git a/python/pipeline/channel.py b/python/pipeline/channel.py index ce55b187e66ae02916d04a57732391de01f4ece5..953f378ae306bc8174c30224b85763b44ee2c811 100644 --- a/python/pipeline/channel.py +++ b/python/pipeline/channel.py @@ -27,7 +27,7 @@ import logging import enum import copy -_LOGGER = logging.getLogger(__name__) +_LOGGER = logging.getLogger() class ChannelDataEcode(enum.Enum): @@ -92,7 +92,16 @@ class ChannelData(object): def check_dictdata(dictdata): ecode = ChannelDataEcode.OK.value error_info = None - if not isinstance(dictdata, dict): + if isinstance(dictdata, list): + # batch data + for sample in dictdata: + if not isinstance(sample, dict): + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = "the value of data must " \ + "be dict, but get {}.".format(type(sample)) + break + elif not isinstance(dictdata, dict): + # batch size = 1 ecode = ChannelDataEcode.TYPE_ERROR.value error_info = "the value of data must " \ "be dict, but get {}.".format(type(dictdata)) @@ -102,12 +111,32 @@ class ChannelData(object): def check_npdata(npdata): ecode = ChannelDataEcode.OK.value error_info = None - for _, value in npdata.items(): - if not isinstance(value, np.ndarray): - ecode = ChannelDataEcode.TYPE_ERROR.value - error_info = "the value of data must " \ - "be np.ndarray, but get {}.".format(type(value)) - break + if isinstance(npdata, list): + # batch data + for sample in npdata: + if not isinstance(sample, dict): + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = "the value of data must " \ + "be dict, but get {}.".format(type(sample)) + break + for _, value in sample.items(): + if not isinstance(value, np.ndarray): + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = "the value of data must " \ + "be np.ndarray, but get {}.".format(type(value)) + return ecode, error_info + elif isinstance(npdata, dict): + # batch_size = 1 + for _, value in npdata.items(): + if not isinstance(value, np.ndarray): + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = "the value of data must " \ + "be np.ndarray, but get {}.".format(type(value)) + break + else: + ecode = ChannelDataEcode.TYPE_ERROR.value + error_info = "the value of data must " \ + "be dict, but get {}.".format(type(npdata)) return ecode, error_info def parse(self): diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index d82cac888298f83a1c8412f742adbf7de3932471..d2323f265c7fac65bc97d9b8d9a3dea8afe4cf2e 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -19,13 +19,14 @@ from paddle_serving_client import MultiLangClient, Client from concurrent import futures import logging import func_timeout +import os from numpy import * from .proto import pipeline_service_pb2 from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType from .util import NameGenerator -_LOGGER = logging.getLogger(__name__) +_LOGGER = logging.getLogger() _op_name_gen = NameGenerator("Op") @@ -59,6 +60,10 @@ class Op(object): self._outputs = [] self._profiler = None + # only for multithread + self._for_init_op_lock = threading.Lock() + self._succ_init_op = False + def init_profiler(self, profiler): self._profiler = profiler @@ -71,18 +76,19 @@ class Op(object): fetch_names): if self.with_serving == False: _LOGGER.debug("{} no client".format(self.name)) - return + return None _LOGGER.debug("{} client_config: {}".format(self.name, client_config)) _LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names)) if client_type == 'brpc': - self._client = Client() - self._client.load_client_config(client_config) + client = Client() + client.load_client_config(client_config) elif client_type == 'grpc': - self._client = MultiLangClient() + client = MultiLangClient() else: raise ValueError("unknow client type: {}".format(client_type)) - self._client.connect(server_endpoints) + client.connect(server_endpoints) self._fetch_names = fetch_names + return client def _get_input_channel(self): return self._input @@ -130,19 +136,17 @@ class Op(object): (_, input_dict), = input_dicts.items() return input_dict - def process(self, feed_dict): + def process(self, client_predict_handler, feed_dict): err, err_info = ChannelData.check_npdata(feed_dict) if err != 0: raise NotImplementedError( "{} Please override preprocess func.".format(err_info)) - _LOGGER.debug(self._log('feed_dict: {}'.format(feed_dict))) - _LOGGER.debug(self._log('fetch: {}'.format(self._fetch_names))) - call_result = self._client.predict( + call_result = client_predict_handler( feed=feed_dict, fetch=self._fetch_names) _LOGGER.debug(self._log("get call_result")) return call_result - def postprocess(self, fetch_dict): + def postprocess(self, input_dict, fetch_dict): return fetch_dict def stop(self): @@ -174,7 +178,7 @@ class Op(object): p = multiprocessing.Process( target=self._run, args=(concurrency_idx, self._get_input_channel(), - self._get_output_channels(), client_type)) + self._get_output_channels(), client_type, False)) p.start() proces.append(p) return proces @@ -185,12 +189,12 @@ class Op(object): t = threading.Thread( target=self._run, args=(concurrency_idx, self._get_input_channel(), - self._get_output_channels(), client_type)) + self._get_output_channels(), client_type, True)) t.start() threads.append(t) return threads - def load_user_resources(self): + def init_op(self): pass def _run_preprocess(self, parsed_data, data_id, log_func): @@ -222,13 +226,15 @@ class Op(object): data_id=data_id) return preped_data, error_channeldata - def _run_process(self, preped_data, data_id, log_func): + def _run_process(self, client_predict_handler, preped_data, data_id, + log_func): midped_data, error_channeldata = None, None if self.with_serving: ecode = ChannelDataEcode.OK.value if self._timeout <= 0: try: - midped_data = self.process(preped_data) + midped_data = self.process(client_predict_handler, + preped_data) except Exception as e: ecode = ChannelDataEcode.UNKNOW.value error_info = log_func(e) @@ -237,7 +243,11 @@ class Op(object): for i in range(self._retry): try: midped_data = func_timeout.func_timeout( - self._timeout, self.process, args=(preped_data, )) + self._timeout, + self.process, + args=( + client_predict_handler, + preped_data, )) except func_timeout.FunctionTimedOut as e: if i + 1 >= self._retry: ecode = ChannelDataEcode.TIMEOUT.value @@ -267,10 +277,10 @@ class Op(object): midped_data = preped_data return midped_data, error_channeldata - def _run_postprocess(self, midped_data, data_id, log_func): + def _run_postprocess(self, input_dict, midped_data, data_id, log_func): output_data, error_channeldata = None, None try: - postped_data = self.postprocess(midped_data) + postped_data = self.postprocess(input_dict, midped_data) except Exception as e: error_info = log_func(e) _LOGGER.error(error_info) @@ -303,8 +313,8 @@ class Op(object): data_id=data_id) return output_data, error_channeldata - def _run(self, concurrency_idx, input_channel, output_channels, - client_type): + def _run(self, concurrency_idx, input_channel, output_channels, client_type, + use_multithread): def get_log_func(op_info_prefix): def log_func(info_str): return "{} {}".format(op_info_prefix, info_str) @@ -315,12 +325,30 @@ class Op(object): log = get_log_func(op_info_prefix) tid = threading.current_thread().ident + client = None + client_predict_handler = None # create client based on client_type - self.init_client(client_type, self._client_config, - self._server_endpoints, self._fetch_names) + try: + client = self.init_client(client_type, self._client_config, + self._server_endpoints, self._fetch_names) + if client is not None: + client_predict_handler = client.predict + except Exception as e: + _LOGGER.error(log(e)) + os._exit(-1) # load user resources - self.load_user_resources() + try: + if use_multithread: + with self._for_init_op_lock: + if not self._succ_init_op: + self.init_op() + self._succ_init_op = True + else: + self.init_op() + except Exception as e: + _LOGGER.error(log(e)) + os._exit(-1) self._is_run = True while self._is_run: @@ -349,8 +377,8 @@ class Op(object): # process self._profiler_record("{}-midp#{}_0".format(op_info_prefix, tid)) - midped_data, error_channeldata = self._run_process(preped_data, - data_id, log) + midped_data, error_channeldata = self._run_process( + client_predict_handler, preped_data, data_id, log) self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid)) if error_channeldata is not None: self._push_to_output_channels(error_channeldata, @@ -359,8 +387,8 @@ class Op(object): # postprocess self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid)) - output_data, error_channeldata = self._run_postprocess(midped_data, - data_id, log) + output_data, error_channeldata = self._run_postprocess( + parsed_data, midped_data, data_id, log) self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid)) if error_channeldata is not None: self._push_to_output_channels(error_channeldata, @@ -384,7 +412,11 @@ class RequestOp(Op): super(RequestOp, self).__init__( name="#G", input_ops=[], concurrency=concurrency) # load user resources - self.load_user_resources() + try: + self.init_op() + except Exception as e: + _LOGGER.error(e) + os._exit(-1) def unpack_request_package(self, request): dictdata = {} @@ -405,7 +437,11 @@ class ResponseOp(Op): super(ResponseOp, self).__init__( name="#R", input_ops=input_ops, concurrency=concurrency) # load user resources - self.load_user_resources() + try: + self.init_op() + except Exception as e: + _LOGGER.error(e) + os._exit(-1) def pack_response_package(self, channeldata): resp = pipeline_service_pb2.Response() @@ -450,17 +486,26 @@ class VirtualOp(Op): def add_virtual_pred_op(self, op): self._virtual_pred_ops.append(op) + def _actual_pred_op_names(self, op): + if not isinstance(op, VirtualOp): + return [op.name] + names = [] + for x in op._virtual_pred_ops: + names.extend(self._actual_pred_op_names(x)) + return names + def add_output_channel(self, channel): if not isinstance(channel, (ThreadChannel, ProcessChannel)): raise TypeError( self._log('output channel must be Channel type, not {}'.format( type(channel)))) for op in self._virtual_pred_ops: - channel.add_producer(op.name) + for op_name in self._actual_pred_op_names(op): + channel.add_producer(op_name) self._outputs.append(channel) - def _run(self, concurrency_idx, input_channel, output_channels, - client_type): + def _run(self, concurrency_idx, input_channel, output_channels, client_type, + use_multithread): def get_log_func(op_info_prefix): def log_func(info_str): return "{} {}".format(op_info_prefix, info_str) diff --git a/python/pipeline/pipeline_client.py b/python/pipeline/pipeline_client.py index 4ad05b5a953d5084ffda360c0a1ac561463898a4..891a2b2a2fe759d0f392ad043dcc6f9173bd0e3a 100644 --- a/python/pipeline/pipeline_client.py +++ b/python/pipeline/pipeline_client.py @@ -20,7 +20,7 @@ import functools from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2_grpc -_LOGGER = logging.getLogger(__name__) +_LOGGER = logging.getLogger() class PipelineClient(object): @@ -52,7 +52,7 @@ class PipelineClient(object): return {"ecode": resp.ecode, "error_info": resp.error_info} fetch_map = {"ecode": resp.ecode} for idx, key in enumerate(resp.key): - if key not in fetch: + if fetch is not None and key not in fetch: continue data = resp.value[idx] try: @@ -62,16 +62,16 @@ class PipelineClient(object): fetch_map[key] = data return fetch_map - def predict(self, feed_dict, fetch, asyn=False): + def predict(self, feed_dict, fetch=None, asyn=False): if not isinstance(feed_dict, dict): raise TypeError( "feed must be dict type with format: {name: value}.") - if not isinstance(fetch, list): + if fetch is not None and not isinstance(fetch, list): raise TypeError("fetch must be list type with format: [name].") req = self._pack_request_package(feed_dict) if not asyn: resp = self._stub.inference(req) - return self._unpack_response_package(resp) + return self._unpack_response_package(resp, fetch) else: call_future = self._stub.inference.future(req) return PipelinePredictFuture( diff --git a/python/pipeline/pipeline_server.py b/python/pipeline/pipeline_server.py index 55289eeca42e02bb979d4a21791fdde44e0aff02..2c1c19bd6aa23214ab8fc385869cc29c8e86b37e 100644 --- a/python/pipeline/pipeline_server.py +++ b/python/pipeline/pipeline_server.py @@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod from .profiler import TimeProfiler from .util import NameGenerator -_LOGGER = logging.getLogger(__name__) +_LOGGER = logging.getLogger() _profiler = TimeProfiler() @@ -235,6 +235,10 @@ class PipelineServer(object): return use_ops, succ_ops_of_use_op use_ops, out_degree_ops = get_use_ops(response_op) + _LOGGER.info("================= use op ==================") + for op in use_ops: + _LOGGER.info(op.name) + _LOGGER.info("===========================================") if len(use_ops) <= 1: raise Exception( "Besides RequestOp and ResponseOp, there should be at least one Op in DAG." diff --git a/python/pipeline/profiler.py b/python/pipeline/profiler.py index 146203f7c184b506bb8fd70dadac1d89166a2de9..49eabf5b318789823073154dd3f8ad38e18638a1 100644 --- a/python/pipeline/profiler.py +++ b/python/pipeline/profiler.py @@ -24,7 +24,7 @@ else: raise Exception("Error Python version") import time -_LOGGER = logging.getLogger(__name__) +_LOGGER = logging.getLogger() class TimeProfiler(object): @@ -58,7 +58,7 @@ class TimeProfiler(object): print_str += "{}_{}:{} ".format(name, tag, timestamp) else: tmp[name] = (tag, timestamp) - print_str += "\n" + print_str = "\n{}\n".format(print_str) sys.stderr.write(print_str) for name, item in tmp.items(): tag, timestamp = item