提交 7cb306e5 编写于 作者: B barrierye

remove client_predict_handler

上级 710f7042
...@@ -139,12 +139,12 @@ class Op(object): ...@@ -139,12 +139,12 @@ class Op(object):
(_, input_dict), = input_dicts.items() (_, input_dict), = input_dicts.items()
return input_dict return input_dict
def process(self, client_predict_handler, feed_dict): def process(self, feed_dict):
err, err_info = ChannelData.check_npdata(feed_dict) err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0: if err != 0:
raise NotImplementedError( raise NotImplementedError(
"{} Please override preprocess func.".format(err_info)) "{} Please override preprocess func.".format(err_info))
call_result = client_predict_handler( call_result = self.client.predict(
feed=feed_dict, fetch=self._fetch_names) feed=feed_dict, fetch=self._fetch_names)
_LOGGER.debug(self._log("get call_result")) _LOGGER.debug(self._log("get call_result"))
return call_result return call_result
...@@ -229,15 +229,13 @@ class Op(object): ...@@ -229,15 +229,13 @@ class Op(object):
data_id=data_id) data_id=data_id)
return preped_data, error_channeldata return preped_data, error_channeldata
def _run_process(self, client_predict_handler, preped_data, data_id, def _run_process(self, preped_data, data_id, log_func):
log_func):
midped_data, error_channeldata = None, None midped_data, error_channeldata = None, None
if self.with_serving: if self.with_serving:
ecode = ChannelDataEcode.OK.value ecode = ChannelDataEcode.OK.value
if self._timeout <= 0: if self._timeout <= 0:
try: try:
midped_data = self.process(client_predict_handler, midped_data = self.process(preped_data)
preped_data)
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e) error_info = log_func(e)
...@@ -246,9 +244,7 @@ class Op(object): ...@@ -246,9 +244,7 @@ class Op(object):
for i in range(self._retry): for i in range(self._retry):
try: try:
midped_data = func_timeout.func_timeout( midped_data = func_timeout.func_timeout(
self._timeout, self._timeout, self.process, args=(preped_data, ))
self.process,
args=(client_predict_handler, preped_data))
except func_timeout.FunctionTimedOut as e: except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry: if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value ecode = ChannelDataEcode.TIMEOUT.value
...@@ -326,36 +322,36 @@ class Op(object): ...@@ -326,36 +322,36 @@ class Op(object):
log = get_log_func(op_info_prefix) log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident tid = threading.current_thread().ident
client = None
client_predict_handler = None
# create client based on client_type
try:
client = self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names)
if client is not None:
client_predict_handler = client.predict
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
# init op # init op
self.concurrency_idx = concurrency_idx self.concurrency_idx = concurrency_idx
try: try:
if use_multithread: if use_multithread:
with self._for_init_op_lock: with self._for_init_op_lock:
if not self._succ_init_op: if not self._succ_init_op:
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile)
# init client
self.client = self.init_client(
client_type, self._client_config,
self._server_endpoints, self._fetch_names)
# user defined
self.init_op() self.init_op()
self._succ_init_op = True self._succ_init_op = True
else: else:
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile)
# init client
self.client = self.init_client(client_type, self._client_config,
self._server_endpoints,
self._fetch_names)
# user defined
self.init_op() self.init_op()
except Exception as e: except Exception as e:
_LOGGER.error(log(e)) _LOGGER.error(log(e))
os._exit(-1) os._exit(-1)
# init profiler
self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile)
self._is_run = True self._is_run = True
while self._is_run: while self._is_run:
#self._profiler_record("get#{}_0".format(op_info_prefix)) #self._profiler_record("get#{}_0".format(op_info_prefix))
...@@ -383,8 +379,8 @@ class Op(object): ...@@ -383,8 +379,8 @@ class Op(object):
# process # process
self._profiler_record("midp#{}_0".format(op_info_prefix)) self._profiler_record("midp#{}_0".format(op_info_prefix))
midped_data, error_channeldata = self._run_process( midped_data, error_channeldata = self._run_process(preped_data,
client_predict_handler, preped_data, data_id, log) data_id, log)
self._profiler_record("midp#{}_1".format(op_info_prefix)) self._profiler_record("midp#{}_1".format(op_info_prefix))
if error_channeldata is not None: if error_channeldata is not None:
self._push_to_output_channels(error_channeldata, self._push_to_output_channels(error_channeldata,
...@@ -531,12 +527,8 @@ class VirtualOp(Op): ...@@ -531,12 +527,8 @@ class VirtualOp(Op):
self._is_run = True self._is_run = True
while self._is_run: while self._is_run:
#self._profiler_record("get#{}_0".format(op_info_prefix))
channeldata_dict = input_channel.front(self.name) channeldata_dict = input_channel.front(self.name)
#self._profiler_record("get#{}_1".format(op_info_prefix))
#self._profiler_record("push#{}_0".format(op_info_prefix))
for name, data in channeldata_dict.items(): for name, data in channeldata_dict.items():
self._push_to_output_channels( self._push_to_output_channels(
data, channels=output_channels, name=name) data, channels=output_channels, name=name)
#self._profiler_record("push#{}_1".format(op_info_prefix))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册