提交 c64f6ddd 编写于 作者: B barrierye

Merge branch 'pipeline-update' of https://github.com/barrierye/Serving into pipeline-update

......@@ -127,7 +127,7 @@ class Op(object):
channel.add_producer(self.name)
self._outputs.append(channel)
def preprocess(self, input_dicts, private_obj):
def preprocess(self, input_dicts):
# multiple previous Op
if len(input_dicts) != 1:
raise NotImplementedError(
......@@ -137,7 +137,7 @@ class Op(object):
(_, input_dict), = input_dicts.items()
return input_dict
def process(self, client_predict_handler, feed_dict, private_obj):
def process(self, client_predict_handler, feed_dict):
err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0:
raise NotImplementedError(
......@@ -147,7 +147,7 @@ class Op(object):
_LOGGER.debug(self._log("get call_result"))
return call_result
def postprocess(self, input_dict, fetch_dict, private_obj):
def postprocess(self, input_dict, fetch_dict):
return fetch_dict
def stop(self):
......@@ -198,13 +198,10 @@ class Op(object):
def init_op(self):
pass
def load_private_obj(self):
return None
def _run_preprocess(self, parsed_data, data_id, private_obj, log_func):
def _run_preprocess(self, parsed_data, data_id, log_func):
preped_data, error_channeldata = None, None
try:
preped_data = self.preprocess(parsed_data, private_obj)
preped_data = self.preprocess(parsed_data)
except NotImplementedError as e:
# preprocess function not implemented
error_info = log_func(e)
......@@ -231,14 +228,14 @@ class Op(object):
return preped_data, error_channeldata
def _run_process(self, client_predict_handler, preped_data, data_id,
private_obj, log_func):
log_func):
midped_data, error_channeldata = None, None
if self.with_serving:
ecode = ChannelDataEcode.OK.value
if self._timeout <= 0:
try:
midped_data = self.process(client_predict_handler,
preped_data, private_obj)
preped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
......@@ -249,8 +246,7 @@ class Op(object):
midped_data = func_timeout.func_timeout(
self._timeout,
self.process,
args=(client_predict_handler, preped_data,
private_obj))
args=(client_predict_handler, preped_data))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
......@@ -280,12 +276,10 @@ class Op(object):
midped_data = preped_data
return midped_data, error_channeldata
def _run_postprocess(self, input_dict, midped_data, data_id, private_obj,
log_func):
def _run_postprocess(self, input_dict, midped_data, data_id, log_func):
output_data, error_channeldata = None, None
try:
postped_data = self.postprocess(input_dict, midped_data,
private_obj)
postped_data = self.postprocess(input_dict, midped_data)
except Exception as e:
error_info = log_func(e)
_LOGGER.error(error_info)
......@@ -355,14 +349,6 @@ class Op(object):
_LOGGER.error(log(e))
os._exit(-1)
# load private objects (some no-thread-safe objects)
private_obj = None
try:
private_obj = self.load_private_obj()
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
self._is_run = True
while self._is_run:
#self._profiler_record("get#{}_0".format(op_info_prefix))
......@@ -380,8 +366,8 @@ class Op(object):
# preprecess
self._profiler_record("prep#{}_0".format(op_info_prefix))
preped_data, error_channeldata = self._run_preprocess(
parsed_data, data_id, private_obj, log)
preped_data, error_channeldata = self._run_preprocess(parsed_data,
data_id, log)
self._profiler_record("prep#{}_1".format(op_info_prefix))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
......@@ -391,7 +377,7 @@ class Op(object):
# process
self._profiler_record("midp#{}_0".format(op_info_prefix))
midped_data, error_channeldata = self._run_process(
client_predict_handler, preped_data, data_id, private_obj, log)
client_predict_handler, preped_data, data_id, log)
self._profiler_record("midp#{}_1".format(op_info_prefix))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
......@@ -401,7 +387,7 @@ class Op(object):
# postprocess
self._profiler_record("postp#{}_0".format(op_info_prefix))
output_data, error_channeldata = self._run_postprocess(
parsed_data, midped_data, data_id, private_obj, log)
parsed_data, midped_data, data_id, log)
self._profiler_record("postp#{}_1".format(op_info_prefix))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册