From 51b9e139700fc2d19a120bb10ad865b63cd23435 Mon Sep 17 00:00:00 2001 From: barrierye Date: Thu, 2 Jul 2020 11:19:39 +0800 Subject: [PATCH] fix bug: client in multi-thread-op --- python/pipeline/operator.py | 40 ++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index 7ba5e52c..cc99c585 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -71,18 +71,19 @@ class Op(object): fetch_names): if self.with_serving == False: _LOGGER.debug("{} no client".format(self.name)) - return + return None _LOGGER.debug("{} client_config: {}".format(self.name, client_config)) _LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names)) if client_type == 'brpc': - self._client = Client() - self._client.load_client_config(client_config) + client = Client() + client.load_client_config(client_config) elif client_type == 'grpc': - self._client = MultiLangClient() + client = MultiLangClient() else: raise ValueError("unknow client type: {}".format(client_type)) - self._client.connect(server_endpoints) + client.connect(server_endpoints) self._fetch_names = fetch_names + return client def _get_input_channel(self): return self._input @@ -130,14 +131,12 @@ class Op(object): (_, input_dict), = input_dicts.items() return input_dict - def process(self, feed_dict): + def process(self, client_predict_handler, feed_dict): err, err_info = ChannelData.check_npdata(feed_dict) if err != 0: raise NotImplementedError( "{} Please override preprocess func.".format(err_info)) - _LOGGER.debug(self._log('feed_dict: {}'.format(feed_dict))) - _LOGGER.debug(self._log('fetch: {}'.format(self._fetch_names))) - call_result = self._client.predict( + call_result = client_predict_handler( feed=feed_dict, fetch=self._fetch_names) _LOGGER.debug(self._log("get call_result")) return call_result @@ -222,13 +221,15 @@ class Op(object): data_id=data_id) return preped_data, error_channeldata - def _run_process(self, preped_data, data_id, log_func): + def _run_process(self, client_predict_handler, preped_data, data_id, + log_func): midped_data, error_channeldata = None, None if self.with_serving: ecode = ChannelDataEcode.OK.value if self._timeout <= 0: try: - midped_data = self.process(preped_data) + midped_data = self.process(client_predict_handler, + preped_data) except Exception as e: ecode = ChannelDataEcode.UNKNOW.value error_info = log_func(e) @@ -237,7 +238,11 @@ class Op(object): for i in range(self._retry): try: midped_data = func_timeout.func_timeout( - self._timeout, self.process, args=(preped_data, )) + self._timeout, + self.process, + args=( + client_predict_handler, + preped_data, )) except func_timeout.FunctionTimedOut as e: if i + 1 >= self._retry: ecode = ChannelDataEcode.TIMEOUT.value @@ -316,8 +321,11 @@ class Op(object): tid = threading.current_thread().ident # create client based on client_type - self.init_client(client_type, self._client_config, - self._server_endpoints, self._fetch_names) + client = self.init_client(client_type, self._client_config, + self._server_endpoints, self._fetch_names) + client_predict_handler = None + if self.with_serving: + client_predict_handler = client.predict # load user resources self.load_user_resources() @@ -349,8 +357,8 @@ class Op(object): # process self._profiler_record("{}-midp#{}_0".format(op_info_prefix, tid)) - midped_data, error_channeldata = self._run_process(preped_data, - data_id, log) + midped_data, error_channeldata = self._run_process( + client_predict_handler, preped_data, data_id, log) self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid)) if error_channeldata is not None: self._push_to_output_channels(error_channeldata, -- GitLab