提交 f3ac1e38 编写于 作者: B barrierye

fix bug: client in multi-thread-op

上级 6722b83c
......@@ -71,18 +71,19 @@ class Op(object):
fetch_names):
if self.with_serving == False:
_LOGGER.debug("{} no client".format(self.name))
return
return None
_LOGGER.debug("{} client_config: {}".format(self.name, client_config))
_LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names))
if client_type == 'brpc':
self._client = Client()
self._client.load_client_config(client_config)
client = Client()
client.load_client_config(client_config)
elif client_type == 'grpc':
self._client = MultiLangClient()
client = MultiLangClient()
else:
raise ValueError("unknow client type: {}".format(client_type))
self._client.connect(server_endpoints)
client.connect(server_endpoints)
self._fetch_names = fetch_names
return client
def _get_input_channel(self):
return self._input
......@@ -130,14 +131,12 @@ class Op(object):
(_, input_dict), = input_dicts.items()
return input_dict
def process(self, feed_dict):
def process(self, client_predict_handler, feed_dict):
err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0:
raise NotImplementedError(
"{} Please override preprocess func.".format(err_info))
_LOGGER.debug(self._log('feed_dict: {}'.format(feed_dict)))
_LOGGER.debug(self._log('fetch: {}'.format(self._fetch_names)))
call_result = self._client.predict(
call_result = client_predict_handler(
feed=feed_dict, fetch=self._fetch_names)
_LOGGER.debug(self._log("get call_result"))
return call_result
......@@ -222,13 +221,15 @@ class Op(object):
data_id=data_id)
return preped_data, error_channeldata
def _run_process(self, preped_data, data_id, log_func):
def _run_process(self, client_predict_handler, preped_data, data_id,
log_func):
midped_data, error_channeldata = None, None
if self.with_serving:
ecode = ChannelDataEcode.OK.value
if self._timeout <= 0:
try:
midped_data = self.process(preped_data)
midped_data = self.process(client_predict_handler,
preped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
......@@ -237,7 +238,11 @@ class Op(object):
for i in range(self._retry):
try:
midped_data = func_timeout.func_timeout(
self._timeout, self.process, args=(preped_data, ))
self._timeout,
self.process,
args=(
client_predict_handler,
preped_data, ))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
......@@ -316,8 +321,11 @@ class Op(object):
tid = threading.current_thread().ident
# create client based on client_type
self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names)
client = self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names)
client_predict_handler = None
if self.with_serving:
client_predict_handler = client.predict
# load user resources
self.load_user_resources()
......@@ -349,8 +357,8 @@ class Op(object):
# process
self._profiler_record("{}-midp#{}_0".format(op_info_prefix, tid))
midped_data, error_channeldata = self._run_process(preped_data,
data_id, log)
midped_data, error_channeldata = self._run_process(
client_predict_handler, preped_data, data_id, log)
self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册