提交 afb05e6b 编写于 作者: B barrierye

add load_private_obj func in OP

上级 b96c6019
......@@ -126,7 +126,7 @@ class Op(object):
channel.add_producer(self.name)
self._outputs.append(channel)
def preprocess(self, input_dicts):
def preprocess(self, input_dicts, private_obj):
# multiple previous Op
if len(input_dicts) != 1:
raise NotImplementedError(
......@@ -136,7 +136,7 @@ class Op(object):
(_, input_dict), = input_dicts.items()
return input_dict
def process(self, client_predict_handler, feed_dict):
def process(self, client_predict_handler, feed_dict, private_obj):
err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0:
raise NotImplementedError(
......@@ -146,7 +146,7 @@ class Op(object):
_LOGGER.debug(self._log("get call_result"))
return call_result
def postprocess(self, input_dict, fetch_dict):
def postprocess(self, input_dict, fetch_dict, private_obj):
return fetch_dict
def stop(self):
......@@ -197,10 +197,13 @@ class Op(object):
def init_op(self):
pass
def _run_preprocess(self, parsed_data, data_id, log_func):
def load_private_obj(self):
return None
def _run_preprocess(self, parsed_data, data_id, private_obj, log_func):
preped_data, error_channeldata = None, None
try:
preped_data = self.preprocess(parsed_data)
preped_data = self.preprocess(parsed_data, private_obj)
except NotImplementedError as e:
# preprocess function not implemented
error_info = log_func(e)
......@@ -227,14 +230,14 @@ class Op(object):
return preped_data, error_channeldata
def _run_process(self, client_predict_handler, preped_data, data_id,
log_func):
private_obj, log_func):
midped_data, error_channeldata = None, None
if self.with_serving:
ecode = ChannelDataEcode.OK.value
if self._timeout <= 0:
try:
midped_data = self.process(client_predict_handler,
preped_data)
preped_data, private_obj)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
......@@ -245,9 +248,8 @@ class Op(object):
midped_data = func_timeout.func_timeout(
self._timeout,
self.process,
args=(
client_predict_handler,
preped_data, ))
args=(client_predict_handler, preped_data,
private_obj))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
......@@ -277,10 +279,12 @@ class Op(object):
midped_data = preped_data
return midped_data, error_channeldata
def _run_postprocess(self, input_dict, midped_data, data_id, log_func):
def _run_postprocess(self, input_dict, midped_data, data_id, private_obj,
log_func):
output_data, error_channeldata = None, None
try:
postped_data = self.postprocess(input_dict, midped_data)
postped_data = self.postprocess(input_dict, midped_data,
private_obj)
except Exception as e:
error_info = log_func(e)
_LOGGER.error(error_info)
......@@ -337,7 +341,7 @@ class Op(object):
_LOGGER.error(log(e))
os._exit(-1)
# load user resources
# init op
try:
if use_multithread:
with self._for_init_op_lock:
......@@ -350,6 +354,14 @@ class Op(object):
_LOGGER.error(log(e))
os._exit(-1)
# load private objects (some no-thread-safe objects)
private_obj = None
try:
private_obj = self.load_private_obj()
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
self._is_run = True
while self._is_run:
#self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
......@@ -367,8 +379,8 @@ class Op(object):
# preprecess
self._profiler_record("{}-prep#{}_0".format(op_info_prefix, tid))
preped_data, error_channeldata = self._run_preprocess(parsed_data,
data_id, log)
preped_data, error_channeldata = self._run_preprocess(
parsed_data, data_id, private_obj, log)
self._profiler_record("{}-prep#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
......@@ -378,7 +390,7 @@ class Op(object):
# process
self._profiler_record("{}-midp#{}_0".format(op_info_prefix, tid))
midped_data, error_channeldata = self._run_process(
client_predict_handler, preped_data, data_id, log)
client_predict_handler, preped_data, data_id, private_obj, log)
self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
......@@ -388,7 +400,7 @@ class Op(object):
# postprocess
self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid))
output_data, error_channeldata = self._run_postprocess(
parsed_data, midped_data, data_id, log)
parsed_data, midped_data, data_id, private_obj, log)
self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
......@@ -411,7 +423,7 @@ class RequestOp(Op):
# PipelineService.name = "#G"
super(RequestOp, self).__init__(
name="#G", input_ops=[], concurrency=concurrency)
# load user resources
# init op
try:
self.init_op()
except Exception as e:
......@@ -436,7 +448,7 @@ class ResponseOp(Op):
def __init__(self, input_ops, concurrency=1):
super(ResponseOp, self).__init__(
name="#R", input_ops=input_ops, concurrency=concurrency)
# load user resources
# init op
try:
self.init_op()
except Exception as e:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册