diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index eed109e6543b27f37fafbd9ee328db8e2692166a..c228213d8b5cbacb70c1ee1e788ca5c769d82802 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -126,7 +126,7 @@ class Op(object): channel.add_producer(self.name) self._outputs.append(channel) - def preprocess(self, input_dicts): + def preprocess(self, input_dicts, private_obj): # multiple previous Op if len(input_dicts) != 1: raise NotImplementedError( @@ -136,7 +136,7 @@ class Op(object): (_, input_dict), = input_dicts.items() return input_dict - def process(self, client_predict_handler, feed_dict): + def process(self, client_predict_handler, feed_dict, private_obj): err, err_info = ChannelData.check_npdata(feed_dict) if err != 0: raise NotImplementedError( @@ -146,7 +146,7 @@ class Op(object): _LOGGER.debug(self._log("get call_result")) return call_result - def postprocess(self, input_dict, fetch_dict): + def postprocess(self, input_dict, fetch_dict, private_obj): return fetch_dict def stop(self): @@ -197,10 +197,13 @@ class Op(object): def init_op(self): pass - def _run_preprocess(self, parsed_data, data_id, log_func): + def load_private_obj(self): + return None + + def _run_preprocess(self, parsed_data, data_id, private_obj, log_func): preped_data, error_channeldata = None, None try: - preped_data = self.preprocess(parsed_data) + preped_data = self.preprocess(parsed_data, private_obj) except NotImplementedError as e: # preprocess function not implemented error_info = log_func(e) @@ -227,14 +230,14 @@ class Op(object): return preped_data, error_channeldata def _run_process(self, client_predict_handler, preped_data, data_id, - log_func): + private_obj, log_func): midped_data, error_channeldata = None, None if self.with_serving: ecode = ChannelDataEcode.OK.value if self._timeout <= 0: try: midped_data = self.process(client_predict_handler, - preped_data) + preped_data, private_obj) except Exception as e: ecode = ChannelDataEcode.UNKNOW.value error_info = log_func(e) @@ -245,9 +248,8 @@ class Op(object): midped_data = func_timeout.func_timeout( self._timeout, self.process, - args=( - client_predict_handler, - preped_data, )) + args=(client_predict_handler, preped_data, + private_obj)) except func_timeout.FunctionTimedOut as e: if i + 1 >= self._retry: ecode = ChannelDataEcode.TIMEOUT.value @@ -277,10 +279,12 @@ class Op(object): midped_data = preped_data return midped_data, error_channeldata - def _run_postprocess(self, input_dict, midped_data, data_id, log_func): + def _run_postprocess(self, input_dict, midped_data, data_id, private_obj, + log_func): output_data, error_channeldata = None, None try: - postped_data = self.postprocess(input_dict, midped_data) + postped_data = self.postprocess(input_dict, midped_data, + private_obj) except Exception as e: error_info = log_func(e) _LOGGER.error(error_info) @@ -337,7 +341,7 @@ class Op(object): _LOGGER.error(log(e)) os._exit(-1) - # load user resources + # init op try: if use_multithread: with self._for_init_op_lock: @@ -350,6 +354,14 @@ class Op(object): _LOGGER.error(log(e)) os._exit(-1) + # load private objects (some no-thread-safe objects) + private_obj = None + try: + private_obj = self.load_private_obj() + except Exception as e: + _LOGGER.error(log(e)) + os._exit(-1) + self._is_run = True while self._is_run: #self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid)) @@ -367,8 +379,8 @@ class Op(object): # preprecess self._profiler_record("{}-prep#{}_0".format(op_info_prefix, tid)) - preped_data, error_channeldata = self._run_preprocess(parsed_data, - data_id, log) + preped_data, error_channeldata = self._run_preprocess( + parsed_data, data_id, private_obj, log) self._profiler_record("{}-prep#{}_1".format(op_info_prefix, tid)) if error_channeldata is not None: self._push_to_output_channels(error_channeldata, @@ -378,7 +390,7 @@ class Op(object): # process self._profiler_record("{}-midp#{}_0".format(op_info_prefix, tid)) midped_data, error_channeldata = self._run_process( - client_predict_handler, preped_data, data_id, log) + client_predict_handler, preped_data, data_id, private_obj, log) self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid)) if error_channeldata is not None: self._push_to_output_channels(error_channeldata, @@ -388,7 +400,7 @@ class Op(object): # postprocess self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid)) output_data, error_channeldata = self._run_postprocess( - parsed_data, midped_data, data_id, log) + parsed_data, midped_data, data_id, private_obj, log) self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid)) if error_channeldata is not None: self._push_to_output_channels(error_channeldata, @@ -411,7 +423,7 @@ class RequestOp(Op): # PipelineService.name = "#G" super(RequestOp, self).__init__( name="#G", input_ops=[], concurrency=concurrency) - # load user resources + # init op try: self.init_op() except Exception as e: @@ -436,7 +448,7 @@ class ResponseOp(Op): def __init__(self, input_ops, concurrency=1): super(ResponseOp, self).__init__( name="#R", input_ops=input_ops, concurrency=concurrency) - # load user resources + # init op try: self.init_op() except Exception as e: