未验证 提交 9c2c5e80 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #5700 from WenmuZhou/fix_cpp_lite_android

[Detection]add FECNet inference
...@@ -56,7 +56,7 @@ PostProcess: ...@@ -56,7 +56,7 @@ PostProcess:
thresh: 0 thresh: 0
box_thresh: 0.85 box_thresh: 0.85
min_area: 16 min_area: 16
box_type: box # 'box' or 'poly' box_type: quad # 'quad' or 'poly'
scale: 1 scale: 1
Metric: Metric:
......
...@@ -27,14 +27,11 @@ Architecture: ...@@ -27,14 +27,11 @@ Architecture:
out_indices: [1,2,3] out_indices: [1,2,3]
Neck: Neck:
name: FCEFPN name: FCEFPN
in_channels: [512, 1024, 2048]
out_channels: 256 out_channels: 256
has_extra_convs: False has_extra_convs: False
extra_stage: 0 extra_stage: 0
Head: Head:
name: FCEHead name: FCEHead
in_channels: 256
scales: [8, 16, 32]
fourier_degree: 5 fourier_degree: 5
Loss: Loss:
name: FCELoss name: FCELoss
...@@ -57,6 +54,7 @@ PostProcess: ...@@ -57,6 +54,7 @@ PostProcess:
alpha: 1.0 alpha: 1.0
beta: 1.0 beta: 1.0
fourier_degree: 5 fourier_degree: 5
box_type: 'poly'
Metric: Metric:
name: DetFCEMetric name: DetFCEMetric
...@@ -123,8 +121,8 @@ Eval: ...@@ -123,8 +121,8 @@ Eval:
ignore_orientation: True ignore_orientation: True
- DetLabelEncode: # Class handling label - DetLabelEncode: # Class handling label
- DetResizeForTest: - DetResizeForTest:
# resize_long: 1280 limit_type: 'min'
rescale_img: [1080, 736] limit_side_len: 736
- NormalizeImage: - NormalizeImage:
scale: 1./255. scale: 1./255.
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
......
...@@ -55,7 +55,7 @@ PostProcess: ...@@ -55,7 +55,7 @@ PostProcess:
thresh: 0 thresh: 0
box_thresh: 0.85 box_thresh: 0.85
min_area: 16 min_area: 16
box_type: box # 'box' or 'poly' box_type: quad # 'quad' or 'poly'
scale: 1 scale: 1
Metric: Metric:
......
...@@ -43,13 +43,12 @@ class FCEHead(nn.Layer): ...@@ -43,13 +43,12 @@ class FCEHead(nn.Layer):
fourier_degree (int) : The maximum Fourier transform degree k. fourier_degree (int) : The maximum Fourier transform degree k.
""" """
def __init__(self, in_channels, scales, fourier_degree=5): def __init__(self, in_channels, fourier_degree=5):
super().__init__() super().__init__()
assert isinstance(in_channels, int) assert isinstance(in_channels, int)
self.downsample_ratio = 1.0 self.downsample_ratio = 1.0
self.in_channels = in_channels self.in_channels = in_channels
self.scales = scales
self.fourier_degree = fourier_degree self.fourier_degree = fourier_degree
self.out_channels_cls = 4 self.out_channels_cls = 4
self.out_channels_reg = (2 * self.fourier_degree + 1) * 2 self.out_channels_reg = (2 * self.fourier_degree + 1) * 2
...@@ -82,9 +81,7 @@ class FCEHead(nn.Layer): ...@@ -82,9 +81,7 @@ class FCEHead(nn.Layer):
def forward(self, feats, targets=None): def forward(self, feats, targets=None):
cls_res, reg_res = multi_apply(self.forward_single, feats) cls_res, reg_res = multi_apply(self.forward_single, feats)
level_num = len(cls_res) level_num = len(cls_res)
# import pdb;pdb.set_trace()
outs = {} outs = {}
if not self.training: if not self.training:
for i in range(level_num): for i in range(level_num):
tr_pred = F.softmax(cls_res[i][:, 0:2, :, :], axis=1) tr_pred = F.softmax(cls_res[i][:, 0:2, :, :], axis=1)
......
...@@ -74,7 +74,7 @@ class FCEPostProcess(object): ...@@ -74,7 +74,7 @@ class FCEPostProcess(object):
nms_thr=0.1, nms_thr=0.1,
alpha=1.0, alpha=1.0,
beta=1.0, beta=1.0,
text_repr_type='poly', box_type='poly',
**kwargs): **kwargs):
self.scales = scales self.scales = scales
...@@ -85,7 +85,7 @@ class FCEPostProcess(object): ...@@ -85,7 +85,7 @@ class FCEPostProcess(object):
self.nms_thr = nms_thr self.nms_thr = nms_thr
self.alpha = alpha self.alpha = alpha
self.beta = beta self.beta = beta
self.text_repr_type = text_repr_type self.box_type = box_type
def __call__(self, preds, shape_list): def __call__(self, preds, shape_list):
score_maps = [] score_maps = []
...@@ -149,7 +149,7 @@ class FCEPostProcess(object): ...@@ -149,7 +149,7 @@ class FCEPostProcess(object):
scale=scale, scale=scale,
alpha=self.alpha, alpha=self.alpha,
beta=self.beta, beta=self.beta,
text_repr_type=self.text_repr_type, box_type=self.box_type,
score_thr=self.score_thr, score_thr=self.score_thr,
nms_thr=self.nms_thr) nms_thr=self.nms_thr)
...@@ -160,7 +160,7 @@ class FCEPostProcess(object): ...@@ -160,7 +160,7 @@ class FCEPostProcess(object):
scale, scale,
alpha=1.0, alpha=1.0,
beta=2.0, beta=2.0,
text_repr_type='poly', box_type='poly',
score_thr=0.3, score_thr=0.3,
nms_thr=0.1): nms_thr=0.1):
"""Decoding predictions of FCENet to instances. """Decoding predictions of FCENet to instances.
...@@ -175,7 +175,7 @@ class FCEPostProcess(object): ...@@ -175,7 +175,7 @@ class FCEPostProcess(object):
= (Score_{text region} ^ alpha) = (Score_{text region} ^ alpha)
* (Score_{text center region}^ beta) * (Score_{text center region}^ beta)
beta (float) : The parameter to calculate final score. beta (float) : The parameter to calculate final score.
text_repr_type (str): Boundary encoding type 'poly' or 'quad'. box_type (str): Boundary encoding type 'poly' or 'quad'.
score_thr (float) : The threshold used to filter out the final score_thr (float) : The threshold used to filter out the final
candidates. candidates.
nms_thr (float) : The threshold of nms. nms_thr (float) : The threshold of nms.
...@@ -186,7 +186,7 @@ class FCEPostProcess(object): ...@@ -186,7 +186,7 @@ class FCEPostProcess(object):
""" """
assert isinstance(preds, list) assert isinstance(preds, list)
assert len(preds) == 2 assert len(preds) == 2
assert text_repr_type in ['poly', 'quad'] assert box_type in ['poly', 'quad']
cls_pred = preds[0][0] cls_pred = preds[0][0]
tr_pred = cls_pred[0:2] tr_pred = cls_pred[0:2]
...@@ -228,7 +228,7 @@ class FCEPostProcess(object): ...@@ -228,7 +228,7 @@ class FCEPostProcess(object):
boundaries = poly_nms(boundaries, nms_thr) boundaries = poly_nms(boundaries, nms_thr)
if text_repr_type == 'quad': if box_type == 'quad':
new_boundaries = [] new_boundaries = []
for boundary in boundaries: for boundary in boundaries:
poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32) poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32)
...@@ -236,5 +236,6 @@ class FCEPostProcess(object): ...@@ -236,5 +236,6 @@ class FCEPostProcess(object):
points = cv2.boxPoints(cv2.minAreaRect(poly)) points = cv2.boxPoints(cv2.minAreaRect(poly))
points = np.int0(points) points = np.int0(points)
new_boundaries.append(points.reshape(-1).tolist() + [score]) new_boundaries.append(points.reshape(-1).tolist() + [score])
boundaries = new_boundaries
return boundaries return boundaries
...@@ -37,10 +37,10 @@ class PSEPostProcess(object): ...@@ -37,10 +37,10 @@ class PSEPostProcess(object):
thresh=0.5, thresh=0.5,
box_thresh=0.85, box_thresh=0.85,
min_area=16, min_area=16,
box_type='box', box_type='quad',
scale=4, scale=4,
**kwargs): **kwargs):
assert box_type in ['box', 'poly'], 'Only box and poly is supported' assert box_type in ['quad', 'poly'], 'Only quad and poly is supported'
self.thresh = thresh self.thresh = thresh
self.box_thresh = box_thresh self.box_thresh = box_thresh
self.min_area = min_area self.min_area = min_area
...@@ -95,7 +95,7 @@ class PSEPostProcess(object): ...@@ -95,7 +95,7 @@ class PSEPostProcess(object):
label[ind] = 0 label[ind] = 0
continue continue
if self.box_type == 'box': if self.box_type == 'quad':
rect = cv2.minAreaRect(points) rect = cv2.minAreaRect(points)
bbox = cv2.boxPoints(rect) bbox = cv2.boxPoints(rect)
elif self.box_type == 'poly': elif self.box_type == 'poly':
......
...@@ -98,6 +98,18 @@ class TextDetector(object): ...@@ -98,6 +98,18 @@ class TextDetector(object):
postprocess_params["box_type"] = args.det_pse_box_type postprocess_params["box_type"] = args.det_pse_box_type
postprocess_params["scale"] = args.det_pse_scale postprocess_params["scale"] = args.det_pse_scale
self.det_pse_box_type = args.det_pse_box_type self.det_pse_box_type = args.det_pse_box_type
elif self.det_algorithm == "FCE":
pre_process_list[0] = {
'DetResizeForTest': {
'rescale_img': [1080, 736]
}
}
postprocess_params['name'] = 'FCEPostProcess'
postprocess_params["scales"] = args.scales
postprocess_params["alpha"] = args.alpha
postprocess_params["beta"] = args.beta
postprocess_params["fourier_degree"] = args.fourier_degree
postprocess_params["box_type"] = args.det_fce_box_type
else: else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0) sys.exit(0)
...@@ -234,15 +246,18 @@ class TextDetector(object): ...@@ -234,15 +246,18 @@ class TextDetector(object):
preds['f_tvo'] = outputs[3] preds['f_tvo'] = outputs[3]
elif self.det_algorithm in ['DB', 'PSE']: elif self.det_algorithm in ['DB', 'PSE']:
preds['maps'] = outputs[0] preds['maps'] = outputs[0]
elif self.det_algorithm == 'FCE':
for i, output in enumerate(outputs):
preds['level_{}'.format(i)] = output
else: else:
raise NotImplementedError raise NotImplementedError
#self.predictor.try_shrink_memory() #self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
if (self.det_algorithm == "SAST" and if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
self.det_sast_polygon) or (self.det_algorithm == "PSE" and self.det_algorithm in ["PSE", "FCE"] and
self.det_pse_box_type == 'poly'): self.postprocess_op.box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else: else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
......
...@@ -68,9 +68,16 @@ def init_args(): ...@@ -68,9 +68,16 @@ def init_args():
parser.add_argument("--det_pse_thresh", type=float, default=0) parser.add_argument("--det_pse_thresh", type=float, default=0)
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85) parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
parser.add_argument("--det_pse_min_area", type=float, default=16) parser.add_argument("--det_pse_min_area", type=float, default=16)
parser.add_argument("--det_pse_box_type", type=str, default='box') parser.add_argument("--det_pse_box_type", type=str, default='quad')
parser.add_argument("--det_pse_scale", type=int, default=1) parser.add_argument("--det_pse_scale", type=int, default=1)
# FCE parmas
parser.add_argument("--scales", type=list, default=[8, 16, 32])
parser.add_argument("--alpha", type=float, default=1.0)
parser.add_argument("--beta", type=float, default=1.0)
parser.add_argument("--fourier_degree", type=int, default=5)
parser.add_argument("--det_fce_box_type", type=str, default='poly')
# params for text recognizer # params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册