提交 5ba1b5fd 编写于 作者: B barrierye

split _run func in operator into small pieces

上级 ef883a83
...@@ -200,162 +200,102 @@ class Op(object): ...@@ -200,162 +200,102 @@ class Op(object):
def load_user_resources(self): def load_user_resources(self):
pass pass
def _run(self, concurrency_idx, input_channel, output_channels, def _run_preprocess(self, parsed_data, data_id, log_func):
client_type): preped_data, error_channeldata = None, None
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
return log_func
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident
# create client based on client_type
self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names)
# load user resources
self.load_user_resources()
self._is_run = True
while self._is_run:
self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
channeldata_dict = input_channel.front(self.name)
self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid))
_LOGGER.debug(log("input_data: {}".format(channeldata_dict)))
data_id, error_channeldata, parsed_data = self._parse_channeldata(
channeldata_dict)
# error data in predecessor Op
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
continue
# preprecess
try: try:
self._profiler_record("{}-prep#{}_0".format(op_info_prefix,
tid))
preped_data = self.preprocess(parsed_data) preped_data = self.preprocess(parsed_data)
self._profiler_record("{}-prep#{}_1".format(op_info_prefix,
tid))
except NotImplementedError as e: except NotImplementedError as e:
# preprocess function not implemented # preprocess function not implemented
error_info = log(e) error_info = log_func(e)
_LOGGER.error(error_info) _LOGGER.error(error_info)
self._push_to_output_channels( error_channeldata = ChannelData(
ChannelData(
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value, ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
error_info=error_info, error_info=error_info,
data_id=data_id), data_id=data_id)
output_channels)
continue
except TypeError as e: except TypeError as e:
# Error type in channeldata.datatype # Error type in channeldata.datatype
error_info = log(e) error_info = log_func(e)
_LOGGER.error(error_info) _LOGGER.error(error_info)
self._push_to_output_channels( error_channeldata = ChannelData(
ChannelData(
ecode=ChannelDataEcode.TYPE_ERROR.value, ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info=error_info, error_info=error_info,
data_id=data_id), data_id=data_id)
output_channels)
continue
except Exception as e: except Exception as e:
error_info = log(e) error_info = log_func(e)
_LOGGER.error(error_info) _LOGGER.error(error_info)
self._push_to_output_channels( error_channeldata = ChannelData(
ChannelData(
ecode=ChannelDataEcode.UNKNOW.value, ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info, error_info=error_info,
data_id=data_id), data_id=data_id)
output_channels) return preped_data, error_channeldata
continue
# process def _run_process(self, preped_data, data_id, log_func):
midped_data = None midped_data, error_channeldata = None, None
if self.with_serving: if self.with_serving:
ecode = ChannelDataEcode.OK.value ecode = ChannelDataEcode.OK.value
self._profiler_record("{}-midp#{}_0".format(op_info_prefix,
tid))
if self._timeout <= 0: if self._timeout <= 0:
try: try:
midped_data = self.process(preped_data) midped_data = self.process(preped_data)
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e) error_info = log_func(e)
_LOGGER.error(error_info) _LOGGER.error(error_info)
else: else:
for i in range(self._retry): for i in range(self._retry):
try: try:
midped_data = func_timeout.func_timeout( midped_data = func_timeout.func_timeout(
self._timeout, self._timeout, self.process, args=(preped_data, ))
self.process,
args=(preped_data, ))
except func_timeout.FunctionTimedOut as e: except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry: if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value ecode = ChannelDataEcode.TIMEOUT.value
error_info = log(e) error_info = log_func(e)
_LOGGER.error(error_info) _LOGGER.error(error_info)
else: else:
_LOGGER.warn( _LOGGER.warn(
log("timeout, retry({})".format(i + 1))) log_func("timeout, retry({})".format(i + 1)))
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log(e) error_info = log_func(e)
_LOGGER.error(error_info) _LOGGER.error(error_info)
break break
else: else:
break break
if ecode != ChannelDataEcode.OK.value: if ecode != ChannelDataEcode.OK.value:
self._push_to_output_channels( error_channeldata = ChannelData(
ChannelData( ecode=ecode, error_info=error_info, data_id=data_id)
ecode=ecode, error_info=error_info, elif midped_data is None:
data_id=data_id),
output_channels)
continue
self._profiler_record("{}-midp#{}_1".format(op_info_prefix,
tid))
# op client return None # op client return None
if midped_data is None: error_channeldata = ChannelData(
self._push_to_output_channels(
ChannelData(
ecode=ChannelDataEcode.CLIENT_ERROR.value, ecode=ChannelDataEcode.CLIENT_ERROR.value,
error_info=log( error_info=log_func(
"predict failed. pls check the server side."), "predict failed. pls check the server side."),
data_id=data_id), data_id=data_id)
output_channels)
continue
else: else:
midped_data = preped_data midped_data = preped_data
return midped_data, error_channeldata
# postprocess def _run_postprocess(self, midped_data, data_id, log_func):
output_data = None output_data, error_channeldata = None, None
self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid))
try: try:
postped_data = self.postprocess(midped_data) postped_data = self.postprocess(midped_data)
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value error_info = log_func(e)
error_info = log(e)
_LOGGER.error(error_info) _LOGGER.error(error_info)
self._push_to_output_channels( error_channeldata = ChannelData(
ChannelData( ecode=ChannelDataEcode.UNKNOW.value,
ecode=ecode, error_info=error_info, data_id=data_id), error_info=error_info,
output_channels) data_id=data_id)
continue return output_data, error_channeldata
if not isinstance(postped_data, dict): if not isinstance(postped_data, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value error_info = log_func("output of postprocess funticon must be " \
error_info = log("output of postprocess funticon must be " \
"dict type, but get {}".format(type(postped_data))) "dict type, but get {}".format(type(postped_data)))
_LOGGER.error(error_info) _LOGGER.error(error_info)
self._push_to_output_channels( error_channeldata = ChannelData(
ChannelData( ecode=ChannelDataEcode.UNKNOW.value,
ecode=ecode, error_info=error_info, data_id=data_id), error_info=error_info,
output_channels) data_id=data_id)
continue return output_data, error_channeldata
err, _ = ChannelData.check_npdata(postped_data) err, _ = ChannelData.check_npdata(postped_data)
if err == 0: if err == 0:
...@@ -368,7 +308,71 @@ class Op(object): ...@@ -368,7 +308,71 @@ class Op(object):
ChannelDataType.DICT.value, ChannelDataType.DICT.value,
dictdata=postped_data, dictdata=postped_data,
data_id=data_id) data_id=data_id)
return output_data, error_channeldata
def _run(self, concurrency_idx, input_channel, output_channels,
client_type):
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
return log_func
op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident
# create client based on client_type
self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names)
# load user resources
self.load_user_resources()
self._is_run = True
while self._is_run:
self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
channeldata_dict = input_channel.front(self.name)
self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid))
_LOGGER.debug(log("input_data: {}".format(channeldata_dict)))
data_id, error_channeldata, parsed_data = self._parse_channeldata(
channeldata_dict)
# error data in predecessor Op
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
continue
# preprecess
self._profiler_record("{}-prep#{}_0".format(op_info_prefix, tid))
preped_data, error_channeldata = self._run_preprocess(parsed_data,
data_id, log)
self._profiler_record("{}-prep#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
continue
# process
self._profiler_record("{}-midp#{}_0".format(op_info_prefix, tid))
midped_data, error_channeldata = self._run_process(preped_data,
data_id, log)
self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
continue
# postprocess
self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid))
output_data, error_channeldata = self._run_postprocess(midped_data,
data_id, log)
self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid)) self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
output_channels)
continue
# push data to channel (if run succ) # push data to channel (if run succ)
self._profiler_record("{}-push#{}_0".format(op_info_prefix, tid)) self._profiler_record("{}-push#{}_0".format(op_info_prefix, tid))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册