未验证 提交 58a5c0b3 编写于 作者: D Double_V 提交者: GitHub

[check_install] (#8177)

* support min_area_rect crop

* add check_install

* fix requirement.txt

* fix check_install

* add lanms-neo for drrg

* fix

* fix doc
上级 2160b2fd
......@@ -26,8 +26,8 @@
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- |
|EAST|ResNet50_vd|88.71%| 81.36%| 84.88%| [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|EAST| MobileNetV3| 78.20%| 79.10%| 78.65%| [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|EAST|ResNet50_vd| [det_r50_vd_east.yml](../../configs/det/det_r50_vd_east.yml)|88.71%| 81.36%| 84.88%| [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|EAST|MobileNetV3|[det_mv3_east.yml](../../configs/det/det_mv3_east.yml) | 78.20%| 79.10%| 78.65%| [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
<a name="2"></a>
......
......@@ -73,9 +73,9 @@ python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Gl
```
SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`,同时,还需要增加参数`--det_sast_polygon=True`,可以执行如下命令:
SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`,同时,还需要增加参数`--det_box_type=poly`,可以执行如下命令:
```
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_sast_polygon=True
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_box_type='poly'
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
......
......@@ -70,7 +70,7 @@ SAST算法相关参数如下
| :--: | :--: | :--: | :--: |
| det_sast_score_thresh | float | 0.5 | SAST后处理中的得分阈值 |
| det_sast_nms_thresh | float | 0.5 | SAST后处理中nms的阈值 |
| det_sast_polygon | bool | False | 是否多边形检测,弯曲文本场景(如Total-Text)设置为True |
| det_box_type | str | quad | 是否多边形检测,弯曲文本场景(如Total-Text)设置为'poly' |
PSE算法相关参数如下
......@@ -79,7 +79,7 @@ PSE算法相关参数如下
| det_pse_thresh | float | 0.0 | 对输出图做二值化的阈值 |
| det_pse_box_thresh | float | 0.85 | 对box进行过滤的阈值,低于此阈值的丢弃 |
| det_pse_min_area | float | 16 | box的最小面积,低于此阈值的丢弃 |
| det_pse_box_type | str | "box" | 返回框的类型,box:四点坐标,poly: 弯曲文本的所有点坐标 |
| det_box_type | str | "quad" | 返回框的类型,quad:四点坐标,poly: 弯曲文本的所有点坐标 |
| det_pse_scale | int | 1 | 输入图像相对于进后处理的图的比例,如`640*640`的图像,网络输出为`160*160`,scale为2的情况下,进后处理的图片shape为`320*320`。这个值调大可以加快后处理速度,但是会带来精度的下降 |
* 文本识别模型相关
......
......@@ -26,8 +26,9 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|Model|Backbone|Configuration|Precision|Recall|Hmean|Download|
| --- | --- | --- | --- | --- | --- | --- |
|EAST|ResNet50_vd|88.71%| 81.36%| 84.88%| [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|EAST| MobileNetV3| 78.20%| 79.10%| 78.65%| [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|EAST|ResNet50_vd| [det_r50_vd_east.yml](../../configs/det/det_r50_vd_east.yml)|88.71%| 81.36%| 84.88%| [model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|EAST|MobileNetV3|[det_mv3_east.yml](../../configs/det/det_mv3_east.yml) | 78.20%| 79.10%| 78.65%| [model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
<a name="2"></a>
......
......@@ -74,10 +74,10 @@ First, convert the model saved in the SAST text detection training process into
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_sast_tt
```
For SAST curved text detection model inference, you need to set the parameter `--det_algorithm="SAST"` and `--det_sast_polygon=True`, run the following command:
For SAST curved text detection model inference, you need to set the parameter `--det_algorithm="SAST"` and `--det_box_type=poly`, run the following command:
```
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_sast_polygon=True
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_box_type='poly'
```
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
......
......@@ -70,7 +70,7 @@ The relevant parameters of the SAST algorithm are as follows
| :--: | :--: | :--: | :--: |
| det_sast_score_thresh | float | 0.5 | Score thresholds in SAST postprocess |
| det_sast_nms_thresh | float | 0.5 | Thresholding of nms in SAST postprocess |
| det_sast_polygon | bool | False | Whether polygon detection, curved text scene (such as Total-Text) is set to True |
| det_box_type | str | 'quad' | Whether polygon detection, curved text scene (such as Total-Text) is set to 'poly' |
The relevant parameters of the PSE algorithm are as follows
......@@ -79,7 +79,7 @@ The relevant parameters of the PSE algorithm are as follows
| det_pse_thresh | float | 0.0 | Threshold for binarizing the output image |
| det_pse_box_thresh | float | 0.85 | Threshold for filtering boxes, below this threshold is discarded |
| det_pse_min_area | float | 16 | The minimum area of the box, below this threshold is discarded |
| det_pse_box_type | str | "box" | The type of the returned box, box: four point coordinates, poly: all point coordinates of the curved text |
| det_box_type | str | "quad" | The type of the returned box, quad: four point coordinates, poly: all point coordinates of the curved text |
| det_pse_scale | int | 1 | The ratio of the input image relative to the post-processed image, such as an image of `640*640`, the network output is `160*160`, and when the scale is 2, the shape of the post-processed image is `320*320`. Increasing this value can speed up the post-processing speed, but it will bring about a decrease in accuracy |
* Text recognition model related parameters
......
......@@ -19,7 +19,8 @@ import pyclipper
import paddle
import numpy as np
import Polygon as plg
from ppocr.utils.utility import check_install
import scipy.io as scio
from PIL import Image
......@@ -70,6 +71,8 @@ class MakeShrink():
return peri
def shrink(self, bboxes, rate, max_shr=20):
check_install('Polygon', 'Polygon3')
import Polygon as plg
rate = rate * rate
shrinked_bboxes = []
for bbox in bboxes:
......
......@@ -18,7 +18,7 @@ https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_t
import cv2
import numpy as np
from lanms import merge_quadrangle_n9 as la_nms
from ppocr.utils.utility import check_install
from numpy.linalg import norm
......@@ -543,6 +543,8 @@ class DRRGTargets(object):
score = np.ones((text_comps.shape[0], 1), dtype=np.float32)
text_comps = np.hstack([text_comps, score])
check_install('lanms', 'lanms-neo')
from lanms import merge_quadrangle_n9 as la_nms
text_comps = la_nms(text_comps, self.text_comp_nms_thr)
if text_comps.shape[0] >= 1:
......
......@@ -22,6 +22,7 @@ import cv2
import paddle
import os
from ppocr.utils.utility import check_install
import sys
......@@ -78,11 +79,11 @@ class EASTPostProcess(object):
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
try:
check_install('lanms', 'lanms-nova')
import lanms
boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
except:
print(
'you should install lanms by pip3 install lanms-nova to speed up nms_locality'
'You should install lanms by pip3 install lanms-nova to speed up nms_locality'
)
boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
if boxes.shape[0] == 0:
......
......@@ -141,6 +141,8 @@ class SASTPostProcess(object):
def nms(self, dets):
if self.is_python35:
from ppocr.utils.utility import check_install
check_install('lanms', 'lanms-nova')
import lanms
dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
else:
......
......@@ -15,7 +15,11 @@
import json
import numpy as np
import scipy.io as io
from ppocr.utils.utility import check_install
check_install("Polygon", "Polygon3")
import Polygon as plg
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
......
......@@ -19,6 +19,9 @@ import cv2
import random
import numpy as np
import paddle
import importlib.util
import sys
import subprocess
def print_dict(d, logger, delimiter=0):
......@@ -131,6 +134,26 @@ def set_seed(seed=1024):
paddle.seed(seed)
def check_install(module_name, install_name):
spec = importlib.util.find_spec(module_name)
if spec is None:
print(f'Warnning! The {module_name} module is NOT installed')
print(
f'Try install {module_name} module automatically. You can also try to install manually by pip install {install_name}.'
)
python = sys.executable
try:
subprocess.check_call(
[python, '-m', 'pip', 'install', install_name],
stdout=subprocess.DEVNULL)
print(f'The {module_name} module is now installed')
except subprocess.CalledProcessError as exc:
raise Exception(
f"Install {module_name} failed, please install manually")
else:
print(f"{module_name} has been installed.")
class AverageMeter:
def __init__(self):
self.reset()
......
......@@ -13,7 +13,5 @@ cython
lxml
premailer
openpyxl
attrdict3
Polygon3
lanms-neo==1.0.2
attrdict
PyMuPDF==1.19.0
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册