提交 a6a44ea1 编写于 作者: W WenmuZhou

add export model of pse

上级 c64d5519
......@@ -24,7 +24,7 @@ PaddleOCR开源的文本检测算法列表:
|DB|MobileNetV3|77.29%|73.08%|75.12%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
|PSE|MobileNetV3|82.20%|70.47%|75.89%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
|PSE|MobileNetV3|82.20%|70.48%|75.89%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
在Total-text文本检测公开数据集上,算法效果如下:
......
......@@ -26,7 +26,7 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|DB|MobileNetV3|77.29%|73.08%|75.12%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
|PSE|MobileNetV3|82.20%|70.47%|75.89%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
|PSE|MobileNetV3|82.20%|70.48%|75.89%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
On Total-Text dataset, the text detection result is as follows:
......
......@@ -70,33 +70,31 @@ class FPN(nn.Layer):
m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32',
default_initializer=paddle.nn.initializer.Constant(0.0))
def _upsample(self, x, y, scale=1):
_, _, H, W = y.shape
return F.upsample(x, size=(H // scale, W // scale), mode='bilinear')
def _upsample(self, x, scale=1):
return F.upsample(x, scale_factor=scale, mode='bilinear')
def _upsample_add(self, x, y):
_, _, H, W = y.shape
return F.upsample(x, size=(H, W), mode='bilinear') + y
def _upsample_add(self, x, y, scale=1):
return F.upsample(x, scale_factor=scale, mode='bilinear') + y
def forward(self, x):
f2, f3, f4, f5 = x
p5 = self.toplayer_(f5)
f4 = self.latlayer1_(f4)
p4 = self._upsample_add(p5, f4)
p4 = self._upsample_add(p5, f4,2)
p4 = self.smooth1_(p4)
f3 = self.latlayer2_(f3)
p3 = self._upsample_add(p4, f3)
p3 = self._upsample_add(p4, f3,2)
p3 = self.smooth2_(p3)
f2 = self.latlayer3_(f2)
p2 = self._upsample_add(p3, f2)
p2 = self._upsample_add(p3, f2,2)
p2 = self.smooth3_(p2)
p3 = self._upsample(p3, p2)
p4 = self._upsample(p4, p2)
p5 = self._upsample(p5, p2)
p3 = self._upsample(p3, 2)
p4 = self._upsample(p4, 4)
p5 = self._upsample(p5, 8)
fuse = paddle.concat([p2, p3, p4, p5], axis=1)
return fuse
\ No newline at end of file
......@@ -89,6 +89,14 @@ class TextDetector(object):
postprocess_params["sample_pts_num"] = 2
postprocess_params["expand_scale"] = 1.0
postprocess_params["shrink_ratio_of_width"] = 0.3
elif self.det_algorithm == "PSE":
postprocess_params['name'] = 'PSEPostProcess'
postprocess_params["thresh"] = args.det_pse_thresh
postprocess_params["box_thresh"] = args.det_pse_box_thresh
postprocess_params["min_area"] = args.det_pse_min_area
postprocess_params["box_type"] = args.det_pse_box_type
postprocess_params["scale"] = args.det_pse_scale
self.det_pse_box_type = args.det_pse_box_type
else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0)
......@@ -207,7 +215,7 @@ class TextDetector(object):
preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3]
elif self.det_algorithm == 'DB':
elif self.det_algorithm in ['DB','PSE']:
preds['maps'] = outputs[0]
else:
raise NotImplementedError
......@@ -217,6 +225,8 @@ class TextDetector(object):
dt_boxes = post_result[0]['points']
if self.det_algorithm == "SAST" and self.det_sast_polygon:
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
elif self.det_algorithm == "PSE" and self.det_pse_box_type=='poly':
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
......
......@@ -62,6 +62,13 @@ def init_args():
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
parser.add_argument("--det_sast_polygon", type=bool, default=False)
# PSE parmas
parser.add_argument("--det_pse_thresh", type=float, default=0)
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
parser.add_argument("--det_pse_min_area", type=float, default=16)
parser.add_argument("--det_pse_box_type", type=str, default='poly')
parser.add_argument("--det_pse_scale", type=int, default=1)
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册