提交 ff7774a4 编写于 作者: T TeslaZhao

Support mini-batch in Pipeline mode

上级 567dc666
...@@ -79,6 +79,9 @@ class RecOp(Op): ...@@ -79,6 +79,9 @@ class RecOp(Op):
feed_list = [] feed_list = []
img_list = [] img_list = []
max_wh_ratio = 0 max_wh_ratio = 0
## One batch, the type of feed_data is dict.
"""
for i, dtbox in enumerate(dt_boxes): for i, dtbox in enumerate(dt_boxes):
boximg = self.get_rotate_crop_image(im, dt_boxes[i]) boximg = self.get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg) img_list.append(boximg)
...@@ -92,14 +95,73 @@ class RecOp(Op): ...@@ -92,14 +95,73 @@ class RecOp(Op):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio) norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img imgs[id] = norm_img
feed = {"image": imgs.copy()} feed = {"image": imgs.copy()}
return feed, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id): """
rec_res = self.ocr_reader.postprocess(fetch_dict, with_score=True)
res_lst = [] ## Many mini-batchs, the type of feed_data is list.
for res in rec_res: max_batch_size = 6 # len(dt_boxes)
res_lst.append(res[0])
res = {"res": str(res_lst)} # If max_batch_size is 0, skipping predict stage
if max_batch_size == 0:
return {}, True, None, ""
boxes_size = len(dt_boxes)
batch_size = boxes_size // max_batch_size
rem = boxes_size % max_batch_size
#_LOGGER.info("max_batch_len:{}, batch_size:{}, rem:{}, boxes_size:{}".format(max_batch_size, batch_size, rem, boxes_size))
for bt_idx in range(0, batch_size + 1):
imgs = None
boxes_num_in_one_batch = 0
if bt_idx == batch_size:
if rem == 0:
continue
else:
boxes_num_in_one_batch = rem
elif bt_idx < batch_size:
boxes_num_in_one_batch = max_batch_size
else:
_LOGGER.error("batch_size error, bt_idx={}, batch_size={}".
format(bt_idx, batch_size))
break
start = bt_idx * max_batch_size
end = start + boxes_num_in_one_batch
img_list = []
for box_idx in range(start, end):
boximg = self.get_rotate_crop_image(im, dt_boxes[box_idx])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
_, w, h = self.ocr_reader.resize_norm_img(img_list[0],
max_wh_ratio).shape
#_LOGGER.info("---- idx:{}, w:{}, h:{}".format(bt_idx, w, h))
imgs = np.zeros((boxes_num_in_one_batch, 3, w, h)).astype('float32')
for id, img in enumerate(img_list):
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
imgs[id] = norm_img
feed = {"image": imgs.copy()}
feed_list.append(feed)
#_LOGGER.info("feed_list : {}".format(feed_list))
return feed_list, False, None, ""
def postprocess(self, input_dicts, fetch_data, log_id):
res_list = []
if isinstance(fetch_data, dict):
if len(fetch_data) > 0:
rec_batch_res = self.ocr_reader.postprocess(
fetch_data, with_score=True)
for res in rec_batch_res:
res_list.append(res[0])
elif isinstance(fetch_data, list):
for one_batch in fetch_data:
one_batch_res = self.ocr_reader.postprocess(
one_batch, with_score=True)
for res in one_batch_res:
res_list.append(res[0])
res = {"res": str(res_list)}
return res, None, "" return res, None, ""
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册