提交 ff7774a4 编写于 作者: T TeslaZhao

Support mini-batch in Pipeline mode

上级 567dc666
...@@ -79,6 +79,9 @@ class RecOp(Op): ...@@ -79,6 +79,9 @@ class RecOp(Op):
feed_list = [] feed_list = []
img_list = [] img_list = []
max_wh_ratio = 0 max_wh_ratio = 0
## One batch, the type of feed_data is dict.
"""
for i, dtbox in enumerate(dt_boxes): for i, dtbox in enumerate(dt_boxes):
boximg = self.get_rotate_crop_image(im, dt_boxes[i]) boximg = self.get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg) img_list.append(boximg)
...@@ -92,14 +95,73 @@ class RecOp(Op): ...@@ -92,14 +95,73 @@ class RecOp(Op):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img imgs[id] = norm_img
feed = {"image": imgs.copy()} feed = {"image": imgs.copy()}
return feed, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id): """
rec_res = self.ocr_reader.postprocess(fetch_dict, with_score=True)
res_lst = [] ## Many mini-batchs, the type of feed_data is list.
for res in rec_res: max_batch_size = 6 # len(dt_boxes)
res_lst.append(res[0])
res = {"res": str(res_lst)} # If max_batch_size is 0, skipping predict stage
if max_batch_size == 0:
return {}, True, None, ""
boxes_size = len(dt_boxes)
batch_size = boxes_size // max_batch_size
rem = boxes_size % max_batch_size
#_LOGGER.info("max_batch_len:{}, batch_size:{}, rem:{}, boxes_size:{}".format(max_batch_size, batch_size, rem, boxes_size))
for bt_idx in range(0, batch_size + 1):
imgs = None
boxes_num_in_one_batch = 0
if bt_idx == batch_size:
if rem == 0:
continue
else:
boxes_num_in_one_batch = rem
elif bt_idx < batch_size:
boxes_num_in_one_batch = max_batch_size
else:
_LOGGER.error("batch_size error, bt_idx={}, batch_size={}".
format(bt_idx, batch_size))
break
start = bt_idx * max_batch_size
end = start + boxes_num_in_one_batch
img_list = []
for box_idx in range(start, end):
boximg = self.get_rotate_crop_image(im, dt_boxes[box_idx])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
_, w, h = self.ocr_reader.resize_norm_img(img_list[0],
max_wh_ratio).shape
#_LOGGER.info("---- idx:{}, w:{}, h:{}".format(bt_idx, w, h))
imgs = np.zeros((boxes_num_in_one_batch, 3, w, h)).astype('float32')
for id, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img
feed = {"image": imgs.copy()}
feed_list.append(feed)
#_LOGGER.info("feed_list : {}".format(feed_list))
return feed_list, False, None, ""
def postprocess(self, input_dicts, fetch_data, log_id):
res_list = []
if isinstance(fetch_data, dict):
if len(fetch_data) > 0:
rec_batch_res = self.ocr_reader.postprocess(
fetch_data, with_score=True)
for res in rec_batch_res:
res_list.append(res[0])
elif isinstance(fetch_data, list):
for one_batch in fetch_data:
one_batch_res = self.ocr_reader.postprocess(
one_batch, with_score=True)
for res in one_batch_res:
res_list.append(res[0])
res = {"res": str(res_list)}
return res, None, "" return res, None, ""
......
...@@ -400,7 +400,7 @@ class Op(object): ...@@ -400,7 +400,7 @@ class Op(object):
log_id: global unique id for RTT, 0 default log_id: global unique id for RTT, 0 default
Return: Return:
input_dict: data for process stage output_data: data for process stage
is_skip_process: skip process stage or not, False default is_skip_process: skip process stage or not, False default
prod_errcode: None default, otherwise, product errores occured. prod_errcode: None default, otherwise, product errores occured.
It is handled in the same way as exception. It is handled in the same way as exception.
...@@ -453,20 +453,23 @@ class Op(object): ...@@ -453,20 +453,23 @@ class Op(object):
call_result.pop("serving_status_code") call_result.pop("serving_status_code")
return call_result return call_result
def postprocess(self, input_dict, fetch_dict, log_id=0): def postprocess(self, input_data, fetch_data, log_id=0):
""" """
In postprocess stage, assemble data for next op or output. In postprocess stage, assemble data for next op or output.
Args: Args:
input_dict: data returned in preprocess stage. input_data: data returned in preprocess stage, dict(for single predict) or list(for batch predict)
fetch_dict: data returned in process stage. fetch_data: data returned in process stage, dict(for single predict) or list(for batch predict)
log_id: logid, 0 default log_id: logid, 0 default
Returns: Returns:
fetch_dict: return fetch_dict default fetch_dict: fetch result must be dict type.
prod_errcode: None default, otherwise, product errores occured. prod_errcode: None default, otherwise, product errores occured.
It is handled in the same way as exception. It is handled in the same way as exception.
prod_errinfo: "" default prod_errinfo: "" default
""" """
fetch_dict = {}
if isinstance(fetch_data, dict):
fetch_dict = fetch_data
return fetch_dict, None, "" return fetch_dict, None, ""
def _parse_channeldata(self, channeldata_dict): def _parse_channeldata(self, channeldata_dict):
...@@ -685,180 +688,211 @@ class Op(object): ...@@ -685,180 +688,211 @@ class Op(object):
_LOGGER.debug("{} Running process".format(op_info_prefix)) _LOGGER.debug("{} Running process".format(op_info_prefix))
midped_data_dict = collections.OrderedDict() midped_data_dict = collections.OrderedDict()
err_channeldata_dict = collections.OrderedDict() err_channeldata_dict = collections.OrderedDict()
### if (batch_num == 1 && skip == True) ,then skip the process stage.
is_skip_process = False is_skip_process = False
data_ids = list(preped_data_dict.keys()) data_ids = list(preped_data_dict.keys())
# skip process stage
if len(data_ids) == 1 and skip_process_dict.get(data_ids[0]) == True: if len(data_ids) == 1 and skip_process_dict.get(data_ids[0]) == True:
is_skip_process = True is_skip_process = True
_LOGGER.info("(data_id={} log_id={}) skip process stage".format( if self.with_serving is False or is_skip_process is True:
data_ids[0], logid_dict.get(data_ids[0]))) midped_data_dict = preped_data_dict
_LOGGER.warning("(data_id={} log_id={}) OP={} skip process stage. " \
if self.with_serving is True and is_skip_process is False: "with_serving={}, is_skip_process={}".format(data_ids[0],
# use typical_logid to mark batch data logid_dict.get(data_ids[0]), self.name, self.with_serving,
typical_logid = data_ids[0] is_skip_process))
if len(data_ids) != 1: return midped_data_dict, err_channeldata_dict
for data_id in data_ids:
_LOGGER.info( # use typical_logid to mark batch data
"(data_id={} logid={}) {} During access to PaddleServingService," # data_ids is one self-increasing unique key.
" we selected logid={} (from batch: {}) as a " typical_logid = data_ids[0]
"representative for logging.".format( if len(data_ids) != 1:
data_id, for data_id in data_ids:
logid_dict.get(data_id), op_info_prefix, _LOGGER.info(
typical_logid, data_ids)) "(data_id={} logid={}) Auto-batching is On Op={}!!" \
"We selected logid={} (from batch: {}) as a " \
# combine samples to batch "representative for logging.".format(
one_input = preped_data_dict[data_ids[0]] data_id, logid_dict.get(data_id), self.name,
feed_batch = [] typical_logid, data_ids))
feed_dict = {}
input_offset = None one_input = preped_data_dict[data_ids[0]]
cur_offset = 0 feed_batch = []
input_offset_dict = {} feed_dict = {}
cur_offset = 0
if isinstance(one_input, dict): input_offset_dict = {}
# sample input batch_input = False
if len(data_ids) == 1:
feed_batch = [ if isinstance(one_input, dict):
preped_data_dict[data_id] for data_id in data_ids # For dict type, data structure is dict.
] # Merge multiple dicts for data_ids into one dict.
else: # feed_batch is the input param of predict func.
for data_id in data_ids: # input_offset_dict is used for data restration[data_ids]
for key, val in preped_data_dict[data_id].items(): if len(data_ids) == 1:
has_val = feed_dict.get(key) feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
if has_val is None: else:
feed_dict[key] = val
continue
# merge 2 np.arrray
if isinstance(val, np.ndarray):
feed_dict[key] = np.append(
feed_dict[key], val, axis=0)
feed_batch.append(feed_dict)
for data_id in data_ids: for data_id in data_ids:
start = cur_offset
for key, val in preped_data_dict[data_id].items(): for key, val in preped_data_dict[data_id].items():
if isinstance(val, (list, np.ndarray)): has_val = feed_dict.get(key)
cur_offset += len(val) if has_val is None:
else: feed_dict[key] = val
cur_offset += 1 continue
break # merge 2 np.arrray
input_offset_dict[data_id] = [start, cur_offset] if isinstance(val, np.ndarray):
elif isinstance(one_input, list): feed_dict[key] = np.append(
# batch input feed_dict[key], val, axis=0)
input_offset = [0] feed_batch.append(feed_dict)
for data_id in data_ids:
batch_input = preped_data_dict[data_id] for data_id in data_ids:
offset = input_offset[-1] + len(batch_input) start = cur_offset
feed_batch += batch_input for key, val in preped_data_dict[data_id].items():
input_offset.append(offset) if isinstance(val, (list, np.ndarray)):
else: cur_offset += len(val)
_LOGGER.critical( else:
"(data_id={} log_id={}){} Failed to process: expect input type is dict(sample" cur_offset += 1
" input) or list(batch input), but get {}".format(data_ids[ break
0], typical_logid, op_info_prefix, type(one_input))) input_offset_dict[data_id] = [start, cur_offset]
os._exit(-1) elif isinstance(one_input, list):
# For list type, data structure of one_input is [dict, dict, ...]
# Data structure of feed_batch is [dict1_1, dict1_2, dict2_1, ...]
# Data structure of input_offset_dict is { data_id : [start, end] }
batch_input = True
for data_id in data_ids:
feed_batch.extend(preped_data_dict[data_id])
data_size = len(preped_data_dict[data_id])
start = cur_offset
cur_offset = start + data_size
input_offset_dict[data_id] = [start, cur_offset]
else:
_LOGGER.critical(
"(data_id={} log_id={}){} Failed to process: expect input type is dict"
" or list(batch input), but get {}".format(data_ids[
0], typical_logid, op_info_prefix, type(one_input)))
for data_id in data_ids:
error_code = ChannelDataErrcode.TYPE_ERROR.value
error_info = "expect input type is dict or list, but get {}".format(
type(one_input))
err_channeldata_dict[data_id] = ChannelData(
error_code=error_code,
error_info=error_info,
data_id=data_id,
log_id=logid_dict.get(data_id))
return midped_data_dict, err_channeldata_dict
midped_batch = None midped_batch = None
error_code = ChannelDataErrcode.OK.value error_code = ChannelDataErrcode.OK.value
if self._timeout <= 0: if self._timeout <= 0:
try: # No retry
try:
if batch_input is False:
midped_batch = self.process(feed_batch, typical_logid) midped_batch = self.process(feed_batch, typical_logid)
except Exception as e: else:
error_code = ChannelDataErrcode.UNKNOW.value midped_batch = []
error_info = "(data_id={} log_id={}) {} Failed to process(batch: {}): {}".format( for idx in range(len(feed_batch)):
data_ids[0], typical_logid, op_info_prefix, data_ids, e) predict_res = self.process([feed_batch[idx]],
_LOGGER.error(error_info, exc_info=True) typical_logid)
else: midped_batch.append(predict_res)
# retry N times configed in yaml files. except Exception as e:
for i in range(self._retry): error_code = ChannelDataErrcode.UNKNOW.value
try: error_info = "(data_id={} log_id={}) {} Failed to process(batch: {}): {}".format(
# time out for each process data_ids[0], typical_logid, op_info_prefix, data_ids, e)
_LOGGER.error(error_info, exc_info=True)
else:
# retry N times configed in yaml files.
for i in range(self._retry):
try:
# time out for each process
if batch_input is False:
midped_batch = func_timeout.func_timeout( midped_batch = func_timeout.func_timeout(
self._timeout, self._timeout,
self.process, self.process,
args=(feed_batch, typical_logid)) args=(feed_batch, typical_logid))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
error_code = ChannelDataErrcode.TIMEOUT.value
error_info = "(log_id={}) {} Failed to process(batch: {}): " \
"exceeded retry count.".format(
typical_logid, op_info_prefix, data_ids)
_LOGGER.error(error_info)
else:
_LOGGER.warning(
"(log_id={}) {} Failed to process(batch: {}): timeout,"
" and retrying({}/{})...".format(
typical_logid, op_info_prefix, data_ids, i +
1, self._retry))
except Exception as e:
error_code = ChannelDataErrcode.UNKNOW.value
error_info = "(log_id={}) {} Failed to process(batch: {}): {}".format(
typical_logid, op_info_prefix, data_ids, e)
_LOGGER.error(error_info, exc_info=True)
break
else: else:
break midped_batch = []
if error_code != ChannelDataErrcode.OK.value: for idx in range(len(feed_batch)):
for data_id in data_ids: predict_res = func_timeout.func_timeout(
err_channeldata_dict[data_id] = ChannelData( self._timeout,
error_code=error_code, self.process,
error_info=error_info, args=([feed_batch[idx]], typical_logid))
data_id=data_id, midped_batch[idx].append(predict_res)
log_id=logid_dict.get(data_id))
elif midped_batch is None: except func_timeout.FunctionTimedOut as e:
# op client return None if i + 1 >= self._retry:
error_info = "(log_id={}) {} Failed to predict, please check if " \ error_code = ChannelDataErrcode.TIMEOUT.value
"PaddleServingService is working properly.".format( error_info = "(log_id={}) {} Failed to process(batch: {}): " \
typical_logid, op_info_prefix) "exceeded retry count.".format(typical_logid, op_info_prefix, data_ids)
_LOGGER.error(error_info) _LOGGER.error(error_info)
for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData(
error_code=ChannelDataErrcode.CLIENT_ERROR.value,
error_info=error_info,
data_id=data_id,
log_id=logid_dict.get(data_id))
else:
# transform np format to dict format
var_names = midped_batch.keys()
lod_var_names = set()
lod_offset_names = set()
for name in var_names:
lod_offset_name = "{}.lod".format(name)
if lod_offset_name in var_names:
_LOGGER.debug(
"(log_id={}) {} {} is LodTensor. lod_offset_name:{}".
format(typical_logid, op_info_prefix, name,
lod_offset_name))
lod_var_names.add(name)
lod_offset_names.add(lod_offset_name)
for idx, data_id in enumerate(data_ids):
midped_data_dict[data_id] = {}
for name, value in midped_batch.items():
if name in lod_offset_names:
continue
if name in lod_var_names:
# lodtensor
lod_offset_name = "{}.lod".format(name)
lod_offset = midped_batch[lod_offset_name]
for idx, data_id in enumerate(data_ids):
data_offset_left = input_offset_dict[data_id][0]
data_offset_right = input_offset_dict[data_id][1]
lod_offset_left = lod_offset[data_offset_left]
lod_offset_right = lod_offset[data_offset_right]
midped_data_dict[data_id][name] = value[
lod_offset_left:lod_offset_right]
midped_data_dict[data_id][lod_offset_name] = \
lod_offset[data_offset_left:data_offset_right + 1] - lod_offset[data_offset_left]
else: else:
# normal tensor _LOGGER.warning(
for idx, data_id in enumerate(data_ids): "(log_id={}) {} Failed to process(batch: {}): timeout,"
left = input_offset_dict[data_id][0] " and retrying({}/{})...".format(
right = input_offset_dict[data_id][1] typical_logid, op_info_prefix, data_ids, i + 1,
midped_data_dict[data_id][name] = value[left:right] self._retry))
except Exception as e:
error_code = ChannelDataErrcode.UNKNOW.value
error_info = "(log_id={}) {} Failed to process(batch: {}): {}".format(
typical_logid, op_info_prefix, data_ids, e)
_LOGGER.error(error_info, exc_info=True)
break
else:
break
# 2 kinds of errors
if error_code != ChannelDataErrcode.OK.value or midped_batch is None:
error_info = "(log_id={}) {} failed to predict.".format(
typical_logid, self.name)
_LOGGER.error(error_info)
for data_id in data_ids:
err_channeldata_dict[data_id] = ChannelData(
error_code=ChannelDataErrcode.CLIENT_ERROR.value,
error_info=error_info,
data_id=data_id,
log_id=logid_dict.get(data_id))
return midped_data_dict, err_channeldata_dict
# Split batch infer result to each data_ids
if batch_input is False:
var_names = midped_batch.keys()
lod_var_names = set()
lod_offset_names = set()
# midped_batch is dict type for single input
for name in var_names:
lod_offset_name = "{}.lod".format(name)
if lod_offset_name in var_names:
_LOGGER.debug("(log_id={}) {} {} is LodTensor".format(
typical_logid, op_info_prefix, name))
lod_var_names.add(name)
lod_offset_names.add(lod_offset_name)
for idx, data_id in enumerate(data_ids):
midped_data_dict[data_id] = {}
for name, value in midped_batch.items():
if name in lod_offset_names:
continue
if name in lod_var_names:
# lodtensor
lod_offset_name = "{}.lod".format(name)
lod_offset = midped_batch[lod_offset_name]
for idx, data_id in enumerate(data_ids):
data_offset_left = input_offset_dict[data_id][0]
data_offset_right = input_offset_dict[data_id][1]
lod_offset_left = lod_offset[data_offset_left]
lod_offset_right = lod_offset[data_offset_right]
midped_data_dict[data_id][name] = value[
lod_offset_left:lod_offset_right]
midped_data_dict[data_id][lod_offset_name] = \
lod_offset[data_offset_left:data_offset_right + 1] - lod_offset[data_offset_left]
else:
# normal tensor
for idx, data_id in enumerate(data_ids):
start = input_offset_dict[data_id][0]
end = input_offset_dict[data_id][1]
midped_data_dict[data_id][name] = value[start:end]
else: else:
midped_data_dict = preped_data_dict # midped_batch is list type for batch input
_LOGGER.debug("{} Succ process".format(op_info_prefix)) for idx, data_id in enumerate(data_ids):
start = input_offset_dict[data_id][0]
end = input_offset_dict[data_id][1]
midped_data_dict[data_id] = midped_batch[start:end]
return midped_data_dict, err_channeldata_dict return midped_data_dict, err_channeldata_dict
def _run_postprocess(self, parsed_data_dict, midped_data_dict, def _run_postprocess(self, parsed_data_dict, midped_data_dict,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册