提交 0727cf98 编写于 作者: B barriery

update _run_preprocess, _run_process and _run_postprocess

上级 da28e782
...@@ -156,6 +156,7 @@ class Op(object): ...@@ -156,6 +156,7 @@ class Op(object):
return input_dict return input_dict
def process(self, feed_dict): def process(self, feed_dict):
#TODO: check batch
err, err_info = ChannelData.check_npdata(feed_dict) err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0: if err != 0:
raise NotImplementedError( raise NotImplementedError(
...@@ -235,42 +236,52 @@ class Op(object): ...@@ -235,42 +236,52 @@ class Op(object):
def init_op(self): def init_op(self):
pass pass
def _run_preprocess(self, parsed_data, data_id, log_func): def _run_preprocess(self, parsed_data_dict, log_func):
preped_data, error_channeldata = None, None preped_data_dict = {}
try: err_channeldata_dict = {}
preped_data = self.preprocess(parsed_data) for data_id, parsed_data in parsed_data_dict.items():
except NotImplementedError as e: preped_data, error_channeldata = None, None
# preprocess function not implemented try:
error_info = log_func(e) preped_data = self.preprocess(parsed_data)
_LOGGER.error(error_info) except NotImplementedError as e:
error_channeldata = ChannelData( # preprocess function not implemented
ecode=ChannelDataEcode.NOT_IMPLEMENTED.value, error_info = log_func(e)
error_info=error_info, _LOGGER.error(error_info)
data_id=data_id) error_channeldata = ChannelData(
except TypeError as e: ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
# Error type in channeldata.datatype error_info=error_info,
error_info = log_func(e) data_id=data_id)
_LOGGER.error(error_info) except TypeError as e:
error_channeldata = ChannelData( # Error type in channeldata.datatype
ecode=ChannelDataEcode.TYPE_ERROR.value, error_info = log_func(e)
error_info=error_info, _LOGGER.error(error_info)
data_id=data_id) error_channeldata = ChannelData(
except Exception as e: ecode=ChannelDataEcode.TYPE_ERROR.value,
error_info = log_func(e) error_info=error_info,
_LOGGER.error(error_info) data_id=data_id)
error_channeldata = ChannelData( except Exception as e:
ecode=ChannelDataEcode.UNKNOW.value, error_info = log_func(e)
error_info=error_info, _LOGGER.error(error_info)
data_id=data_id) error_channeldata = ChannelData(
return preped_data, error_channeldata ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
def _run_process(self, preped_data, data_id, log_func): data_id=data_id)
midped_data, error_channeldata = None, None if error_channeldata is not None:
err_channeldata_dict[data_id] = error_channeldata
else:
preped_data_dict[data_id] = preped_data
return preped_data_dict, err_channeldata_dict
def _run_process(self, preped_data_dict, log_func):
midped_data_dict = {}
err_channeldata_dict = {}
if self.with_serving: if self.with_serving:
data_ids = preped_data_dict.keys()
batch = [preped_data_dict[data_id] for data_id in data_ids]
ecode = ChannelDataEcode.OK.value ecode = ChannelDataEcode.OK.value
if self._timeout <= 0: if self._timeout <= 0:
try: try:
midped_data = self.process(preped_data) midped_data = self.process(batch)
except Exception as e: except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e) error_info = log_func(e)
...@@ -278,8 +289,8 @@ class Op(object): ...@@ -278,8 +289,8 @@ class Op(object):
else: else:
for i in range(self._retry): for i in range(self._retry):
try: try:
midped_data = func_timeout.func_timeout( midped_batch = func_timeout.func_timeout(
self._timeout, self.process, args=(preped_data, )) self._timeout, self.process, args=(batch, ))
except func_timeout.FunctionTimedOut as e: except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry: if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value ecode = ChannelDataEcode.TIMEOUT.value
...@@ -296,54 +307,73 @@ class Op(object): ...@@ -296,54 +307,73 @@ class Op(object):
else: else:
break break
if ecode != ChannelDataEcode.OK.value: if ecode != ChannelDataEcode.OK.value:
error_channeldata = ChannelData( for data_id in data_ids:
ecode=ecode, error_info=error_info, data_id=data_id) err_channeldata_dict[data_id] = ChannelData(
elif midped_data is None: ecode=ecode,
error_info=error_info,
data_id=data_id)
elif midped_batch is None:
# op client return None # op client return None
error_channeldata = ChannelData( for data_id in data_ids:
ecode=ChannelDataEcode.CLIENT_ERROR.value, err_channeldata_dict[data_id] = ChannelData(
error_info=log_func( ecode=ChannelDataEcode.CLIENT_ERROR.value,
"predict failed. pls check the server side."), error_info=log_func(
data_id=data_id) "predict failed. pls check the server side."),
else: data_id=data_id)
midped_data = preped_data else:
return midped_data, error_channeldata # transform np format to dict format
for idx, data_id in enumerate(data_ids):
def _run_postprocess(self, input_dict, midped_data, data_id, log_func): midped_data_dict[data_id] = {
output_data, error_channeldata = None, None k: v[idx] for k, v in midped_batch.items()
try: }
postped_data = self.postprocess(input_dict, midped_data)
except Exception as e:
error_info = log_func(e)
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id)
return output_data, error_channeldata
if not isinstance(postped_data, dict):
error_info = log_func("output of postprocess funticon must be " \
"dict type, but get {}".format(type(postped_data)))
_LOGGER.error(error_info)
error_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id)
return output_data, error_channeldata
err, _ = ChannelData.check_npdata(postped_data)
if err == 0:
output_data = ChannelData(
ChannelDataType.CHANNEL_NPDATA.value,
npdata=postped_data,
data_id=data_id)
else: else:
output_data = ChannelData( midped_data_dict = preped_data_dict
ChannelDataType.DICT.value, return midped_data_dict, err_channeldata_dict
dictdata=postped_data,
data_id=data_id) def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func):
return output_data, error_channeldata postped_data_dict = {}
err_channeldata_dict = {}
for data_id, midped_data in mided_data_dict.items():
postped_data, err_channeldata = None, None
try:
postped_data = self.postprocess(
parsed_data_dict[data_id], midped_data)
except Exception as e:
error_info = log_func(e)
_LOGGER.error(error_info)
err_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id)
if err_channeldata is not None:
err_channeldata_dict[data_id] = err_channeldata
continue
else:
if not isinstance(postped_data, dict):
error_info = log_func("output of postprocess funticon must be " \
"dict type, but get {}".format(type(postped_data)))
_LOGGER.error(error_info)
err_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id)
err_channeldata_dict[data_id] = err_channeldata
continue
output_data = None
err, _ = ChannelData.check_npdata(postped_data)
if err == 0:
output_data = ChannelData(
ChannelDataType.CHANNEL_NPDATA.value,
npdata=postped_data,
data_id=data_id)
else:
output_data = ChannelData(
ChannelDataType.DICT.value,
dictdata=postped_data,
data_id=data_id)
postped_data_dict[data_id] = output_data
return postped_data_dict, err_channeldata_dict
def _auto_batching_generator(self, input_channel, op_name, batch_size, timeout): def _auto_batching_generator(self, input_channel, op_name, batch_size, timeout):
while True: while True:
...@@ -369,8 +399,28 @@ class Op(object): ...@@ -369,8 +399,28 @@ class Op(object):
break break
yield batch yield batch
def _run(self, concurrency_idx, input_channel, output_channels, client_type, def _parse_channeldata_batch(self, batch, output_channels):
is_thread_op): parsed_data_dict = {}
need_profile_dict = {}
profile_dict = {}
for channeldata_dict in channeldata_dict_batch:
(data_id, error_channeldata, parsed_data,
client_need_profile, profile_set) = \
self._parse_channeldata(channeldata_dict)
if error_channeldata is None:
parsed_data_dict[data_id] = parsed_data
need_profile_dict[data_id] = client_need_profile
profile_dict[data_id] = profile_set
else:
# error data in predecessor Op
# (error_channeldata with profile info)
self._push_to_output_channels(
error_channeldata, output_channels)
return parsed_data_dict, need_profile_dict, profile_dict
def _run(self, concurrency_idx, input_channel, output_channels,
client_type, is_thread_op):
def get_log_func(op_info_prefix): def get_log_func(op_info_prefix):
def log_func(info_str): def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str) return "{} {}".format(op_info_prefix, info_str)
...@@ -395,7 +445,6 @@ class Op(object): ...@@ -395,7 +445,6 @@ class Op(object):
timeout=self._auto_batching_timeout) timeout=self._auto_batching_timeout)
while True: while True:
channeldata_dict_batch = None
try: try:
channeldata_dict_batch = next(batch_generator) channeldata_dict_batch = next(batch_generator)
except ChannelStopError: except ChannelStopError:
...@@ -405,95 +454,86 @@ class Op(object): ...@@ -405,95 +454,86 @@ class Op(object):
# parse channeldata batch # parse channeldata batch
try: try:
# parse channeldata batch parsed_data_dict, need_profile_dict, profile_dict \
= self._parse_channeldata_batch(
channeldata_dict_batch, output_channels)
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("stop."))
self._finalize(is_thread_op)
break break
nor_dataid_list = [] if len(parsed_data_dict) == 0:
err_dataid_list = [] # data in the whole batch is all error data
nor_datas = {} continue
err_datas = {}
for channeldata_dict in channeldata_dict_batch:
(data_id, error_channeldata, parsed_data,
client_need_profile, profile_set) = \
self._parse_channeldata(channeldata_dict)
if error_channeldata is None:
nor_dataid_list.append(data_id)
nor_datas[data_id] = {
"pd": parsed_data,
"np": client_need_profile,
"ps": profile_set,
}
else:
# error data in predecessor Op
try:
# error_channeldata with profile info
self._push_to_output_channels(error_channeldata,
output_channels)
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
# preprecess # preprecess
self._profiler_record("prep#{}_0".format(op_info_prefix)) self._profiler_record("prep#{}_0".format(op_info_prefix))
preped_data, error_channeldata = self._run_preprocess(parsed_data, preped_data_dict, err_channeldata_dict \
data_id, log) = self._run_preprocess(parsed_data_dict, log)
self._profiler_record("prep#{}_1".format(op_info_prefix)) self._profiler_record("prep#{}_1".format(op_info_prefix))
if error_channeldata is not None: try:
try: for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels( self._push_to_output_channels(
error_channeldata, err_channeldata,
output_channels, output_channels,
client_need_profile=client_need_profile, client_need_profile=need_profile_dict[data_id],
profile_set=profile_set) profile_set=profile_dict[data_id])
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("stop."))
break self._finalize(is_thread_op)
break
if len(parsed_data_dict) == 0:
continue continue
# process # process
self._profiler_record("midp#{}_0".format(op_info_prefix)) self._profiler_record("midp#{}_0".format(op_info_prefix))
midped_data, error_channeldata = self._run_process(preped_data, midped_data_dict, err_channeldata_dict \
data_id, log) = self._run_process(preped_data_dict, log)
self._profiler_record("midp#{}_1".format(op_info_prefix)) self._profiler_record("midp#{}_1".format(op_info_prefix))
if error_channeldata is not None: try:
try: for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels( self._push_to_output_channels(
error_channeldata, err_channeldata,
output_channels, output_channels,
client_need_profile=client_need_profile, client_need_profile=need_profile_dict[data_id],
profile_set=profile_set) profile_set=profile_dict[data_id])
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("stop."))
break self._finalize(is_thread_op)
break
if len(midped_data_dict) == 0:
continue continue
# postprocess # postprocess
self._profiler_record("postp#{}_0".format(op_info_prefix)) self._profiler_record("postp#{}_0".format(op_info_prefix))
output_data, error_channeldata = self._run_postprocess( postped_data_dict, err_channeldata_dict \
parsed_data, midped_data, data_id, log) = self._run_postprocess(
parsed_data_dict, midped_data_dict, log)
self._profiler_record("postp#{}_1".format(op_info_prefix)) self._profiler_record("postp#{}_1".format(op_info_prefix))
if error_channeldata is not None: try:
try: for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels( self._push_to_output_channels(
error_channeldata, error_channeldata,
output_channels, output_channels,
client_need_profile=client_need_profile, client_need_profile=need_profile_dict[data_id],
profile_set=profile_set) profile_set=profile_dict[data_id])
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("stop."))
break self._finalize(is_thread_op)
break
if len(postped_data_dict) == 0:
continue continue
# push data to channel (if run succ) # push data to channel (if run succ)
try: try:
self._push_to_output_channels( for data_id, postped_data in postped_data_dict.items():
output_data, self._push_to_output_channels(
output_channels, postped_data,
client_need_profile=client_need_profile, output_channels,
profile_set=profile_set) client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id])
except ChannelStopError: except ChannelStopError:
_LOGGER.debug(log("stop.")) _LOGGER.debug(log("stop."))
self._finalize(is_thread_op)
break break
def _initialize(self, is_thread_op): def _initialize(self, is_thread_op):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册