diff --git a/configs/det/det_r50_db++_icdar15.yml b/configs/det/det_r50_db++_icdar15.yml index e0cd6012b660573a79ff013a1b6e2309074a3d86..2bb2cb8fd6cc999541cd10df7264ef09445295f4 100644 --- a/configs/det/det_r50_db++_icdar15.yml +++ b/configs/det/det_r50_db++_icdar15.yml @@ -54,6 +54,7 @@ PostProcess: box_thresh: 0.6 max_candidates: 1000 unclip_ratio: 1.5 + det_box_type: 'quad' # 'quad' or 'poly' Metric: name: DetMetric main_indicator: hmean diff --git a/configs/det/det_r50_db++_td_tr.yml b/configs/det/det_r50_db++_td_tr.yml index 65021bb66184381ba732980ac1b7a65d7bd3a355..f3b02aa21de225b99c9a4ac81d6b6a6bd898753c 100644 --- a/configs/det/det_r50_db++_td_tr.yml +++ b/configs/det/det_r50_db++_td_tr.yml @@ -54,6 +54,7 @@ PostProcess: box_thresh: 0.5 max_candidates: 1000 unclip_ratio: 1.5 + det_box_type: 'quad' # 'quad' or 'poly' Metric: name: DetMetric main_indicator: hmean diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 5e2553c3a09f8359d1641d2d49b1bfb84df695ac..cbe7bf4bd6ba33e13c0b0718ff2a585014ab6acc 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -38,7 +38,7 @@ class DBPostProcess(object): unclip_ratio=2.0, use_dilation=False, score_mode="fast", - use_polygon=False, + box_type='quad', **kwargs): self.thresh = thresh self.box_thresh = box_thresh @@ -46,7 +46,7 @@ class DBPostProcess(object): self.unclip_ratio = unclip_ratio self.min_size = 3 self.score_mode = score_mode - self.use_polygon = use_polygon + self.box_type = box_type assert score_mode in [ "slow", "fast" ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) @@ -233,12 +233,14 @@ class DBPostProcess(object): self.dilation_kernel) else: mask = segmentation[batch_index] - if self.use_polygon is True: + if self.box_type == 'quad': boxes, scores = self.polygons_from_bitmap(pred[batch_index], mask, src_w, src_h) - else: + elif self.box_type == 'poly': boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, src_w, src_h) + else: + raise ValueError("box_type can only be one of ['quad', 'poly']") boxes_batch.append({'points': boxes}) return boxes_batch @@ -254,7 +256,7 @@ class DistillationDBPostProcess(object): unclip_ratio=1.5, use_dilation=False, score_mode="fast", - use_polygon=False, + box_type='quad', **kwargs): self.model_name = model_name self.key = key @@ -265,7 +267,7 @@ class DistillationDBPostProcess(object): unclip_ratio=unclip_ratio, use_dilation=use_dilation, score_mode=score_mode, - use_polygon=use_polygon) + box_type=box_type) def __call__(self, predicts, shape_list): results = {} diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 98116afd9a21f297d429c26fc88881986c628866..fdfe5e5731cc2fcc5eb82fea373d1af08ac7c85d 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -67,7 +67,7 @@ class TextDetector(object): postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = args.use_dilation postprocess_params["score_mode"] = args.det_db_score_mode - postprocess_params["use_polygon"] = args.det_use_polygon + postprocess_params["box_type"] = args.det_box_type elif self.det_algorithm == "DB++": postprocess_params['name'] = 'DBPostProcess' postprocess_params["thresh"] = args.det_db_thresh @@ -76,7 +76,7 @@ class TextDetector(object): postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = args.use_dilation postprocess_params["score_mode"] = args.det_db_score_mode - postprocess_params["use_polygon"] = args.det_use_polygon + postprocess_params["box_type"] = args.det_box_type pre_process_list[1] = { 'NormalizeImage': { 'std': [1.0, 1.0, 1.0], @@ -100,8 +100,8 @@ class TextDetector(object): postprocess_params['name'] = 'SASTPostProcess' postprocess_params["score_thresh"] = args.det_sast_score_thresh postprocess_params["nms_thresh"] = args.det_sast_nms_thresh - self.det_sast_polygon = args.det_sast_polygon - if self.det_sast_polygon: + + if args.det_box_type == 'poly': postprocess_params["sample_pts_num"] = 6 postprocess_params["expand_scale"] = 1.2 postprocess_params["shrink_ratio_of_width"] = 0.2 @@ -109,14 +109,14 @@ class TextDetector(object): postprocess_params["sample_pts_num"] = 2 postprocess_params["expand_scale"] = 1.0 postprocess_params["shrink_ratio_of_width"] = 0.3 + elif self.det_algorithm == "PSE": postprocess_params['name'] = 'PSEPostProcess' postprocess_params["thresh"] = args.det_pse_thresh postprocess_params["box_thresh"] = args.det_pse_box_thresh postprocess_params["min_area"] = args.det_pse_min_area - postprocess_params["box_type"] = args.det_pse_box_type + postprocess_params["box_type"] = args.det_box_type postprocess_params["scale"] = args.det_pse_scale - self.det_pse_box_type = args.det_pse_box_type elif self.det_algorithm == "FCE": pre_process_list[0] = { 'DetResizeForTest': { @@ -128,7 +128,7 @@ class TextDetector(object): postprocess_params["alpha"] = args.alpha postprocess_params["beta"] = args.beta postprocess_params["fourier_degree"] = args.fourier_degree - postprocess_params["box_type"] = args.det_fce_box_type + postprocess_params["box_type"] = args.det_box_type elif self.det_algorithm == "CT": pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}} postprocess_params['name'] = 'CTPostProcess' @@ -269,11 +269,7 @@ class TextDetector(object): post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] - if (self.det_algorithm == "SAST" and self.det_sast_polygon) or ( - self.det_algorithm in ["PSE", "FCE", "CT"] and - self.postprocess_op.box_type == 'poly'): - dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) - elif 'DB' in self.det_algorithm and self.postprocess_op.use_polygon is True: + if self.args.det_box_type == 'poly': dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) else: dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 30ae712d1772503e60b65afbaffe3151bb3be9f3..f6a44e35a5b303d6ed30bf8057a62409aa690fef 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -50,6 +50,7 @@ def init_args(): parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_limit_side_len", type=float, default=960) parser.add_argument("--det_limit_type", type=str, default='max') + parser.add_argument("--det_box_type", type=str, default='quad') # DB parmas parser.add_argument("--det_db_thresh", type=float, default=0.3) @@ -58,7 +59,7 @@ def init_args(): parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--use_dilation", type=str2bool, default=False) parser.add_argument("--det_db_score_mode", type=str, default="fast") - parser.add_argument("--det_use_polygon", type=str2bool, default=False) + # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) @@ -67,13 +68,11 @@ def init_args(): # SAST parmas parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) - parser.add_argument("--det_sast_polygon", type=str2bool, default=False) # PSE parmas parser.add_argument("--det_pse_thresh", type=float, default=0) parser.add_argument("--det_pse_box_thresh", type=float, default=0.85) parser.add_argument("--det_pse_min_area", type=float, default=16) - parser.add_argument("--det_pse_box_type", type=str, default='quad') parser.add_argument("--det_pse_scale", type=int, default=1) # FCE parmas @@ -81,7 +80,6 @@ def init_args(): parser.add_argument("--alpha", type=float, default=1.0) parser.add_argument("--beta", type=float, default=1.0) parser.add_argument("--fourier_degree", type=int, default=5) - parser.add_argument("--det_fce_box_type", type=str, default='poly') # params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')