提交 a6a44ea1 编写于 作者: W WenmuZhou

add export model of pse

上级 c64d5519
...@@ -24,7 +24,7 @@ PaddleOCR开源的文本检测算法列表: ...@@ -24,7 +24,7 @@ PaddleOCR开源的文本检测算法列表:
|DB|MobileNetV3|77.29%|73.08%|75.12%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| |DB|MobileNetV3|77.29%|73.08%|75.12%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)| |SAST|ResNet50_vd|91.39%|83.77%|87.42%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)| |PSE|ResNet50_vd|85.81%|79.53%|82.55%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
|PSE|MobileNetV3|82.20%|70.47%|75.89%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)| |PSE|MobileNetV3|82.20%|70.48%|75.89%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
在Total-text文本检测公开数据集上,算法效果如下: 在Total-text文本检测公开数据集上,算法效果如下:
......
...@@ -26,7 +26,7 @@ On the ICDAR2015 dataset, the text detection result is as follows: ...@@ -26,7 +26,7 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|DB|MobileNetV3|77.29%|73.08%|75.12%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| |DB|MobileNetV3|77.29%|73.08%|75.12%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)| |SAST|ResNet50_vd|91.39%|83.77%|87.42%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)| |PSE|ResNet50_vd|85.81%|79.53%|82.55%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
|PSE|MobileNetV3|82.20%|70.47%|75.89%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)| |PSE|MobileNetV3|82.20%|70.48%|75.89%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
On Total-Text dataset, the text detection result is as follows: On Total-Text dataset, the text detection result is as follows:
......
...@@ -70,33 +70,31 @@ class FPN(nn.Layer): ...@@ -70,33 +70,31 @@ class FPN(nn.Layer):
m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32', m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32',
default_initializer=paddle.nn.initializer.Constant(0.0)) default_initializer=paddle.nn.initializer.Constant(0.0))
def _upsample(self, x, y, scale=1): def _upsample(self, x, scale=1):
_, _, H, W = y.shape return F.upsample(x, scale_factor=scale, mode='bilinear')
return F.upsample(x, size=(H // scale, W // scale), mode='bilinear')
def _upsample_add(self, x, y): def _upsample_add(self, x, y, scale=1):
_, _, H, W = y.shape return F.upsample(x, scale_factor=scale, mode='bilinear') + y
return F.upsample(x, size=(H, W), mode='bilinear') + y
def forward(self, x): def forward(self, x):
f2, f3, f4, f5 = x f2, f3, f4, f5 = x
p5 = self.toplayer_(f5) p5 = self.toplayer_(f5)
f4 = self.latlayer1_(f4) f4 = self.latlayer1_(f4)
p4 = self._upsample_add(p5, f4) p4 = self._upsample_add(p5, f4,2)
p4 = self.smooth1_(p4) p4 = self.smooth1_(p4)
f3 = self.latlayer2_(f3) f3 = self.latlayer2_(f3)
p3 = self._upsample_add(p4, f3) p3 = self._upsample_add(p4, f3,2)
p3 = self.smooth2_(p3) p3 = self.smooth2_(p3)
f2 = self.latlayer3_(f2) f2 = self.latlayer3_(f2)
p2 = self._upsample_add(p3, f2) p2 = self._upsample_add(p3, f2,2)
p2 = self.smooth3_(p2) p2 = self.smooth3_(p2)
p3 = self._upsample(p3, p2) p3 = self._upsample(p3, 2)
p4 = self._upsample(p4, p2) p4 = self._upsample(p4, 4)
p5 = self._upsample(p5, p2) p5 = self._upsample(p5, 8)
fuse = paddle.concat([p2, p3, p4, p5], axis=1) fuse = paddle.concat([p2, p3, p4, p5], axis=1)
return fuse return fuse
\ No newline at end of file
...@@ -89,6 +89,14 @@ class TextDetector(object): ...@@ -89,6 +89,14 @@ class TextDetector(object):
postprocess_params["sample_pts_num"] = 2 postprocess_params["sample_pts_num"] = 2
postprocess_params["expand_scale"] = 1.0 postprocess_params["expand_scale"] = 1.0
postprocess_params["shrink_ratio_of_width"] = 0.3 postprocess_params["shrink_ratio_of_width"] = 0.3
elif self.det_algorithm == "PSE":
postprocess_params['name'] = 'PSEPostProcess'
postprocess_params["thresh"] = args.det_pse_thresh
postprocess_params["box_thresh"] = args.det_pse_box_thresh
postprocess_params["min_area"] = args.det_pse_min_area
postprocess_params["box_type"] = args.det_pse_box_type
postprocess_params["scale"] = args.det_pse_scale
self.det_pse_box_type = args.det_pse_box_type
else: else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0) sys.exit(0)
...@@ -207,7 +215,7 @@ class TextDetector(object): ...@@ -207,7 +215,7 @@ class TextDetector(object):
preds['f_score'] = outputs[1] preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2] preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3] preds['f_tvo'] = outputs[3]
elif self.det_algorithm == 'DB': elif self.det_algorithm in ['DB','PSE']:
preds['maps'] = outputs[0] preds['maps'] = outputs[0]
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -217,6 +225,8 @@ class TextDetector(object): ...@@ -217,6 +225,8 @@ class TextDetector(object):
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
if self.det_algorithm == "SAST" and self.det_sast_polygon: if self.det_algorithm == "SAST" and self.det_sast_polygon:
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
elif self.det_algorithm == "PSE" and self.det_pse_box_type=='poly':
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else: else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
......
...@@ -62,6 +62,13 @@ def init_args(): ...@@ -62,6 +62,13 @@ def init_args():
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
parser.add_argument("--det_sast_polygon", type=bool, default=False) parser.add_argument("--det_sast_polygon", type=bool, default=False)
# PSE parmas
parser.add_argument("--det_pse_thresh", type=float, default=0)
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
parser.add_argument("--det_pse_min_area", type=float, default=16)
parser.add_argument("--det_pse_box_type", type=str, default='poly')
parser.add_argument("--det_pse_scale", type=int, default=1)
# params for text recognizer # params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册