From a6a44ea124406a5960790cee52ea0f079cf8794e Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Thu, 29 Jul 2021 19:31:30 +0800 Subject: [PATCH] add export model of pse --- doc/doc_ch/algorithm_overview.md | 2 +- doc/doc_en/algorithm_overview_en.md | 2 +- ppocr/modeling/necks/fpn.py | 22 ++++++++++------------ tools/infer/predict_det.py | 12 +++++++++++- tools/infer/utility.py | 7 +++++++ 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 9ba50ffe..b0d9b3ef 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -24,7 +24,7 @@ PaddleOCR开源的文本检测算法列表: |DB|MobileNetV3|77.29%|73.08%|75.12%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| |SAST|ResNet50_vd|91.39%|83.77%|87.42%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)| |PSE|ResNet50_vd|85.81%|79.53%|82.55%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)| -|PSE|MobileNetV3|82.20%|70.47%|75.89%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)| +|PSE|MobileNetV3|82.20%|70.48%|75.89%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)| 在Total-text文本检测公开数据集上,算法效果如下: diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index a34508f9..592b9ef0 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -26,7 +26,7 @@ On the ICDAR2015 dataset, the text detection result is as follows: |DB|MobileNetV3|77.29%|73.08%|75.12%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| |SAST|ResNet50_vd|91.39%|83.77%|87.42%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)| |PSE|ResNet50_vd|85.81%|79.53%|82.55%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)| -|PSE|MobileNetV3|82.20%|70.47%|75.89%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)| +|PSE|MobileNetV3|82.20%|70.48%|75.89%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)| On Total-Text dataset, the text detection result is as follows: diff --git a/ppocr/modeling/necks/fpn.py b/ppocr/modeling/necks/fpn.py index 49089200..8728a5c9 100644 --- a/ppocr/modeling/necks/fpn.py +++ b/ppocr/modeling/necks/fpn.py @@ -70,33 +70,31 @@ class FPN(nn.Layer): m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32', default_initializer=paddle.nn.initializer.Constant(0.0)) - def _upsample(self, x, y, scale=1): - _, _, H, W = y.shape - return F.upsample(x, size=(H // scale, W // scale), mode='bilinear') + def _upsample(self, x, scale=1): + return F.upsample(x, scale_factor=scale, mode='bilinear') - def _upsample_add(self, x, y): - _, _, H, W = y.shape - return F.upsample(x, size=(H, W), mode='bilinear') + y + def _upsample_add(self, x, y, scale=1): + return F.upsample(x, scale_factor=scale, mode='bilinear') + y def forward(self, x): f2, f3, f4, f5 = x p5 = self.toplayer_(f5) f4 = self.latlayer1_(f4) - p4 = self._upsample_add(p5, f4) + p4 = self._upsample_add(p5, f4,2) p4 = self.smooth1_(p4) f3 = self.latlayer2_(f3) - p3 = self._upsample_add(p4, f3) + p3 = self._upsample_add(p4, f3,2) p3 = self.smooth2_(p3) f2 = self.latlayer3_(f2) - p2 = self._upsample_add(p3, f2) + p2 = self._upsample_add(p3, f2,2) p2 = self.smooth3_(p2) - p3 = self._upsample(p3, p2) - p4 = self._upsample(p4, p2) - p5 = self._upsample(p5, p2) + p3 = self._upsample(p3, 2) + p4 = self._upsample(p4, 4) + p5 = self._upsample(p5, 8) fuse = paddle.concat([p2, p3, p4, p5], axis=1) return fuse \ No newline at end of file diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 6a45f81e..f727b4f5 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -89,6 +89,14 @@ class TextDetector(object): postprocess_params["sample_pts_num"] = 2 postprocess_params["expand_scale"] = 1.0 postprocess_params["shrink_ratio_of_width"] = 0.3 + elif self.det_algorithm == "PSE": + postprocess_params['name'] = 'PSEPostProcess' + postprocess_params["thresh"] = args.det_pse_thresh + postprocess_params["box_thresh"] = args.det_pse_box_thresh + postprocess_params["min_area"] = args.det_pse_min_area + postprocess_params["box_type"] = args.det_pse_box_type + postprocess_params["scale"] = args.det_pse_scale + self.det_pse_box_type = args.det_pse_box_type else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) @@ -207,7 +215,7 @@ class TextDetector(object): preds['f_score'] = outputs[1] preds['f_tco'] = outputs[2] preds['f_tvo'] = outputs[3] - elif self.det_algorithm == 'DB': + elif self.det_algorithm in ['DB','PSE']: preds['maps'] = outputs[0] else: raise NotImplementedError @@ -217,6 +225,8 @@ class TextDetector(object): dt_boxes = post_result[0]['points'] if self.det_algorithm == "SAST" and self.det_sast_polygon: dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) + elif self.det_algorithm == "PSE" and self.det_pse_box_type=='poly': + dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) else: dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 28e9818b..9bb6136e 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -62,6 +62,13 @@ def init_args(): parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) parser.add_argument("--det_sast_polygon", type=bool, default=False) + # PSE parmas + parser.add_argument("--det_pse_thresh", type=float, default=0) + parser.add_argument("--det_pse_box_thresh", type=float, default=0.85) + parser.add_argument("--det_pse_min_area", type=float, default=16) + parser.add_argument("--det_pse_box_type", type=str, default='poly') + parser.add_argument("--det_pse_scale", type=int, default=1) + # params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_model_dir", type=str) -- GitLab