未验证 提交 45444820 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge pull request #1183 from TeslaZhao/develop

Support mini-batch in Pipeline mode
......@@ -79,6 +79,9 @@ class RecOp(Op):
feed_list = []
img_list = []
max_wh_ratio = 0
## One batch, the type of feed_data is dict.
"""
for i, dtbox in enumerate(dt_boxes):
boximg = self.get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg)
......@@ -92,14 +95,73 @@ class RecOp(Op):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img
feed = {"image": imgs.copy()}
return feed, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id):
rec_res = self.ocr_reader.postprocess(fetch_dict, with_score=True)
res_lst = []
for res in rec_res:
res_lst.append(res[0])
res = {"res": str(res_lst)}
"""
## Many mini-batchs, the type of feed_data is list.
max_batch_size = 6 # len(dt_boxes)
# If max_batch_size is 0, skipping predict stage
if max_batch_size == 0:
return {}, True, None, ""
boxes_size = len(dt_boxes)
batch_size = boxes_size // max_batch_size
rem = boxes_size % max_batch_size
#_LOGGER.info("max_batch_len:{}, batch_size:{}, rem:{}, boxes_size:{}".format(max_batch_size, batch_size, rem, boxes_size))
for bt_idx in range(0, batch_size + 1):
imgs = None
boxes_num_in_one_batch = 0
if bt_idx == batch_size:
if rem == 0:
continue
else:
boxes_num_in_one_batch = rem
elif bt_idx < batch_size:
boxes_num_in_one_batch = max_batch_size
else:
_LOGGER.error("batch_size error, bt_idx={}, batch_size={}".
format(bt_idx, batch_size))
break
start = bt_idx * max_batch_size
end = start + boxes_num_in_one_batch
img_list = []
for box_idx in range(start, end):
boximg = self.get_rotate_crop_image(im, dt_boxes[box_idx])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
_, w, h = self.ocr_reader.resize_norm_img(img_list[0],
max_wh_ratio).shape
#_LOGGER.info("---- idx:{}, w:{}, h:{}".format(bt_idx, w, h))
imgs = np.zeros((boxes_num_in_one_batch, 3, w, h)).astype('float32')
for id, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img
feed = {"image": imgs.copy()}
feed_list.append(feed)
#_LOGGER.info("feed_list : {}".format(feed_list))
return feed_list, False, None, ""
def postprocess(self, input_dicts, fetch_data, log_id):
res_list = []
if isinstance(fetch_data, dict):
if len(fetch_data) > 0:
rec_batch_res = self.ocr_reader.postprocess(
fetch_data, with_score=True)
for res in rec_batch_res:
res_list.append(res[0])
elif isinstance(fetch_data, list):
for one_batch in fetch_data:
one_batch_res = self.ocr_reader.postprocess(
one_batch, with_score=True)
for res in one_batch_res:
res_list.append(res[0])
res = {"res": str(res_list)}
return res, None, ""
......
......@@ -400,7 +400,7 @@ class Op(object):
log_id: global unique id for RTT, 0 default
Return:
input_dict: data for process stage
output_data: data for process stage
is_skip_process: skip process stage or not, False default
prod_errcode: None default, otherwise, product errores occured.
It is handled in the same way as exception.
......@@ -453,20 +453,23 @@ class Op(object):
call_result.pop("serving_status_code")
return call_result
def postprocess(self, input_dict, fetch_dict, log_id=0):
def postprocess(self, input_data, fetch_data, log_id=0):
"""
In postprocess stage, assemble data for next op or output.
Args:
input_dict: data returned in preprocess stage.
fetch_dict: data returned in process stage.
input_data: data returned in preprocess stage, dict(for single predict) or list(for batch predict)
fetch_data: data returned in process stage, dict(for single predict) or list(for batch predict)
log_id: logid, 0 default
Returns:
fetch_dict: return fetch_dict default
fetch_dict: fetch result must be dict type.
prod_errcode: None default, otherwise, product errores occured.
It is handled in the same way as exception.
prod_errinfo: "" default
"""
fetch_dict = {}
if isinstance(fetch_data, dict):
fetch_dict = fetch_data
return fetch_dict, None, ""
def _parse_channeldata(self, channeldata_dict):
......@@ -685,41 +688,46 @@ class Op(object):
_LOGGER.debug("{} Running process".format(op_info_prefix))
midped_data_dict = collections.OrderedDict()
err_channeldata_dict = collections.OrderedDict()
### if (batch_num == 1 && skip == True) ,then skip the process stage.
is_skip_process = False
data_ids = list(preped_data_dict.keys())
# skip process stage
if len(data_ids) == 1 and skip_process_dict.get(data_ids[0]) == True:
is_skip_process = True
_LOGGER.info("(data_id={} log_id={}) skip process stage".format(
data_ids[0], logid_dict.get(data_ids[0])))
if self.with_serving is False or is_skip_process is True:
midped_data_dict = preped_data_dict
_LOGGER.warning("(data_id={} log_id={}) OP={} skip process stage. " \
"with_serving={}, is_skip_process={}".format(data_ids[0],
logid_dict.get(data_ids[0]), self.name, self.with_serving,
is_skip_process))
return midped_data_dict, err_channeldata_dict
if self.with_serving is True and is_skip_process is False:
# use typical_logid to mark batch data
# data_ids is one self-increasing unique key.
typical_logid = data_ids[0]
if len(data_ids) != 1:
for data_id in data_ids:
_LOGGER.info(
"(data_id={} logid={}) {} During access to PaddleServingService,"
" we selected logid={} (from batch: {}) as a "
"(data_id={} logid={}) Auto-batching is On Op={}!!" \
"We selected logid={} (from batch: {}) as a " \
"representative for logging.".format(
data_id,
logid_dict.get(data_id), op_info_prefix,
data_id, logid_dict.get(data_id), self.name,
typical_logid, data_ids))
# combine samples to batch
one_input = preped_data_dict[data_ids[0]]
feed_batch = []
feed_dict = {}
input_offset = None
cur_offset = 0
input_offset_dict = {}
batch_input = False
if isinstance(one_input, dict):
# sample input
# For dict type, data structure is dict.
# Merge multiple dicts for data_ids into one dict.
# feed_batch is the input param of predict func.
# input_offset_dict is used for data restration[data_ids]
if len(data_ids) == 1:
feed_batch = [
preped_data_dict[data_id] for data_id in data_ids
]
feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
else:
for data_id in data_ids:
for key, val in preped_data_dict[data_id].items():
......@@ -743,25 +751,45 @@ class Op(object):
break
input_offset_dict[data_id] = [start, cur_offset]
elif isinstance(one_input, list):
# batch input
input_offset = [0]
# For list type, data structure of one_input is [dict, dict, ...]
# Data structure of feed_batch is [dict1_1, dict1_2, dict2_1, ...]
# Data structure of input_offset_dict is { data_id : [start, end] }
batch_input = True
for data_id in data_ids:
batch_input = preped_data_dict[data_id]
offset = input_offset[-1] + len(batch_input)
feed_batch += batch_input
input_offset.append(offset)
feed_batch.extend(preped_data_dict[data_id])
data_size = len(preped_data_dict[data_id])
start = cur_offset
cur_offset = start + data_size
input_offset_dict[data_id] = [start, cur_offset]
else:
_LOGGER.critical(
"(data_id={} log_id={}){} Failed to process: expect input type is dict(sample"
" input) or list(batch input), but get {}".format(data_ids[
"(data_id={} log_id={}){} Failed to process: expect input type is dict"
" or list(batch input), but get {}".format(data_ids[
0], typical_logid, op_info_prefix, type(one_input)))
os._exit(-1)
for data_id in data_ids:
error_code = ChannelDataErrcode.TYPE_ERROR.value
error_info = "expect input type is dict or list, but get {}".format(
type(one_input))
err_channeldata_dict[data_id] = ChannelData(
error_code=error_code,
error_info=error_info,
data_id=data_id,
log_id=logid_dict.get(data_id))
return midped_data_dict, err_channeldata_dict
midped_batch = None
error_code = ChannelDataErrcode.OK.value
if self._timeout <= 0:
# No retry
try:
if batch_input is False:
midped_batch = self.process(feed_batch, typical_logid)
else:
midped_batch = []
for idx in range(len(feed_batch)):
predict_res = self.process([feed_batch[idx]],
typical_logid)
midped_batch.append(predict_res)
except Exception as e:
error_code = ChannelDataErrcode.UNKNOW.value
error_info = "(data_id={} log_id={}) {} Failed to process(batch: {}): {}".format(
......@@ -772,23 +800,32 @@ class Op(object):
for i in range(self._retry):
try:
# time out for each process
if batch_input is False:
midped_batch = func_timeout.func_timeout(
self._timeout,
self.process,
args=(feed_batch, typical_logid))
else:
midped_batch = []
for idx in range(len(feed_batch)):
predict_res = func_timeout.func_timeout(
self._timeout,
self.process,
args=([feed_batch[idx]], typical_logid))
midped_batch[idx].append(predict_res)
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
error_code = ChannelDataErrcode.TIMEOUT.value
error_info = "(log_id={}) {} Failed to process(batch: {}): " \
"exceeded retry count.".format(
typical_logid, op_info_prefix, data_ids)
"exceeded retry count.".format(typical_logid, op_info_prefix, data_ids)
_LOGGER.error(error_info)
else:
_LOGGER.warning(
"(log_id={}) {} Failed to process(batch: {}): timeout,"
" and retrying({}/{})...".format(
typical_logid, op_info_prefix, data_ids, i +
1, self._retry))
typical_logid, op_info_prefix, data_ids, i + 1,
self._retry))
except Exception as e:
error_code = ChannelDataErrcode.UNKNOW.value
error_info = "(log_id={}) {} Failed to process(batch: {}): {}".format(
......@@ -797,18 +834,11 @@ class Op(object):
break
else:
break
if error_code != ChannelDataErrcode.OK.value:
for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData(
error_code=error_code,
error_info=error_info,
data_id=data_id,
log_id=logid_dict.get(data_id))
elif midped_batch is None:
# op client return None
error_info = "(log_id={}) {} Failed to predict, please check if " \
"PaddleServingService is working properly.".format(
typical_logid, op_info_prefix)
# 2 kinds of errors
if error_code != ChannelDataErrcode.OK.value or midped_batch is None:
error_info = "(log_id={}) {} failed to predict.".format(
typical_logid, self.name)
_LOGGER.error(error_info)
for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData(
......@@ -816,18 +846,19 @@ class Op(object):
error_info=error_info,
data_id=data_id,
log_id=logid_dict.get(data_id))
else:
# transform np format to dict format
return midped_data_dict, err_channeldata_dict
# Split batch infer result to each data_ids
if batch_input is False:
var_names = midped_batch.keys()
lod_var_names = set()
lod_offset_names = set()
# midped_batch is dict type for single input
for name in var_names:
lod_offset_name = "{}.lod".format(name)
if lod_offset_name in var_names:
_LOGGER.debug(
"(log_id={}) {} {} is LodTensor. lod_offset_name:{}".
format(typical_logid, op_info_prefix, name,
lod_offset_name))
_LOGGER.debug("(log_id={}) {} {} is LodTensor".format(
typical_logid, op_info_prefix, name))
lod_var_names.add(name)
lod_offset_names.add(lod_offset_name)
......@@ -853,12 +884,15 @@ class Op(object):
else:
# normal tensor
for idx, data_id in enumerate(data_ids):
left = input_offset_dict[data_id][0]
right = input_offset_dict[data_id][1]
midped_data_dict[data_id][name] = value[left:right]
start = input_offset_dict[data_id][0]
end = input_offset_dict[data_id][1]
midped_data_dict[data_id][name] = value[start:end]
else:
midped_data_dict = preped_data_dict
_LOGGER.debug("{} Succ process".format(op_info_prefix))
# midped_batch is list type for batch input
for idx, data_id in enumerate(data_ids):
start = input_offset_dict[data_id][0]
end = input_offset_dict[data_id][1]
midped_data_dict[data_id] = midped_batch[start:end]
return midped_data_dict, err_channeldata_dict
def _run_postprocess(self, parsed_data_dict, midped_data_dict,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册