From 7cb306e510385801d34fef741a1bd73fc15d6a37 Mon Sep 17 00:00:00 2001 From: barrierye Date: Wed, 8 Jul 2020 03:45:10 +0800 Subject: [PATCH] remove client_predict_handler --- python/pipeline/operator.py | 54 ++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 31 deletions(-) diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index 507807c6..094a71c2 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -139,12 +139,12 @@ class Op(object): (_, input_dict), = input_dicts.items() return input_dict - def process(self, client_predict_handler, feed_dict): + def process(self, feed_dict): err, err_info = ChannelData.check_npdata(feed_dict) if err != 0: raise NotImplementedError( "{} Please override preprocess func.".format(err_info)) - call_result = client_predict_handler( + call_result = self.client.predict( feed=feed_dict, fetch=self._fetch_names) _LOGGER.debug(self._log("get call_result")) return call_result @@ -229,15 +229,13 @@ class Op(object): data_id=data_id) return preped_data, error_channeldata - def _run_process(self, client_predict_handler, preped_data, data_id, - log_func): + def _run_process(self, preped_data, data_id, log_func): midped_data, error_channeldata = None, None if self.with_serving: ecode = ChannelDataEcode.OK.value if self._timeout <= 0: try: - midped_data = self.process(client_predict_handler, - preped_data) + midped_data = self.process(preped_data) except Exception as e: ecode = ChannelDataEcode.UNKNOW.value error_info = log_func(e) @@ -246,9 +244,7 @@ class Op(object): for i in range(self._retry): try: midped_data = func_timeout.func_timeout( - self._timeout, - self.process, - args=(client_predict_handler, preped_data)) + self._timeout, self.process, args=(preped_data, )) except func_timeout.FunctionTimedOut as e: if i + 1 >= self._retry: ecode = ChannelDataEcode.TIMEOUT.value @@ -326,36 +322,36 @@ class Op(object): log = get_log_func(op_info_prefix) tid = threading.current_thread().ident - client = None - client_predict_handler = None - # create client based on client_type - try: - client = self.init_client(client_type, self._client_config, - self._server_endpoints, self._fetch_names) - if client is not None: - client_predict_handler = client.predict - except Exception as e: - _LOGGER.error(log(e)) - os._exit(-1) - # init op self.concurrency_idx = concurrency_idx try: if use_multithread: with self._for_init_op_lock: if not self._succ_init_op: + # init profiler + self._profiler = TimeProfiler() + self._profiler.enable(self._use_profile) + # init client + self.client = self.init_client( + client_type, self._client_config, + self._server_endpoints, self._fetch_names) + # user defined self.init_op() self._succ_init_op = True else: + # init profiler + self._profiler = TimeProfiler() + self._profiler.enable(self._use_profile) + # init client + self.client = self.init_client(client_type, self._client_config, + self._server_endpoints, + self._fetch_names) + # user defined self.init_op() except Exception as e: _LOGGER.error(log(e)) os._exit(-1) - # init profiler - self._profiler = TimeProfiler() - self._profiler.enable(self._use_profile) - self._is_run = True while self._is_run: #self._profiler_record("get#{}_0".format(op_info_prefix)) @@ -383,8 +379,8 @@ class Op(object): # process self._profiler_record("midp#{}_0".format(op_info_prefix)) - midped_data, error_channeldata = self._run_process( - client_predict_handler, preped_data, data_id, log) + midped_data, error_channeldata = self._run_process(preped_data, + data_id, log) self._profiler_record("midp#{}_1".format(op_info_prefix)) if error_channeldata is not None: self._push_to_output_channels(error_channeldata, @@ -531,12 +527,8 @@ class VirtualOp(Op): self._is_run = True while self._is_run: - #self._profiler_record("get#{}_0".format(op_info_prefix)) channeldata_dict = input_channel.front(self.name) - #self._profiler_record("get#{}_1".format(op_info_prefix)) - #self._profiler_record("push#{}_0".format(op_info_prefix)) for name, data in channeldata_dict.items(): self._push_to_output_channels( data, channels=output_channels, name=name) - #self._profiler_record("push#{}_1".format(op_info_prefix)) -- GitLab