未验证 提交 7b759798 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #2604 from littletomatodonkey/2.1/fix_db_post

add py impl of pr 2574
...@@ -34,12 +34,18 @@ class DBPostProcess(object): ...@@ -34,12 +34,18 @@ class DBPostProcess(object):
max_candidates=1000, max_candidates=1000,
unclip_ratio=2.0, unclip_ratio=2.0,
use_dilation=False, use_dilation=False,
score_mode="fast",
**kwargs): **kwargs):
self.thresh = thresh self.thresh = thresh
self.box_thresh = box_thresh self.box_thresh = box_thresh
self.max_candidates = max_candidates self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio self.unclip_ratio = unclip_ratio
self.min_size = 3 self.min_size = 3
self.score_mode = score_mode
assert score_mode in [
"slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
self.dilation_kernel = None if not use_dilation else np.array( self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]]) [[1, 1], [1, 1]])
...@@ -69,7 +75,10 @@ class DBPostProcess(object): ...@@ -69,7 +75,10 @@ class DBPostProcess(object):
if sside < self.min_size: if sside < self.min_size:
continue continue
points = np.array(points) points = np.array(points)
score = self.box_score_fast(pred, points.reshape(-1, 2)) if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score: if self.box_thresh > score:
continue continue
...@@ -120,6 +129,9 @@ class DBPostProcess(object): ...@@ -120,6 +129,9 @@ class DBPostProcess(object):
return box, min(bounding_box[1]) return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box): def box_score_fast(self, bitmap, _box):
'''
box_score_fast: use bbox mean score as the mean score
'''
h, w = bitmap.shape[:2] h, w = bitmap.shape[:2]
box = _box.copy() box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
...@@ -133,6 +145,27 @@ class DBPostProcess(object): ...@@ -133,6 +145,27 @@ class DBPostProcess(object):
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
'''
box_score_slow: use polyon mean score as the mean score
'''
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list): def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps'] pred = outs_dict['maps']
if isinstance(pred, paddle.Tensor): if isinstance(pred, paddle.Tensor):
......
...@@ -62,6 +62,7 @@ class TextDetector(object): ...@@ -62,6 +62,7 @@ class TextDetector(object):
postprocess_params["max_candidates"] = 1000 postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
elif self.det_algorithm == "EAST": elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess' postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh postprocess_params["score_thresh"] = args.det_east_score_thresh
......
...@@ -48,6 +48,7 @@ def parse_args(): ...@@ -48,6 +48,7 @@ def parse_args():
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--use_dilation", type=bool, default=False) parser.add_argument("--use_dilation", type=bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册