diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 91729e0a93c8b013cd734abc37cb6ac2b4960312..769ddbe23253ce58e2bccd46ef5074cc2a7d27da 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -34,12 +34,18 @@ class DBPostProcess(object): max_candidates=1000, unclip_ratio=2.0, use_dilation=False, + score_mode="fast", **kwargs): self.thresh = thresh self.box_thresh = box_thresh self.max_candidates = max_candidates self.unclip_ratio = unclip_ratio self.min_size = 3 + self.score_mode = score_mode + assert score_mode in [ + "slow", "fast" + ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) + self.dilation_kernel = None if not use_dilation else np.array( [[1, 1], [1, 1]]) @@ -69,7 +75,10 @@ class DBPostProcess(object): if sside < self.min_size: continue points = np.array(points) - score = self.box_score_fast(pred, points.reshape(-1, 2)) + if self.score_mode == "fast": + score = self.box_score_fast(pred, points.reshape(-1, 2)) + else: + score = self.box_score_slow(pred, contour) if self.box_thresh > score: continue @@ -120,6 +129,9 @@ class DBPostProcess(object): return box, min(bounding_box[1]) def box_score_fast(self, bitmap, _box): + ''' + box_score_fast: use bbox mean score as the mean score + ''' h, w = bitmap.shape[:2] box = _box.copy() xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) @@ -133,6 +145,27 @@ class DBPostProcess(object): cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + def box_score_slow(self, bitmap, contour): + ''' + box_score_slow: use polyon mean score as the mean score + ''' + h, w = bitmap.shape[:2] + contour = contour.copy() + contour = np.reshape(contour, (-1, 2)) + + xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) + xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) + ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) + ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + + contour[:, 0] = contour[:, 0] - xmin + contour[:, 1] = contour[:, 1] - ymin + + cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + def __call__(self, outs_dict, shape_list): pred = outs_dict['maps'] if isinstance(pred, paddle.Tensor): diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index f5ea0504f97f3e40853d431061f7086653f2628e..717a09ecdad9362b7ea6e556c1136f5791d2f06c 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -62,6 +62,7 @@ class TextDetector(object): postprocess_params["max_candidates"] = 1000 postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = args.use_dilation + postprocess_params["score_mode"] = args.det_db_score_mode elif self.det_algorithm == "EAST": postprocess_params['name'] = 'EASTPostProcess' postprocess_params["score_thresh"] = args.det_east_score_thresh diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 9fa51d80b41ae80abe206ee779accd44d49cebba..b5fe3ba9813b319d5baf7a9cf2ed0cb655e12021 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -48,6 +48,7 @@ def parse_args(): parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--use_dilation", type=bool, default=False) + parser.add_argument("--det_db_score_mode", type=str, default="fast") # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)