diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 7c0c7fd003a38966a24fd116d8cfd3805aed6797..8f41a005f5b90e7edf11fad80b9b7eac89257160 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -34,6 +34,7 @@ from .pg_postprocess import PGPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess from .table_postprocess import TableMasterLabelDecode, TableLabelDecode +from .picodet_postprocess import PicoDetPostProcess def build_post_process(config, global_config=None): @@ -47,7 +48,7 @@ def build_post_process(config, global_config=None): 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode', 'TableMasterLabelDecode', 'SPINLabelDecode', 'DistillationSerPostProcess', 'DistillationRePostProcess', - 'VLLabelDecode' + 'VLLabelDecode', 'PicoDetPostProcess' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/picodet_postprocess.py b/ppocr/postprocess/picodet_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0aeb4387ea4778c1c6bec910262f1c4e136084 --- /dev/null +++ b/ppocr/postprocess/picodet_postprocess.py @@ -0,0 +1,250 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from scipy.special import softmax + + +def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200): + """ + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + iou_threshold: intersection over union threshold. + top_k: keep top_k results. If k <= 0, keep all the results. + candidate_size: only consider the candidates with the highest scores. + Returns: + picked: a list of indexes of the kept boxes + """ + scores = box_scores[:, -1] + boxes = box_scores[:, :-1] + picked = [] + indexes = np.argsort(scores) + indexes = indexes[-candidate_size:] + while len(indexes) > 0: + current = indexes[-1] + picked.append(current) + if 0 < top_k == len(picked) or len(indexes) == 1: + break + current_box = boxes[current, :] + indexes = indexes[:-1] + rest_boxes = boxes[indexes, :] + iou = iou_of( + rest_boxes, + np.expand_dims( + current_box, axis=0), ) + indexes = indexes[iou <= iou_threshold] + + return box_scores[picked, :] + + +def iou_of(boxes0, boxes1, eps=1e-5): + """Return intersection-over-union (Jaccard index) of boxes. + Args: + boxes0 (N, 4): ground truth boxes. + boxes1 (N or 1, 4): predicted boxes. + eps: a small number to avoid 0 as denominator. + Returns: + iou (N): IoU values. + """ + overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2]) + overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:]) + + overlap_area = area_of(overlap_left_top, overlap_right_bottom) + area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) + area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) + return overlap_area / (area0 + area1 - overlap_area + eps) + + +def area_of(left_top, right_bottom): + """Compute the areas of rectangles given two corners. + Args: + left_top (N, 2): left top corner. + right_bottom (N, 2): right bottom corner. + Returns: + area (N): return the area. + """ + hw = np.clip(right_bottom - left_top, 0.0, None) + return hw[..., 0] * hw[..., 1] + + +class PicoDetPostProcess(object): + """ + Args: + input_shape (int): network input image size + ori_shape (int): ori image shape of before padding + scale_factor (float): scale factor of ori image + enable_mkldnn (bool): whether to open MKLDNN + """ + + def __init__(self, + layout_dict_path, + strides=[8, 16, 32, 64], + score_threshold=0.4, + nms_threshold=0.5, + nms_top_k=1000, + keep_top_k=100): + self.labels = self.load_layout_dict(layout_dict_path) + self.strides = strides + self.score_threshold = score_threshold + self.nms_threshold = nms_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + + def load_layout_dict(self, layout_dict_path): + with open(layout_dict_path, 'r', encoding='utf-8') as fp: + labels = fp.readlines() + return [label.strip('\n') for label in labels] + + def warp_boxes(self, boxes, ori_shape): + """Apply transform to boxes + """ + width, height = ori_shape[1], ori_shape[0] + n = len(boxes) + if n: + # warp points + xy = np.ones((n * 4, 3)) + xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape( + n * 4, 2) # x1y1, x2y2, x1y2, x2y1 + # xy = xy @ M.T # transform + xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale + # create new boxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + xy = np.concatenate( + (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + # clip boxes + xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width) + xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height) + return xy.astype(np.float32) + else: + return boxes + + def img_info(self, ori_img, img): + origin_shape = ori_img.shape + resize_shape = img.shape + im_scale_y = resize_shape[2] / float(origin_shape[0]) + im_scale_x = resize_shape[3] / float(origin_shape[1]) + scale_factor = np.array([im_scale_y, im_scale_x], dtype=np.float32) + img_shape = np.array(img.shape[2:], dtype=np.float32) + + input_shape = np.array(img).astype('float32').shape[2:] + ori_shape = np.array((img_shape, )).astype('float32') + scale_factor = np.array((scale_factor, )).astype('float32') + return ori_shape, input_shape, scale_factor + + def __call__(self, ori_img, img, preds): + scores, raw_boxes = preds['boxes'], preds['boxes_num'] + batch_size = raw_boxes[0].shape[0] + reg_max = int(raw_boxes[0].shape[-1] / 4 - 1) + out_boxes_num = [] + out_boxes_list = [] + results = [] + ori_shape, input_shape, scale_factor = self.img_info(ori_img, img) + + for batch_id in range(batch_size): + # generate centers + decode_boxes = [] + select_scores = [] + for stride, box_distribute, score in zip(self.strides, raw_boxes, + scores): + box_distribute = box_distribute[batch_id] + score = score[batch_id] + # centers + fm_h = input_shape[0] / stride + fm_w = input_shape[1] / stride + h_range = np.arange(fm_h) + w_range = np.arange(fm_w) + ww, hh = np.meshgrid(w_range, h_range) + ct_row = (hh.flatten() + 0.5) * stride + ct_col = (ww.flatten() + 0.5) * stride + center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1) + + # box distribution to distance + reg_range = np.arange(reg_max + 1) + box_distance = box_distribute.reshape((-1, reg_max + 1)) + box_distance = softmax(box_distance, axis=1) + box_distance = box_distance * np.expand_dims(reg_range, axis=0) + box_distance = np.sum(box_distance, axis=1).reshape((-1, 4)) + box_distance = box_distance * stride + + # top K candidate + topk_idx = np.argsort(score.max(axis=1))[::-1] + topk_idx = topk_idx[:self.nms_top_k] + center = center[topk_idx] + score = score[topk_idx] + box_distance = box_distance[topk_idx] + + # decode box + decode_box = center + [-1, -1, 1, 1] * box_distance + + select_scores.append(score) + decode_boxes.append(decode_box) + + # nms + bboxes = np.concatenate(decode_boxes, axis=0) + confidences = np.concatenate(select_scores, axis=0) + picked_box_probs = [] + picked_labels = [] + for class_index in range(0, confidences.shape[1]): + probs = confidences[:, class_index] + mask = probs > self.score_threshold + probs = probs[mask] + if probs.shape[0] == 0: + continue + subset_boxes = bboxes[mask, :] + box_probs = np.concatenate( + [subset_boxes, probs.reshape(-1, 1)], axis=1) + box_probs = hard_nms( + box_probs, + iou_threshold=self.nms_threshold, + top_k=self.keep_top_k, ) + picked_box_probs.append(box_probs) + picked_labels.extend([class_index] * box_probs.shape[0]) + + if len(picked_box_probs) == 0: + out_boxes_list.append(np.empty((0, 4))) + out_boxes_num.append(0) + + else: + picked_box_probs = np.concatenate(picked_box_probs) + + # resize output boxes + picked_box_probs[:, :4] = self.warp_boxes( + picked_box_probs[:, :4], ori_shape[batch_id]) + im_scale = np.concatenate([ + scale_factor[batch_id][::-1], scale_factor[batch_id][::-1] + ]) + picked_box_probs[:, :4] /= im_scale + # clas score box + out_boxes_list.append( + np.concatenate( + [ + np.expand_dims( + np.array(picked_labels), + axis=-1), np.expand_dims( + picked_box_probs[:, 4], axis=-1), + picked_box_probs[:, :4] + ], + axis=1)) + out_boxes_num.append(len(picked_labels)) + + out_boxes_list = np.concatenate(out_boxes_list, axis=0) + out_boxes_num = np.asarray(out_boxes_num).astype(np.int32) + + for dt in out_boxes_list: + clsid, bbox, score = int(dt[0]), dt[2:], dt[1] + label = self.labels[clsid] + result = {'bbox': bbox, 'label': label} + results.append(result) + return results diff --git a/ppstructure/layout/predict_layout.py b/ppstructure/layout/predict_layout.py new file mode 100755 index 0000000000000000000000000000000000000000..a108d1005a7ed4b571071c07371c68e182f2dd09 --- /dev/null +++ b/ppstructure/layout/predict_layout.py @@ -0,0 +1,130 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import numpy as np +import time + +import tools.infer.utility as utility +from ppocr.data import create_operators, transform +from ppocr.postprocess import build_post_process +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppstructure.utility import parse_args +from picodet_postprocess import PicoDetPostProcess + +logger = get_logger() + +class LayoutPredictor(object): + def __init__(self, args): + pre_process_list = [{ + 'Resize': { + 'size': [800, 608] + } + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image'] + } + }] + postprocess_params = { + 'name': 'PicoDetPostProcess', + "layout_dict_path": args.layout_dict_path, + "score_threshold": args.score_threshold, + "nms_threshold": args.nms_threshold, + } + + self.preprocess_op = create_operators(pre_process_list) + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors, self.config = \ + utility.create_predictor(args, 'layout', logger) + + def __call__(self, img): + ori_im = img.copy() + data = {'image': img} + data = transform(data, self.preprocess_op) + img = data[0] + + if img is None: + return None, 0 + + img = np.expand_dims(img, axis=0) + img = img.copy() + + preds, elapse = 0, 1 + starttime = time.time() + + self.input_tensor.copy_from_cpu(img) + self.predictor.run() + + np_score_list, np_boxes_list = [], [] + output_names = self.predictor.get_output_names() + num_outs = int(len(output_names) / 2) + for out_idx in range(num_outs): + np_score_list.append( + self.predictor.get_output_handle(output_names[out_idx]) + .copy_to_cpu()) + np_boxes_list.append( + self.predictor.get_output_handle(output_names[ + out_idx + num_outs]).copy_to_cpu()) + preds = dict(boxes=np_score_list, boxes_num=np_boxes_list) + + post_preds = self.postprocess_op(ori_im, img, preds) + elapse = time.time() - starttime + return post_preds, elapse + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + layout_predictor = LayoutPredictor(args) + count = 0 + total_time = 0 + + repeats = 50 + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + + layout_res, elapse = layout_predictor(img) + + logger.info("result: {}".format(layout_res)) + + if count > 0: + total_time += elapse + count += 1 + logger.info("Predict time of {}: {}".format(image_file, elapse)) + + +if __name__ == "__main__": + main(parse_args()) diff --git a/ppstructure/utility.py b/ppstructure/utility.py index af0616239b167ff9ca5f6e1222015d51338d6bab..80090fd9d5d936a36d3fa613d3a2df760f6c0ee8 100644 --- a/ppstructure/utility.py +++ b/ppstructure/utility.py @@ -32,6 +32,19 @@ def init_args(): type=str, default="../ppocr/utils/dict/table_structure_dict.txt") # params for layout + parser.add_argument("--layout_model_dir", type=str) + parser.add_argument( + "--layout_dict_path", + type=str, + default="../../ppocr/utils/dict/layout_pubalynet_dict.txt") + parser.add_argument( + "--score_threshold", + type=float, + default=0.5, + help="Threshold of score.") + parser.add_argument( + "--nms_threshold", type=float, default=0.5, help="Threshold of nms.") + parser.add_argument( "--layout_path_model", type=str, @@ -87,7 +100,7 @@ def draw_structure_result(image, result, font_path): image = Image.fromarray(image) boxes, txts, scores = [], [], [] for region in result: - if region['type'] == 'Table': + if region['type'] == 'table': pass else: for text_result in region['res']: