提交 5ba1b5fd 编写于 作者: B barrierye

split _run func in operator into small pieces

上级 ef883a83
......@@ -200,6 +200,116 @@ class Op(object):
def load_user_resources(self):
pass
def _run_preprocess(self, parsed_data, data_id, log_func):
preped_data, error_channeldata = None, None
try:
preped_data = self.preprocess(parsed_data)
except NotImplementedError as e:
# preprocess function not implemented
error_info = log_func(e)
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
error_info=error_info,
data_id=data_id)
except TypeError as e:
# Error type in channeldata.datatype
error_info = log_func(e)
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info,
data_id=data_id)
except Exception as e:
error_info = log_func(e)
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id)
return preped_data, error_channeldata
def _run_process(self, preped_data, data_id, log_func):
midped_data, error_channeldata = None, None
if self.with_serving:
ecode = ChannelDataEcode.OK.value
if self._timeout <= 0:
try:
midped_data = self.process(preped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
_LOGGER.error(error_info)
else:
for i in range(self._retry):
try:
midped_data = func_timeout.func_timeout(
self._timeout, self.process, args=(preped_data, ))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
error_info = log_func(e)
_LOGGER.error(error_info)
else:
_LOGGER.warn(
log_func("timeout, retry({})".format(i + 1)))
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
_LOGGER.error(error_info)
break
else:
break
if ecode != ChannelDataEcode.OK.value:
error_channeldata = ChannelData(
ecode=ecode, error_info=error_info, data_id=data_id)
elif midped_data is None:
# op client return None
error_channeldata = ChannelData(
ecode=ChannelDataEcode.CLIENT_ERROR.value,
error_info=log_func(
"predict failed. pls check the server side."),
data_id=data_id)
else:
midped_data = preped_data
return midped_data, error_channeldata
def _run_postprocess(self, midped_data, data_id, log_func):
output_data, error_channeldata = None, None
try:
postped_data = self.postprocess(midped_data)
except Exception as e:
error_info = log_func(e)
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id)
return output_data, error_channeldata
if not isinstance(postped_data, dict):
error_info = log_func("output of postprocess funticon must be " \
"dict type, but get {}".format(type(postped_data)))
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id)
return output_data, error_channeldata
err, _ = ChannelData.check_npdata(postped_data)
if err == 0:
output_data = ChannelData(
ChannelDataType.CHANNEL_NPDATA.value,
npdata=postped_data,
data_id=data_id)
else:
output_data = ChannelData(
ChannelDataType.DICT.value,
dictdata=postped_data,
data_id=data_id)
return output_data, error_channeldata
def _run(self, concurrency_idx, input_channel, output_channels,
client_type):
def get_log_func(op_info_prefix):
......@@ -228,7 +338,6 @@ class Op(object):
data_id, error_channeldata, parsed_data = self._parse_channeldata(
channeldata_dict)
# error data in predecessor Op
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
......@@ -236,139 +345,34 @@ class Op(object):
continue
# preprecess
try:
self._profiler_record("{}-prep#{}_0".format(op_info_prefix,
tid))
preped_data = self.preprocess(parsed_data)
self._profiler_record("{}-prep#{}_1".format(op_info_prefix,
tid))
except NotImplementedError as e:
# preprocess function not implemented
error_info = log(e)
_LOGGER.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
error_info=error_info,
data_id=data_id),
output_channels)
continue
except TypeError as e:
# Error type in channeldata.datatype
error_info = log(e)
_LOGGER.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info,
data_id=data_id),
output_channels)
continue
except Exception as e:
error_info = log(e)
_LOGGER.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id),
output_channels)
self._profiler_record("{}-prep#{}_0".format(op_info_prefix, tid))
preped_data, error_channeldata = self._run_preprocess(parsed_data,
data_id, log)
self._profiler_record("{}-prep#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
continue
# process
midped_data = None
if self.with_serving:
ecode = ChannelDataEcode.OK.value
self._profiler_record("{}-midp#{}_0".format(op_info_prefix,
tid))
if self._timeout <= 0:
try:
midped_data = self.process(preped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
_LOGGER.error(error_info)
else:
for i in range(self._retry):
try:
midped_data = func_timeout.func_timeout(
self._timeout,
self.process,
args=(preped_data, ))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
error_info = log(e)
_LOGGER.error(error_info)
else:
_LOGGER.warn(
log("timeout, retry({})".format(i + 1)))
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
_LOGGER.error(error_info)
break
else:
break
if ecode != ChannelDataEcode.OK.value:
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info,
data_id=data_id),
output_channels)
continue
self._profiler_record("{}-midp#{}_1".format(op_info_prefix,
tid))
# op client return None
if midped_data is None:
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.CLIENT_ERROR.value,
error_info=log(
"predict failed. pls check the server side."),
data_id=data_id),
output_channels)
continue
else:
midped_data = preped_data
self._profiler_record("{}-midp#{}_0".format(op_info_prefix, tid))
midped_data, error_channeldata = self._run_process(preped_data,
data_id, log)
self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
continue
# postprocess
output_data = None
self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid))
try:
postped_data = self.postprocess(midped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e)
_LOGGER.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info, data_id=data_id),
output_channels)
continue
if not isinstance(postped_data, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = log("output of postprocess funticon must be " \
"dict type, but get {}".format(type(postped_data)))
_LOGGER.error(error_info)
self._push_to_output_channels(
ChannelData(
ecode=ecode, error_info=error_info, data_id=data_id),
output_channels)
continue
err, _ = ChannelData.check_npdata(postped_data)
if err == 0:
output_data = ChannelData(
ChannelDataType.CHANNEL_NPDATA.value,
npdata=postped_data,
data_id=data_id)
else:
output_data = ChannelData(
ChannelDataType.DICT.value,
dictdata=postped_data,
data_id=data_id)
output_data, error_channeldata = self._run_postprocess(midped_data,
data_id, log)
self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
continue
# push data to channel (if run succ)
self._profiler_record("{}-push#{}_0".format(op_info_prefix, tid))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册