未验证 提交 9df7730e 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #7840 from LDOUBLEV/dygraph

add polygon params
...@@ -54,6 +54,7 @@ PostProcess: ...@@ -54,6 +54,7 @@ PostProcess:
box_thresh: 0.6 box_thresh: 0.6
max_candidates: 1000 max_candidates: 1000
unclip_ratio: 1.5 unclip_ratio: 1.5
det_box_type: 'quad' # 'quad' or 'poly'
Metric: Metric:
name: DetMetric name: DetMetric
main_indicator: hmean main_indicator: hmean
......
...@@ -54,6 +54,7 @@ PostProcess: ...@@ -54,6 +54,7 @@ PostProcess:
box_thresh: 0.5 box_thresh: 0.5
max_candidates: 1000 max_candidates: 1000
unclip_ratio: 1.5 unclip_ratio: 1.5
det_box_type: 'quad' # 'quad' or 'poly'
Metric: Metric:
name: DetMetric name: DetMetric
main_indicator: hmean main_indicator: hmean
......
...@@ -38,7 +38,7 @@ class DBPostProcess(object): ...@@ -38,7 +38,7 @@ class DBPostProcess(object):
unclip_ratio=2.0, unclip_ratio=2.0,
use_dilation=False, use_dilation=False,
score_mode="fast", score_mode="fast",
use_polygon=False, box_type='quad',
**kwargs): **kwargs):
self.thresh = thresh self.thresh = thresh
self.box_thresh = box_thresh self.box_thresh = box_thresh
...@@ -46,7 +46,7 @@ class DBPostProcess(object): ...@@ -46,7 +46,7 @@ class DBPostProcess(object):
self.unclip_ratio = unclip_ratio self.unclip_ratio = unclip_ratio
self.min_size = 3 self.min_size = 3
self.score_mode = score_mode self.score_mode = score_mode
self.use_polygon = use_polygon self.box_type = box_type
assert score_mode in [ assert score_mode in [
"slow", "fast" "slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode) ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
...@@ -233,12 +233,14 @@ class DBPostProcess(object): ...@@ -233,12 +233,14 @@ class DBPostProcess(object):
self.dilation_kernel) self.dilation_kernel)
else: else:
mask = segmentation[batch_index] mask = segmentation[batch_index]
if self.use_polygon is True: if self.box_type == 'poly':
boxes, scores = self.polygons_from_bitmap(pred[batch_index], boxes, scores = self.polygons_from_bitmap(pred[batch_index],
mask, src_w, src_h) mask, src_w, src_h)
else: elif self.box_type == 'quad':
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h) src_w, src_h)
else:
raise ValueError("box_type can only be one of ['quad', 'poly']")
boxes_batch.append({'points': boxes}) boxes_batch.append({'points': boxes})
return boxes_batch return boxes_batch
...@@ -254,7 +256,7 @@ class DistillationDBPostProcess(object): ...@@ -254,7 +256,7 @@ class DistillationDBPostProcess(object):
unclip_ratio=1.5, unclip_ratio=1.5,
use_dilation=False, use_dilation=False,
score_mode="fast", score_mode="fast",
use_polygon=False, box_type='quad',
**kwargs): **kwargs):
self.model_name = model_name self.model_name = model_name
self.key = key self.key = key
...@@ -265,7 +267,7 @@ class DistillationDBPostProcess(object): ...@@ -265,7 +267,7 @@ class DistillationDBPostProcess(object):
unclip_ratio=unclip_ratio, unclip_ratio=unclip_ratio,
use_dilation=use_dilation, use_dilation=use_dilation,
score_mode=score_mode, score_mode=score_mode,
use_polygon=use_polygon) box_type=box_type)
def __call__(self, predicts, shape_list): def __call__(self, predicts, shape_list):
results = {} results = {}
......
...@@ -67,6 +67,7 @@ class TextDetector(object): ...@@ -67,6 +67,7 @@ class TextDetector(object):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode postprocess_params["score_mode"] = args.det_db_score_mode
postprocess_params["box_type"] = args.det_box_type
elif self.det_algorithm == "DB++": elif self.det_algorithm == "DB++":
postprocess_params['name'] = 'DBPostProcess' postprocess_params['name'] = 'DBPostProcess'
postprocess_params["thresh"] = args.det_db_thresh postprocess_params["thresh"] = args.det_db_thresh
...@@ -75,6 +76,7 @@ class TextDetector(object): ...@@ -75,6 +76,7 @@ class TextDetector(object):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode postprocess_params["score_mode"] = args.det_db_score_mode
postprocess_params["box_type"] = args.det_box_type
pre_process_list[1] = { pre_process_list[1] = {
'NormalizeImage': { 'NormalizeImage': {
'std': [1.0, 1.0, 1.0], 'std': [1.0, 1.0, 1.0],
...@@ -98,8 +100,8 @@ class TextDetector(object): ...@@ -98,8 +100,8 @@ class TextDetector(object):
postprocess_params['name'] = 'SASTPostProcess' postprocess_params['name'] = 'SASTPostProcess'
postprocess_params["score_thresh"] = args.det_sast_score_thresh postprocess_params["score_thresh"] = args.det_sast_score_thresh
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
self.det_sast_polygon = args.det_sast_polygon
if self.det_sast_polygon: if args.det_box_type == 'poly':
postprocess_params["sample_pts_num"] = 6 postprocess_params["sample_pts_num"] = 6
postprocess_params["expand_scale"] = 1.2 postprocess_params["expand_scale"] = 1.2
postprocess_params["shrink_ratio_of_width"] = 0.2 postprocess_params["shrink_ratio_of_width"] = 0.2
...@@ -107,14 +109,14 @@ class TextDetector(object): ...@@ -107,14 +109,14 @@ class TextDetector(object):
postprocess_params["sample_pts_num"] = 2 postprocess_params["sample_pts_num"] = 2
postprocess_params["expand_scale"] = 1.0 postprocess_params["expand_scale"] = 1.0
postprocess_params["shrink_ratio_of_width"] = 0.3 postprocess_params["shrink_ratio_of_width"] = 0.3
elif self.det_algorithm == "PSE": elif self.det_algorithm == "PSE":
postprocess_params['name'] = 'PSEPostProcess' postprocess_params['name'] = 'PSEPostProcess'
postprocess_params["thresh"] = args.det_pse_thresh postprocess_params["thresh"] = args.det_pse_thresh
postprocess_params["box_thresh"] = args.det_pse_box_thresh postprocess_params["box_thresh"] = args.det_pse_box_thresh
postprocess_params["min_area"] = args.det_pse_min_area postprocess_params["min_area"] = args.det_pse_min_area
postprocess_params["box_type"] = args.det_pse_box_type postprocess_params["box_type"] = args.det_box_type
postprocess_params["scale"] = args.det_pse_scale postprocess_params["scale"] = args.det_pse_scale
self.det_pse_box_type = args.det_pse_box_type
elif self.det_algorithm == "FCE": elif self.det_algorithm == "FCE":
pre_process_list[0] = { pre_process_list[0] = {
'DetResizeForTest': { 'DetResizeForTest': {
...@@ -126,7 +128,7 @@ class TextDetector(object): ...@@ -126,7 +128,7 @@ class TextDetector(object):
postprocess_params["alpha"] = args.alpha postprocess_params["alpha"] = args.alpha
postprocess_params["beta"] = args.beta postprocess_params["beta"] = args.beta
postprocess_params["fourier_degree"] = args.fourier_degree postprocess_params["fourier_degree"] = args.fourier_degree
postprocess_params["box_type"] = args.det_fce_box_type postprocess_params["box_type"] = args.det_box_type
elif self.det_algorithm == "CT": elif self.det_algorithm == "CT":
pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}} pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}}
postprocess_params['name'] = 'CTPostProcess' postprocess_params['name'] = 'CTPostProcess'
...@@ -190,6 +192,8 @@ class TextDetector(object): ...@@ -190,6 +192,8 @@ class TextDetector(object):
img_height, img_width = image_shape[0:2] img_height, img_width = image_shape[0:2]
dt_boxes_new = [] dt_boxes_new = []
for box in dt_boxes: for box in dt_boxes:
if type(box) is list:
box = np.array(box)
box = self.order_points_clockwise(box) box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width) box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1])) rect_width = int(np.linalg.norm(box[0] - box[1]))
...@@ -204,6 +208,8 @@ class TextDetector(object): ...@@ -204,6 +208,8 @@ class TextDetector(object):
img_height, img_width = image_shape[0:2] img_height, img_width = image_shape[0:2]
dt_boxes_new = [] dt_boxes_new = []
for box in dt_boxes: for box in dt_boxes:
if type(box) is list:
box = np.array(box)
box = self.clip_det_res(box, img_height, img_width) box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box) dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new) dt_boxes = np.array(dt_boxes_new)
...@@ -262,12 +268,10 @@ class TextDetector(object): ...@@ -262,12 +268,10 @@ class TextDetector(object):
else: else:
raise NotImplementedError raise NotImplementedError
#self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
self.det_algorithm in ["PSE", "FCE", "CT"] and if self.args.det_box_type == 'poly':
self.postprocess_op.box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else: else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
......
...@@ -50,6 +50,7 @@ def init_args(): ...@@ -50,6 +50,7 @@ def init_args():
parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_limit_side_len", type=float, default=960) parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max') parser.add_argument("--det_limit_type", type=str, default='max')
parser.add_argument("--det_box_type", type=str, default='quad')
# DB parmas # DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_thresh", type=float, default=0.3)
...@@ -58,6 +59,7 @@ def init_args(): ...@@ -58,6 +59,7 @@ def init_args():
parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--use_dilation", type=str2bool, default=False) parser.add_argument("--use_dilation", type=str2bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast") parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
...@@ -66,13 +68,11 @@ def init_args(): ...@@ -66,13 +68,11 @@ def init_args():
# SAST parmas # SAST parmas
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
# PSE parmas # PSE parmas
parser.add_argument("--det_pse_thresh", type=float, default=0) parser.add_argument("--det_pse_thresh", type=float, default=0)
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85) parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
parser.add_argument("--det_pse_min_area", type=float, default=16) parser.add_argument("--det_pse_min_area", type=float, default=16)
parser.add_argument("--det_pse_box_type", type=str, default='quad')
parser.add_argument("--det_pse_scale", type=int, default=1) parser.add_argument("--det_pse_scale", type=int, default=1)
# FCE parmas # FCE parmas
...@@ -80,7 +80,6 @@ def init_args(): ...@@ -80,7 +80,6 @@ def init_args():
parser.add_argument("--alpha", type=float, default=1.0) parser.add_argument("--alpha", type=float, default=1.0)
parser.add_argument("--beta", type=float, default=1.0) parser.add_argument("--beta", type=float, default=1.0)
parser.add_argument("--fourier_degree", type=int, default=5) parser.add_argument("--fourier_degree", type=int, default=5)
parser.add_argument("--det_fce_box_type", type=str, default='poly')
# params for text recognizer # params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet') parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册