提交 1a2c7b28 编写于 作者: 文幕地方's avatar 文幕地方

add FECNet inference

上级 30e8f905
......@@ -56,7 +56,7 @@ PostProcess:
thresh: 0
box_thresh: 0.85
min_area: 16
box_type: box # 'box' or 'poly'
box_type: quad # 'quad' or 'poly'
scale: 1
Metric:
......
......@@ -27,14 +27,11 @@ Architecture:
out_indices: [1,2,3]
Neck:
name: FCEFPN
in_channels: [512, 1024, 2048]
out_channels: 256
has_extra_convs: False
extra_stage: 0
Head:
name: FCEHead
in_channels: 256
scales: [8, 16, 32]
fourier_degree: 5
Loss:
name: FCELoss
......@@ -57,6 +54,7 @@ PostProcess:
alpha: 1.0
beta: 1.0
fourier_degree: 5
box_type: 'poly'
Metric:
name: DetFCEMetric
......@@ -123,8 +121,8 @@ Eval:
ignore_orientation: True
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# resize_long: 1280
rescale_img: [1080, 736]
limit_type: 'min'
limit_side_len: 736
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
......
......@@ -55,7 +55,7 @@ PostProcess:
thresh: 0
box_thresh: 0.85
min_area: 16
box_type: box # 'box' or 'poly'
box_type: quad # 'quad' or 'poly'
scale: 1
Metric:
......
......@@ -43,13 +43,12 @@ class FCEHead(nn.Layer):
fourier_degree (int) : The maximum Fourier transform degree k.
"""
def __init__(self, in_channels, scales, fourier_degree=5):
def __init__(self, in_channels, fourier_degree=5):
super().__init__()
assert isinstance(in_channels, int)
self.downsample_ratio = 1.0
self.in_channels = in_channels
self.scales = scales
self.fourier_degree = fourier_degree
self.out_channels_cls = 4
self.out_channels_reg = (2 * self.fourier_degree + 1) * 2
......@@ -82,9 +81,7 @@ class FCEHead(nn.Layer):
def forward(self, feats, targets=None):
cls_res, reg_res = multi_apply(self.forward_single, feats)
level_num = len(cls_res)
# import pdb;pdb.set_trace()
outs = {}
if not self.training:
for i in range(level_num):
tr_pred = F.softmax(cls_res[i][:, 0:2, :, :], axis=1)
......
......@@ -74,7 +74,7 @@ class FCEPostProcess(object):
nms_thr=0.1,
alpha=1.0,
beta=1.0,
text_repr_type='poly',
box_type='poly',
**kwargs):
self.scales = scales
......@@ -85,7 +85,7 @@ class FCEPostProcess(object):
self.nms_thr = nms_thr
self.alpha = alpha
self.beta = beta
self.text_repr_type = text_repr_type
self.box_type = box_type
def __call__(self, preds, shape_list):
score_maps = []
......@@ -149,7 +149,7 @@ class FCEPostProcess(object):
scale=scale,
alpha=self.alpha,
beta=self.beta,
text_repr_type=self.text_repr_type,
box_type=self.box_type,
score_thr=self.score_thr,
nms_thr=self.nms_thr)
......@@ -160,7 +160,7 @@ class FCEPostProcess(object):
scale,
alpha=1.0,
beta=2.0,
text_repr_type='poly',
box_type='poly',
score_thr=0.3,
nms_thr=0.1):
"""Decoding predictions of FCENet to instances.
......@@ -175,7 +175,7 @@ class FCEPostProcess(object):
= (Score_{text region} ^ alpha)
* (Score_{text center region}^ beta)
beta (float) : The parameter to calculate final score.
text_repr_type (str): Boundary encoding type 'poly' or 'quad'.
box_type (str): Boundary encoding type 'poly' or 'quad'.
score_thr (float) : The threshold used to filter out the final
candidates.
nms_thr (float) : The threshold of nms.
......@@ -186,7 +186,7 @@ class FCEPostProcess(object):
"""
assert isinstance(preds, list)
assert len(preds) == 2
assert text_repr_type in ['poly', 'quad']
assert box_type in ['poly', 'quad']
cls_pred = preds[0][0]
tr_pred = cls_pred[0:2]
......@@ -228,7 +228,7 @@ class FCEPostProcess(object):
boundaries = poly_nms(boundaries, nms_thr)
if text_repr_type == 'quad':
if box_type == 'quad':
new_boundaries = []
for boundary in boundaries:
poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32)
......@@ -236,5 +236,6 @@ class FCEPostProcess(object):
points = cv2.boxPoints(cv2.minAreaRect(poly))
points = np.int0(points)
new_boundaries.append(points.reshape(-1).tolist() + [score])
boundaries = new_boundaries
return boundaries
......@@ -37,10 +37,10 @@ class PSEPostProcess(object):
thresh=0.5,
box_thresh=0.85,
min_area=16,
box_type='box',
box_type='quad',
scale=4,
**kwargs):
assert box_type in ['box', 'poly'], 'Only box and poly is supported'
assert box_type in ['quad', 'poly'], 'Only quad and poly is supported'
self.thresh = thresh
self.box_thresh = box_thresh
self.min_area = min_area
......@@ -95,7 +95,7 @@ class PSEPostProcess(object):
label[ind] = 0
continue
if self.box_type == 'box':
if self.box_type == 'quad':
rect = cv2.minAreaRect(points)
bbox = cv2.boxPoints(rect)
elif self.box_type == 'poly':
......
......@@ -98,6 +98,18 @@ class TextDetector(object):
postprocess_params["box_type"] = args.det_pse_box_type
postprocess_params["scale"] = args.det_pse_scale
self.det_pse_box_type = args.det_pse_box_type
elif self.det_algorithm == "FCE":
pre_process_list[0] = {
'DetResizeForTest': {
'rescale_img': [1080, 736]
}
}
postprocess_params['name'] = 'FCEPostProcess'
postprocess_params["scales"] = args.scales
postprocess_params["alpha"] = args.alpha
postprocess_params["beta"] = args.beta
postprocess_params["fourier_degree"] = args.fourier_degree
postprocess_params["box_type"] = args.det_fce_box_type
else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0)
......@@ -234,15 +246,18 @@ class TextDetector(object):
preds['f_tvo'] = outputs[3]
elif self.det_algorithm in ['DB', 'PSE']:
preds['maps'] = outputs[0]
elif self.det_algorithm == 'FCE':
for i, output in enumerate(outputs):
preds['level_{}'.format(i)] = output
else:
raise NotImplementedError
#self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
if (self.det_algorithm == "SAST" and
self.det_sast_polygon) or (self.det_algorithm == "PSE" and
self.det_pse_box_type == 'poly'):
if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
self.det_algorithm in ["PSE", "FCE"] and
self.postprocess_op.box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
......
......@@ -68,9 +68,16 @@ def init_args():
parser.add_argument("--det_pse_thresh", type=float, default=0)
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
parser.add_argument("--det_pse_min_area", type=float, default=16)
parser.add_argument("--det_pse_box_type", type=str, default='box')
parser.add_argument("--det_pse_box_type", type=str, default='quad')
parser.add_argument("--det_pse_scale", type=int, default=1)
# FCE parmas
parser.add_argument("--scales", type=list, default=[8, 16, 32])
parser.add_argument("--alpha", type=float, default=1.0)
parser.add_argument("--beta", type=float, default=1.0)
parser.add_argument("--fourier_degree", type=int, default=5)
parser.add_argument("--det_fce_box_type", type=str, default='poly')
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册