diff --git a/configs/det/det_r50_db++_icdar15.yml b/configs/det/det_r50_db++_icdar15.yml index e0cd6012b660573a79ff013a1b6e2309074a3d86..2bb2cb8fd6cc999541cd10df7264ef09445295f4 100644 --- a/configs/det/det_r50_db++_icdar15.yml +++ b/configs/det/det_r50_db++_icdar15.yml @@ -54,6 +54,7 @@ PostProcess: box_thresh: 0.6 max_candidates: 1000 unclip_ratio: 1.5 + det_box_type: 'quad' # 'quad' or 'poly' Metric: name: DetMetric main_indicator: hmean diff --git a/configs/det/det_r50_db++_td_tr.yml b/configs/det/det_r50_db++_td_tr.yml index 65021bb66184381ba732980ac1b7a65d7bd3a355..f3b02aa21de225b99c9a4ac81d6b6a6bd898753c 100644 --- a/configs/det/det_r50_db++_td_tr.yml +++ b/configs/det/det_r50_db++_td_tr.yml @@ -54,6 +54,7 @@ PostProcess: box_thresh: 0.5 max_candidates: 1000 unclip_ratio: 1.5 + det_box_type: 'quad' # 'quad' or 'poly' Metric: name: DetMetric main_indicator: hmean diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 5e2553c3a09f8359d1641d2d49b1bfb84df695ac..dfe107816c195b36bf06568843b008bf66ff24c7 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -38,7 +38,7 @@ class DBPostProcess(object): unclip_ratio=2.0, use_dilation=False, score_mode="fast", - use_polygon=False, + box_type='quad', **kwargs): self.thresh = thresh self.box_thresh = box_thresh @@ -46,7 +46,7 @@ class DBPostProcess(object): self.unclip_ratio = unclip_ratio self.min_size = 3 self.score_mode = score_mode - self.use_polygon = use_polygon + self.box_type = box_type assert score_mode in [ "slow", "fast" ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) @@ -233,12 +233,14 @@ class DBPostProcess(object): self.dilation_kernel) else: mask = segmentation[batch_index] - if self.use_polygon is True: + if self.box_type == 'poly': boxes, scores = self.polygons_from_bitmap(pred[batch_index], mask, src_w, src_h) - else: + elif self.box_type == 'quad': boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, src_w, src_h) + else: + raise ValueError("box_type can only be one of ['quad', 'poly']") boxes_batch.append({'points': boxes}) return boxes_batch @@ -254,7 +256,7 @@ class DistillationDBPostProcess(object): unclip_ratio=1.5, use_dilation=False, score_mode="fast", - use_polygon=False, + box_type='quad', **kwargs): self.model_name = model_name self.key = key @@ -265,7 +267,7 @@ class DistillationDBPostProcess(object): unclip_ratio=unclip_ratio, use_dilation=use_dilation, score_mode=score_mode, - use_polygon=use_polygon) + box_type=box_type) def __call__(self, predicts, shape_list): results = {} diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 52c225d2b3913cf8c0dc88abcc07f7ccfd3cc914..1b4446a6717bccdc5b3de4ba70e058885479be84 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -67,6 +67,7 @@ class TextDetector(object): postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = args.use_dilation postprocess_params["score_mode"] = args.det_db_score_mode + postprocess_params["box_type"] = args.det_box_type elif self.det_algorithm == "DB++": postprocess_params['name'] = 'DBPostProcess' postprocess_params["thresh"] = args.det_db_thresh @@ -75,6 +76,7 @@ class TextDetector(object): postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = args.use_dilation postprocess_params["score_mode"] = args.det_db_score_mode + postprocess_params["box_type"] = args.det_box_type pre_process_list[1] = { 'NormalizeImage': { 'std': [1.0, 1.0, 1.0], @@ -98,8 +100,8 @@ class TextDetector(object): postprocess_params['name'] = 'SASTPostProcess' postprocess_params["score_thresh"] = args.det_sast_score_thresh postprocess_params["nms_thresh"] = args.det_sast_nms_thresh - self.det_sast_polygon = args.det_sast_polygon - if self.det_sast_polygon: + + if args.det_box_type == 'poly': postprocess_params["sample_pts_num"] = 6 postprocess_params["expand_scale"] = 1.2 postprocess_params["shrink_ratio_of_width"] = 0.2 @@ -107,14 +109,14 @@ class TextDetector(object): postprocess_params["sample_pts_num"] = 2 postprocess_params["expand_scale"] = 1.0 postprocess_params["shrink_ratio_of_width"] = 0.3 + elif self.det_algorithm == "PSE": postprocess_params['name'] = 'PSEPostProcess' postprocess_params["thresh"] = args.det_pse_thresh postprocess_params["box_thresh"] = args.det_pse_box_thresh postprocess_params["min_area"] = args.det_pse_min_area - postprocess_params["box_type"] = args.det_pse_box_type + postprocess_params["box_type"] = args.det_box_type postprocess_params["scale"] = args.det_pse_scale - self.det_pse_box_type = args.det_pse_box_type elif self.det_algorithm == "FCE": pre_process_list[0] = { 'DetResizeForTest': { @@ -126,7 +128,7 @@ class TextDetector(object): postprocess_params["alpha"] = args.alpha postprocess_params["beta"] = args.beta postprocess_params["fourier_degree"] = args.fourier_degree - postprocess_params["box_type"] = args.det_fce_box_type + postprocess_params["box_type"] = args.det_box_type elif self.det_algorithm == "CT": pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}} postprocess_params['name'] = 'CTPostProcess' @@ -190,6 +192,8 @@ class TextDetector(object): img_height, img_width = image_shape[0:2] dt_boxes_new = [] for box in dt_boxes: + if type(box) is list: + box = np.array(box) box = self.order_points_clockwise(box) box = self.clip_det_res(box, img_height, img_width) rect_width = int(np.linalg.norm(box[0] - box[1])) @@ -204,6 +208,8 @@ class TextDetector(object): img_height, img_width = image_shape[0:2] dt_boxes_new = [] for box in dt_boxes: + if type(box) is list: + box = np.array(box) box = self.clip_det_res(box, img_height, img_width) dt_boxes_new.append(box) dt_boxes = np.array(dt_boxes_new) @@ -262,12 +268,10 @@ class TextDetector(object): else: raise NotImplementedError - #self.predictor.try_shrink_memory() post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] - if (self.det_algorithm == "SAST" and self.det_sast_polygon) or ( - self.det_algorithm in ["PSE", "FCE", "CT"] and - self.postprocess_op.box_type == 'poly'): + + if self.args.det_box_type == 'poly': dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) else: dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index e555dbec1b314510aaaf6b31f1b35bf60fefa98e..f6a44e35a5b303d6ed30bf8057a62409aa690fef 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -50,6 +50,7 @@ def init_args(): parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_limit_side_len", type=float, default=960) parser.add_argument("--det_limit_type", type=str, default='max') + parser.add_argument("--det_box_type", type=str, default='quad') # DB parmas parser.add_argument("--det_db_thresh", type=float, default=0.3) @@ -58,6 +59,7 @@ def init_args(): parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--use_dilation", type=str2bool, default=False) parser.add_argument("--det_db_score_mode", type=str, default="fast") + # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) @@ -66,13 +68,11 @@ def init_args(): # SAST parmas parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) - parser.add_argument("--det_sast_polygon", type=str2bool, default=False) # PSE parmas parser.add_argument("--det_pse_thresh", type=float, default=0) parser.add_argument("--det_pse_box_thresh", type=float, default=0.85) parser.add_argument("--det_pse_min_area", type=float, default=16) - parser.add_argument("--det_pse_box_type", type=str, default='quad') parser.add_argument("--det_pse_scale", type=int, default=1) # FCE parmas @@ -80,7 +80,6 @@ def init_args(): parser.add_argument("--alpha", type=float, default=1.0) parser.add_argument("--beta", type=float, default=1.0) parser.add_argument("--fourier_degree", type=int, default=5) - parser.add_argument("--det_fce_box_type", type=str, default='poly') # params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')