提交 0727cf98 编写于 作者: B barriery

update _run_preprocess, _run_process and _run_postprocess

上级 da28e782
......@@ -156,6 +156,7 @@ class Op(object):
return input_dict
def process(self, feed_dict):
#TODO: check batch
err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0:
raise NotImplementedError(
......@@ -235,7 +236,10 @@ class Op(object):
def init_op(self):
pass
def _run_preprocess(self, parsed_data, data_id, log_func):
def _run_preprocess(self, parsed_data_dict, log_func):
preped_data_dict = {}
err_channeldata_dict = {}
for data_id, parsed_data in parsed_data_dict.items():
preped_data, error_channeldata = None, None
try:
preped_data = self.preprocess(parsed_data)
......@@ -262,15 +266,22 @@ class Op(object):
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id)
return preped_data, error_channeldata
if error_channeldata is not None:
err_channeldata_dict[data_id] = error_channeldata
else:
preped_data_dict[data_id] = preped_data
return preped_data_dict, err_channeldata_dict
def _run_process(self, preped_data, data_id, log_func):
midped_data, error_channeldata = None, None
def _run_process(self, preped_data_dict, log_func):
midped_data_dict = {}
err_channeldata_dict = {}
if self.with_serving:
data_ids = preped_data_dict.keys()
batch = [preped_data_dict[data_id] for data_id in data_ids]
ecode = ChannelDataEcode.OK.value
if self._timeout <= 0:
try:
midped_data = self.process(preped_data)
midped_data = self.process(batch)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
......@@ -278,8 +289,8 @@ class Op(object):
else:
for i in range(self._retry):
try:
midped_data = func_timeout.func_timeout(
self._timeout, self.process, args=(preped_data, ))
midped_batch = func_timeout.func_timeout(
self._timeout, self.process, args=(batch, ))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
......@@ -296,42 +307,60 @@ class Op(object):
else:
break
if ecode != ChannelDataEcode.OK.value:
error_channeldata = ChannelData(
ecode=ecode, error_info=error_info, data_id=data_id)
elif midped_data is None:
for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData(
ecode=ecode,
error_info=error_info,
data_id=data_id)
elif midped_batch is None:
# op client return None
error_channeldata = ChannelData(
for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData(
ecode=ChannelDataEcode.CLIENT_ERROR.value,
error_info=log_func(
"predict failed. pls check the server side."),
data_id=data_id)
else:
midped_data = preped_data
return midped_data, error_channeldata
def _run_postprocess(self, input_dict, midped_data, data_id, log_func):
output_data, error_channeldata = None, None
# transform np format to dict format
for idx, data_id in enumerate(data_ids):
midped_data_dict[data_id] = {
k: v[idx] for k, v in midped_batch.items()
}
else:
midped_data_dict = preped_data_dict
return midped_data_dict, err_channeldata_dict
def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func):
postped_data_dict = {}
err_channeldata_dict = {}
for data_id, midped_data in mided_data_dict.items():
postped_data, err_channeldata = None, None
try:
postped_data = self.postprocess(input_dict, midped_data)
postped_data = self.postprocess(
parsed_data_dict[data_id], midped_data)
except Exception as e:
error_info = log_func(e)
_LOGGER.error(error_info)
error_channeldata = ChannelData(
err_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id)
return output_data, error_channeldata
if err_channeldata is not None:
err_channeldata_dict[data_id] = err_channeldata
continue
else:
if not isinstance(postped_data, dict):
error_info = log_func("output of postprocess funticon must be " \
"dict type, but get {}".format(type(postped_data)))
_LOGGER.error(error_info)
error_channeldata = ChannelData(
err_channeldata = ChannelData(
ecode=ChannelDataEcode.UNKNOW.value,
error_info=error_info,
data_id=data_id)
return output_data, error_channeldata
err_channeldata_dict[data_id] = err_channeldata
continue
output_data = None
err, _ = ChannelData.check_npdata(postped_data)
if err == 0:
output_data = ChannelData(
......@@ -343,7 +372,8 @@ class Op(object):
ChannelDataType.DICT.value,
dictdata=postped_data,
data_id=data_id)
return output_data, error_channeldata
postped_data_dict[data_id] = output_data
return postped_data_dict, err_channeldata_dict
def _auto_batching_generator(self, input_channel, op_name, batch_size, timeout):
while True:
......@@ -369,8 +399,28 @@ class Op(object):
break
yield batch
def _run(self, concurrency_idx, input_channel, output_channels, client_type,
is_thread_op):
def _parse_channeldata_batch(self, batch, output_channels):
parsed_data_dict = {}
need_profile_dict = {}
profile_dict = {}
for channeldata_dict in channeldata_dict_batch:
(data_id, error_channeldata, parsed_data,
client_need_profile, profile_set) = \
self._parse_channeldata(channeldata_dict)
if error_channeldata is None:
parsed_data_dict[data_id] = parsed_data
need_profile_dict[data_id] = client_need_profile
profile_dict[data_id] = profile_set
else:
# error data in predecessor Op
# (error_channeldata with profile info)
self._push_to_output_channels(
error_channeldata, output_channels)
return parsed_data_dict, need_profile_dict, profile_dict
def _run(self, concurrency_idx, input_channel, output_channels,
client_type, is_thread_op):
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
......@@ -395,7 +445,6 @@ class Op(object):
timeout=self._auto_batching_timeout)
while True:
channeldata_dict_batch = None
try:
channeldata_dict_batch = next(batch_generator)
except ChannelStopError:
......@@ -405,95 +454,86 @@ class Op(object):
# parse channeldata batch
try:
# parse channeldata batch
except ChannelStopError:
_LOGGER.debug(log("stop."))
break
nor_dataid_list = []
err_dataid_list = []
nor_datas = {}
err_datas = {}
for channeldata_dict in channeldata_dict_batch:
(data_id, error_channeldata, parsed_data,
client_need_profile, profile_set) = \
self._parse_channeldata(channeldata_dict)
if error_channeldata is None:
nor_dataid_list.append(data_id)
nor_datas[data_id] = {
"pd": parsed_data,
"np": client_need_profile,
"ps": profile_set,
}
else:
# error data in predecessor Op
try:
# error_channeldata with profile info
self._push_to_output_channels(error_channeldata,
output_channels)
parsed_data_dict, need_profile_dict, profile_dict \
= self._parse_channeldata_batch(
channeldata_dict_batch, output_channels)
except ChannelStopError:
_LOGGER.debug(log("stop."))
self._finalize(is_thread_op)
break
if len(parsed_data_dict) == 0:
# data in the whole batch is all error data
continue
# preprecess
self._profiler_record("prep#{}_0".format(op_info_prefix))
preped_data, error_channeldata = self._run_preprocess(parsed_data,
data_id, log)
preped_data_dict, err_channeldata_dict \
= self._run_preprocess(parsed_data_dict, log)
self._profiler_record("prep#{}_1".format(op_info_prefix))
if error_channeldata is not None:
try:
for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels(
error_channeldata,
err_channeldata,
output_channels,
client_need_profile=client_need_profile,
profile_set=profile_set)
client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id])
except ChannelStopError:
_LOGGER.debug(log("stop."))
self._finalize(is_thread_op)
break
if len(parsed_data_dict) == 0:
continue
# process
self._profiler_record("midp#{}_0".format(op_info_prefix))
midped_data, error_channeldata = self._run_process(preped_data,
data_id, log)
midped_data_dict, err_channeldata_dict \
= self._run_process(preped_data_dict, log)
self._profiler_record("midp#{}_1".format(op_info_prefix))
if error_channeldata is not None:
try:
for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels(
error_channeldata,
err_channeldata,
output_channels,
client_need_profile=client_need_profile,
profile_set=profile_set)
client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id])
except ChannelStopError:
_LOGGER.debug(log("stop."))
self._finalize(is_thread_op)
break
if len(midped_data_dict) == 0:
continue
# postprocess
self._profiler_record("postp#{}_0".format(op_info_prefix))
output_data, error_channeldata = self._run_postprocess(
parsed_data, midped_data, data_id, log)
postped_data_dict, err_channeldata_dict \
= self._run_postprocess(
parsed_data_dict, midped_data_dict, log)
self._profiler_record("postp#{}_1".format(op_info_prefix))
if error_channeldata is not None:
try:
for data_id, err_channeldata in err_channeldata_dict.items():
self._push_to_output_channels(
error_channeldata,
output_channels,
client_need_profile=client_need_profile,
profile_set=profile_set)
client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id])
except ChannelStopError:
_LOGGER.debug(log("stop."))
self._finalize(is_thread_op)
break
if len(postped_data_dict) == 0:
continue
# push data to channel (if run succ)
try:
for data_id, postped_data in postped_data_dict.items():
self._push_to_output_channels(
output_data,
postped_data,
output_channels,
client_need_profile=client_need_profile,
profile_set=profile_set)
client_need_profile=need_profile_dict[data_id],
profile_set=profile_dict[data_id])
except ChannelStopError:
_LOGGER.debug(log("stop."))
self._finalize(is_thread_op)
break
def _initialize(self, is_thread_op):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册