提交 ff7774a4 编写于 作者: T TeslaZhao

Support mini-batch in Pipeline mode

上级 567dc666
...@@ -79,6 +79,9 @@ class RecOp(Op): ...@@ -79,6 +79,9 @@ class RecOp(Op):
feed_list = [] feed_list = []
img_list = [] img_list = []
max_wh_ratio = 0 max_wh_ratio = 0
## One batch, the type of feed_data is dict.
"""
for i, dtbox in enumerate(dt_boxes): for i, dtbox in enumerate(dt_boxes):
boximg = self.get_rotate_crop_image(im, dt_boxes[i]) boximg = self.get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg) img_list.append(boximg)
...@@ -92,14 +95,73 @@ class RecOp(Op): ...@@ -92,14 +95,73 @@ class RecOp(Op):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img imgs[id] = norm_img
feed = {"image": imgs.copy()} feed = {"image": imgs.copy()}
return feed, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id): """
rec_res = self.ocr_reader.postprocess(fetch_dict, with_score=True)
res_lst = [] ## Many mini-batchs, the type of feed_data is list.
for res in rec_res: max_batch_size = 6 # len(dt_boxes)
res_lst.append(res[0])
res = {"res": str(res_lst)} # If max_batch_size is 0, skipping predict stage
if max_batch_size == 0:
return {}, True, None, ""
boxes_size = len(dt_boxes)
batch_size = boxes_size // max_batch_size
rem = boxes_size % max_batch_size
#_LOGGER.info("max_batch_len:{}, batch_size:{}, rem:{}, boxes_size:{}".format(max_batch_size, batch_size, rem, boxes_size))
for bt_idx in range(0, batch_size + 1):
imgs = None
boxes_num_in_one_batch = 0
if bt_idx == batch_size:
if rem == 0:
continue
else:
boxes_num_in_one_batch = rem
elif bt_idx < batch_size:
boxes_num_in_one_batch = max_batch_size
else:
_LOGGER.error("batch_size error, bt_idx={}, batch_size={}".
format(bt_idx, batch_size))
break
start = bt_idx * max_batch_size
end = start + boxes_num_in_one_batch
img_list = []
for box_idx in range(start, end):
boximg = self.get_rotate_crop_image(im, dt_boxes[box_idx])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
_, w, h = self.ocr_reader.resize_norm_img(img_list[0],
max_wh_ratio).shape
#_LOGGER.info("---- idx:{}, w:{}, h:{}".format(bt_idx, w, h))
imgs = np.zeros((boxes_num_in_one_batch, 3, w, h)).astype('float32')
for id, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img
feed = {"image": imgs.copy()}
feed_list.append(feed)
#_LOGGER.info("feed_list : {}".format(feed_list))
return feed_list, False, None, ""
def postprocess(self, input_dicts, fetch_data, log_id):
res_list = []
if isinstance(fetch_data, dict):
if len(fetch_data) > 0:
rec_batch_res = self.ocr_reader.postprocess(
fetch_data, with_score=True)
for res in rec_batch_res:
res_list.append(res[0])
elif isinstance(fetch_data, list):
for one_batch in fetch_data:
one_batch_res = self.ocr_reader.postprocess(
one_batch, with_score=True)
for res in one_batch_res:
res_list.append(res[0])
res = {"res": str(res_list)}
return res, None, "" return res, None, ""
......
...@@ -400,7 +400,7 @@ class Op(object): ...@@ -400,7 +400,7 @@ class Op(object):
log_id: global unique id for RTT, 0 default log_id: global unique id for RTT, 0 default
Return: Return:
input_dict: data for process stage output_data: data for process stage
is_skip_process: skip process stage or not, False default is_skip_process: skip process stage or not, False default
prod_errcode: None default, otherwise, product errores occured. prod_errcode: None default, otherwise, product errores occured.
It is handled in the same way as exception. It is handled in the same way as exception.
...@@ -453,20 +453,23 @@ class Op(object): ...@@ -453,20 +453,23 @@ class Op(object):
call_result.pop("serving_status_code") call_result.pop("serving_status_code")
return call_result return call_result
def postprocess(self, input_dict, fetch_dict, log_id=0): def postprocess(self, input_data, fetch_data, log_id=0):
""" """
In postprocess stage, assemble data for next op or output. In postprocess stage, assemble data for next op or output.
Args: Args:
input_dict: data returned in preprocess stage. input_data: data returned in preprocess stage, dict(for single predict) or list(for batch predict)
fetch_dict: data returned in process stage. fetch_data: data returned in process stage, dict(for single predict) or list(for batch predict)
log_id: logid, 0 default log_id: logid, 0 default
Returns: Returns:
fetch_dict: return fetch_dict default fetch_dict: fetch result must be dict type.
prod_errcode: None default, otherwise, product errores occured. prod_errcode: None default, otherwise, product errores occured.
It is handled in the same way as exception. It is handled in the same way as exception.
prod_errinfo: "" default prod_errinfo: "" default
""" """
fetch_dict = {}
if isinstance(fetch_data, dict):
fetch_dict = fetch_data
return fetch_dict, None, "" return fetch_dict, None, ""
def _parse_channeldata(self, channeldata_dict): def _parse_channeldata(self, channeldata_dict):
...@@ -685,41 +688,46 @@ class Op(object): ...@@ -685,41 +688,46 @@ class Op(object):
_LOGGER.debug("{} Running process".format(op_info_prefix)) _LOGGER.debug("{} Running process".format(op_info_prefix))
midped_data_dict = collections.OrderedDict() midped_data_dict = collections.OrderedDict()
err_channeldata_dict = collections.OrderedDict() err_channeldata_dict = collections.OrderedDict()
### if (batch_num == 1 && skip == True) ,then skip the process stage.
is_skip_process = False is_skip_process = False
data_ids = list(preped_data_dict.keys()) data_ids = list(preped_data_dict.keys())
# skip process stage
if len(data_ids) == 1 and skip_process_dict.get(data_ids[0]) == True: if len(data_ids) == 1 and skip_process_dict.get(data_ids[0]) == True:
is_skip_process = True is_skip_process = True
_LOGGER.info("(data_id={} log_id={}) skip process stage".format( if self.with_serving is False or is_skip_process is True:
data_ids[0], logid_dict.get(data_ids[0]))) midped_data_dict = preped_data_dict
_LOGGER.warning("(data_id={} log_id={}) OP={} skip process stage. " \
"with_serving={}, is_skip_process={}".format(data_ids[0],
logid_dict.get(data_ids[0]), self.name, self.with_serving,
is_skip_process))
return midped_data_dict, err_channeldata_dict
if self.with_serving is True and is_skip_process is False:
# use typical_logid to mark batch data # use typical_logid to mark batch data
# data_ids is one self-increasing unique key.
typical_logid = data_ids[0] typical_logid = data_ids[0]
if len(data_ids) != 1: if len(data_ids) != 1:
for data_id in data_ids: for data_id in data_ids:
_LOGGER.info( _LOGGER.info(
"(data_id={} logid={}) {} During access to PaddleServingService," "(data_id={} logid={}) Auto-batching is On Op={}!!" \
" we selected logid={} (from batch: {}) as a " "We selected logid={} (from batch: {}) as a " \
"representative for logging.".format( "representative for logging.".format(
data_id, data_id, logid_dict.get(data_id), self.name,
logid_dict.get(data_id), op_info_prefix,
typical_logid, data_ids)) typical_logid, data_ids))
# combine samples to batch
one_input = preped_data_dict[data_ids[0]] one_input = preped_data_dict[data_ids[0]]
feed_batch = [] feed_batch = []
feed_dict = {} feed_dict = {}
input_offset = None
cur_offset = 0 cur_offset = 0
input_offset_dict = {} input_offset_dict = {}
batch_input = False
if isinstance(one_input, dict): if isinstance(one_input, dict):
# sample input # For dict type, data structure is dict.
# Merge multiple dicts for data_ids into one dict.
# feed_batch is the input param of predict func.
# input_offset_dict is used for data restration[data_ids]
if len(data_ids) == 1: if len(data_ids) == 1:
feed_batch = [ feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
preped_data_dict[data_id] for data_id in data_ids
]
else: else:
for data_id in data_ids: for data_id in data_ids:
for key, val in preped_data_dict[data_id].items(): for key, val in preped_data_dict[data_id].items():
...@@ -743,25 +751,45 @@ class Op(object): ...@@ -743,25 +751,45 @@ class Op(object):
break break
input_offset_dict[data_id] = [start, cur_offset] input_offset_dict[data_id] = [start, cur_offset]
elif isinstance(one_input, list): elif isinstance(one_input, list):
# batch input # For list type, data structure of one_input is [dict, dict, ...]
input_offset = [0] # Data structure of feed_batch is [dict1_1, dict1_2, dict2_1, ...]
# Data structure of input_offset_dict is { data_id : [start, end] }
batch_input = True
for data_id in data_ids: for data_id in data_ids:
batch_input = preped_data_dict[data_id] feed_batch.extend(preped_data_dict[data_id])
offset = input_offset[-1] + len(batch_input) data_size = len(preped_data_dict[data_id])
feed_batch += batch_input start = cur_offset
input_offset.append(offset) cur_offset = start + data_size
input_offset_dict[data_id] = [start, cur_offset]
else: else:
_LOGGER.critical( _LOGGER.critical(
"(data_id={} log_id={}){} Failed to process: expect input type is dict(sample" "(data_id={} log_id={}){} Failed to process: expect input type is dict"
" input) or list(batch input), but get {}".format(data_ids[ " or list(batch input), but get {}".format(data_ids[
0], typical_logid, op_info_prefix, type(one_input))) 0], typical_logid, op_info_prefix, type(one_input)))
os._exit(-1) for data_id in data_ids:
error_code = ChannelDataErrcode.TYPE_ERROR.value
error_info = "expect input type is dict or list, but get {}".format(
type(one_input))
err_channeldata_dict[data_id] = ChannelData(
error_code=error_code,
error_info=error_info,
data_id=data_id,
log_id=logid_dict.get(data_id))
return midped_data_dict, err_channeldata_dict
midped_batch = None midped_batch = None
error_code = ChannelDataErrcode.OK.value error_code = ChannelDataErrcode.OK.value
if self._timeout <= 0: if self._timeout <= 0:
# No retry
try: try:
if batch_input is False:
midped_batch = self.process(feed_batch, typical_logid) midped_batch = self.process(feed_batch, typical_logid)
else:
midped_batch = []
for idx in range(len(feed_batch)):
predict_res = self.process([feed_batch[idx]],
typical_logid)
midped_batch.append(predict_res)
except Exception as e: except Exception as e:
error_code = ChannelDataErrcode.UNKNOW.value error_code = ChannelDataErrcode.UNKNOW.value
error_info = "(data_id={} log_id={}) {} Failed to process(batch: {}): {}".format( error_info = "(data_id={} log_id={}) {} Failed to process(batch: {}): {}".format(
...@@ -772,23 +800,32 @@ class Op(object): ...@@ -772,23 +800,32 @@ class Op(object):
for i in range(self._retry): for i in range(self._retry):
try: try:
# time out for each process # time out for each process
if batch_input is False:
midped_batch = func_timeout.func_timeout( midped_batch = func_timeout.func_timeout(
self._timeout, self._timeout,
self.process, self.process,
args=(feed_batch, typical_logid)) args=(feed_batch, typical_logid))
else:
midped_batch = []
for idx in range(len(feed_batch)):
predict_res = func_timeout.func_timeout(
self._timeout,
self.process,
args=([feed_batch[idx]], typical_logid))
midped_batch[idx].append(predict_res)
except func_timeout.FunctionTimedOut as e: except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry: if i + 1 >= self._retry:
error_code = ChannelDataErrcode.TIMEOUT.value error_code = ChannelDataErrcode.TIMEOUT.value
error_info = "(log_id={}) {} Failed to process(batch: {}): " \ error_info = "(log_id={}) {} Failed to process(batch: {}): " \
"exceeded retry count.".format( "exceeded retry count.".format(typical_logid, op_info_prefix, data_ids)
typical_logid, op_info_prefix, data_ids)
_LOGGER.error(error_info) _LOGGER.error(error_info)
else: else:
_LOGGER.warning( _LOGGER.warning(
"(log_id={}) {} Failed to process(batch: {}): timeout," "(log_id={}) {} Failed to process(batch: {}): timeout,"
" and retrying({}/{})...".format( " and retrying({}/{})...".format(
typical_logid, op_info_prefix, data_ids, i + typical_logid, op_info_prefix, data_ids, i + 1,
1, self._retry)) self._retry))
except Exception as e: except Exception as e:
error_code = ChannelDataErrcode.UNKNOW.value error_code = ChannelDataErrcode.UNKNOW.value
error_info = "(log_id={}) {} Failed to process(batch: {}): {}".format( error_info = "(log_id={}) {} Failed to process(batch: {}): {}".format(
...@@ -797,18 +834,11 @@ class Op(object): ...@@ -797,18 +834,11 @@ class Op(object):
break break
else: else:
break break
if error_code != ChannelDataErrcode.OK.value:
for data_id in data_ids: # 2 kinds of errors
err_channeldata_dict[data_id] = ChannelData( if error_code != ChannelDataErrcode.OK.value or midped_batch is None:
error_code=error_code, error_info = "(log_id={}) {} failed to predict.".format(
error_info=error_info, typical_logid, self.name)
data_id=data_id,
log_id=logid_dict.get(data_id))
elif midped_batch is None:
# op client return None
error_info = "(log_id={}) {} Failed to predict, please check if " \
"PaddleServingService is working properly.".format(
typical_logid, op_info_prefix)
_LOGGER.error(error_info) _LOGGER.error(error_info)
for data_id in data_ids: for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData( err_channeldata_dict[data_id] = ChannelData(
...@@ -816,18 +846,19 @@ class Op(object): ...@@ -816,18 +846,19 @@ class Op(object):
error_info=error_info, error_info=error_info,
data_id=data_id, data_id=data_id,
log_id=logid_dict.get(data_id)) log_id=logid_dict.get(data_id))
else: return midped_data_dict, err_channeldata_dict
# transform np format to dict format
# Split batch infer result to each data_ids
if batch_input is False:
var_names = midped_batch.keys() var_names = midped_batch.keys()
lod_var_names = set() lod_var_names = set()
lod_offset_names = set() lod_offset_names = set()
# midped_batch is dict type for single input
for name in var_names: for name in var_names:
lod_offset_name = "{}.lod".format(name) lod_offset_name = "{}.lod".format(name)
if lod_offset_name in var_names: if lod_offset_name in var_names:
_LOGGER.debug( _LOGGER.debug("(log_id={}) {} {} is LodTensor".format(
"(log_id={}) {} {} is LodTensor. lod_offset_name:{}". typical_logid, op_info_prefix, name))
format(typical_logid, op_info_prefix, name,
lod_offset_name))
lod_var_names.add(name) lod_var_names.add(name)
lod_offset_names.add(lod_offset_name) lod_offset_names.add(lod_offset_name)
...@@ -853,12 +884,15 @@ class Op(object): ...@@ -853,12 +884,15 @@ class Op(object):
else: else:
# normal tensor # normal tensor
for idx, data_id in enumerate(data_ids): for idx, data_id in enumerate(data_ids):
left = input_offset_dict[data_id][0] start = input_offset_dict[data_id][0]
right = input_offset_dict[data_id][1] end = input_offset_dict[data_id][1]
midped_data_dict[data_id][name] = value[left:right] midped_data_dict[data_id][name] = value[start:end]
else: else:
midped_data_dict = preped_data_dict # midped_batch is list type for batch input
_LOGGER.debug("{} Succ process".format(op_info_prefix)) for idx, data_id in enumerate(data_ids):
start = input_offset_dict[data_id][0]
end = input_offset_dict[data_id][1]
midped_data_dict[data_id] = midped_batch[start:end]
return midped_data_dict, err_channeldata_dict return midped_data_dict, err_channeldata_dict
def _run_postprocess(self, parsed_data_dict, midped_data_dict, def _run_postprocess(self, parsed_data_dict, midped_data_dict,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册