未验证 提交 91980595 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #2443 from JetHong/pgnet-postpro

add  Pgnet postpro
...@@ -59,8 +59,10 @@ Optimizer: ...@@ -59,8 +59,10 @@ Optimizer:
PostProcess: PostProcess:
name: PGPostProcess name: PGPostProcess
score_thresh: 0.5 score_thresh: 0.5
mode: fast # fast or slow two ways
Metric: Metric:
name: E2EMetric name: E2EMetric
gt_mat_dir: # the dir of gt_mat
character_dict_path: ppocr/utils/ic15_dict.txt character_dict_path: ppocr/utils/ic15_dict.txt
main_indicator: f_score_e2e main_indicator: f_score_e2e
...@@ -106,7 +108,7 @@ Eval: ...@@ -106,7 +108,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ] keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id']
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -28,13 +28,10 @@ inference 模型(`paddle.jit.save`保存的模型) ...@@ -28,13 +28,10 @@ inference 模型(`paddle.jit.save`保存的模型)
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理) - [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- [5. 多语言模型的推理](#多语言模型的推理) - [5. 多语言模型的推理](#多语言模型的推理)
- [四、端到端模型推理](#端到端模型推理) - [四、方向分类模型推理](#方向识别模型推理)
- [1. PGNet端到端模型推理](#PGNet端到端模型推理)
- [五、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理) - [1. 方向分类模型推理](#方向分类模型推理)
- [、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理) - [、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理) - [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
- [2. 其他模型推理](#其他模型推理) - [2. 其他模型推理](#其他模型推理)
...@@ -362,38 +359,8 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" - ...@@ -362,38 +359,8 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" -
Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904) Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
``` ```
<a name="端到端模型推理"></a>
## 四、端到端模型推理
端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
<a name="PGNet端到端模型推理"></a>
### 1. PGNet端到端模型推理
#### (1). 四边形文本检测模型(ICDAR2015)
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar)),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
```
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
#### (2). 弯曲文本检测模型(Total-Text)
和四边形文本检测模型共用一个推理模型
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
```
python3.7 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
```
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img623_pgnet.jpg)
<a name="方向分类模型推理"></a> <a name="方向分类模型推理"></a>
## 、方向分类模型推理 ## 、方向分类模型推理
下面将介绍方向分类模型推理。 下面将介绍方向分类模型推理。
...@@ -418,7 +385,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982] ...@@ -418,7 +385,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982]
``` ```
<a name="文本检测、方向分类和文字识别串联推理"></a> <a name="文本检测、方向分类和文字识别串联推理"></a>
## 、文本检测、方向分类和文字识别串联推理 ## 、文本检测、方向分类和文字识别串联推理
<a name="超轻量中文OCR模型推理"></a> <a name="超轻量中文OCR模型推理"></a>
### 1. 超轻量中文OCR模型推理 ### 1. 超轻量中文OCR模型推理
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
- [一、简介](#简介) - [一、简介](#简介)
- [二、环境配置](#环境配置) - [二、环境配置](#环境配置)
- [三、快速使用](#快速使用) - [三、快速使用](#快速使用)
- [四、模型训练、评估、推理](#快速训练) - [四、模型训练、评估、推理](#模型训练、评估、推理)
<a name="简介"></a> <a name="简介"></a>
## 一、简介 ## 一、简介
...@@ -16,11 +16,13 @@ OCR算法可以分为两阶段算法和端对端的算法。二阶段OCR算法 ...@@ -16,11 +16,13 @@ OCR算法可以分为两阶段算法和端对端的算法。二阶段OCR算法
- 提出基于图的修正模块(GRM)来进一步提高模型识别性能 - 提出基于图的修正模块(GRM)来进一步提高模型识别性能
- 精度更高,预测速度更快 - 精度更高,预测速度更快
PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf)算法原理图如下所示: PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) ,算法原理图如下所示:
![](../pgnet_framework.png) ![](../pgnet_framework.png)
输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。 输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。
其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。 其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。
其检测识别效果图如下: 其检测识别效果图如下:
![](../imgs_results/e2e_res_img293_pgnet.png) ![](../imgs_results/e2e_res_img293_pgnet.png)
![](../imgs_results/e2e_res_img295_pgnet.png) ![](../imgs_results/e2e_res_img295_pgnet.png)
...@@ -49,24 +51,24 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer. ...@@ -49,24 +51,24 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.
### 单张图像或者图像集合预测 ### 单张图像或者图像集合预测
```bash ```bash
# 预测image_dir指定的单张图像 # 预测image_dir指定的单张图像
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
# 预测image_dir指定的图像集合 # 预测image_dir指定的图像集合
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
# 如果想使用CPU进行预测,需设置use_gpu参数为False # 如果想使用CPU进行预测,需设置use_gpu参数为False
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True --use_gpu=False python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True --use_gpu=False
``` ```
### 可视化结果 ### 可视化结果
可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下: 可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img623_pgnet.jpg) ![](../imgs_results/e2e_res_img623_pgnet.jpg)
<a name="快速训练"></a> <a name="模型训练、评估、推理"></a>
## 四、模型训练、评估、推理 ## 四、模型训练、评估、推理
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。 本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
### 准备数据 ### 准备数据
下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md)数据集到PaddleOCR/train_data/目录,数据集组织结构: 下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) 数据集到PaddleOCR/train_data/目录,数据集组织结构:
``` ```
/PaddleOCR/train_data/total_text/train/ /PaddleOCR/train_data/total_text/train/
|- rgb/ # total_text数据集的训练数据 |- rgb/ # total_text数据集的训练数据
...@@ -135,20 +137,20 @@ python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{ ...@@ -135,20 +137,20 @@ python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{
### 模型预测 ### 模型预测
测试单张图像的端到端识别效果 测试单张图像的端到端识别效果
```shell ```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false
``` ```
测试文件夹下所有图像的端到端识别效果 测试文件夹下所有图像的端到端识别效果
```shell ```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false
``` ```
### 预测推理 ### 预测推理
#### (1).四边形文本检测模型(ICDAR2015) #### (1). 四边形文本检测模型(ICDAR2015)
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,以英文数据集训练的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar) ,可以使用如下命令进行转换: 首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,以英文数据集训练的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar) ,可以使用如下命令进行转换:
``` ```
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
``` ```
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令: **PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
``` ```
...@@ -158,7 +160,7 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im ...@@ -158,7 +160,7 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
![](../imgs_results/e2e_res_img_10_pgnet.jpg) ![](../imgs_results/e2e_res_img_10_pgnet.jpg)
#### (2).弯曲文本检测模型(Total-Text) #### (2). 弯曲文本检测模型(Total-Text)
对于弯曲文本样例 对于弯曲文本样例
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令: **PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
...@@ -168,3 +170,10 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im ...@@ -168,3 +170,10 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下: 可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img623_pgnet.jpg) ![](../imgs_results/e2e_res_img623_pgnet.jpg)
#### (3). 性能指标
| |det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS (size=640)|
| --- | --- | --- | --- | --- | --- | --- | --- |
|Paper|85.30|86.80|86.1|-|-|61.7|38.20|
|Ours|87.03|82.48|84.69|61.71|58.43|60.03|62.61|
*note:PaddleOCR里的PGNet实现针对预测速度做了优化,在精度下降可接受范围内,可以显著提升端对端预测速度*
...@@ -15,7 +15,7 @@ In recent years, the end-to-end OCR algorithm has been well developed, including ...@@ -15,7 +15,7 @@ In recent years, the end-to-end OCR algorithm has been well developed, including
- A graph based modification module (GRM) is proposed to further improve the performance of model recognition - A graph based modification module (GRM) is proposed to further improve the performance of model recognition
- Higher accuracy and faster prediction speed - Higher accuracy and faster prediction speed
For details of PGNet algorithm, please refer to [paper](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf), The schematic diagram of the algorithm is as follows: For details of PGNet algorithm, please refer to [paper](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) ,The schematic diagram of the algorithm is as follows:
![](../pgnet_framework.png) ![](../pgnet_framework.png)
After feature extraction, the input image is sent to four branches: TBO module for text edge offset prediction, TCL module for text centerline prediction, TDO module for text direction offset prediction, and TCC module for text character classification graph prediction. After feature extraction, the input image is sent to four branches: TBO module for text edge offset prediction, TCL module for text centerline prediction, TDO module for text direction offset prediction, and TCC module for text character classification graph prediction.
The output of TBO and TCL can get text detection results after post-processing, and TCL, TDO and TCC are responsible for text recognition. The output of TBO and TCL can get text detection results after post-processing, and TCL, TDO and TCC are responsible for text recognition.
...@@ -49,13 +49,13 @@ After decompression, there should be the following file structure: ...@@ -49,13 +49,13 @@ After decompression, there should be the following file structure:
### Single image or image set prediction ### Single image or image set prediction
```bash ```bash
# Prediction single image specified by image_dir # Prediction single image specified by image_dir
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
# Prediction the collection of images specified by image_dir # Prediction the collection of images specified by image_dir
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
# If you want to use CPU for prediction, you need to set use_gpu parameter is false # If you want to use CPU for prediction, you need to set use_gpu parameter is false
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True --use_gpu=False python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True --use_gpu=False
``` ```
### Visualization results ### Visualization results
The visualized end-to-end results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows: The visualized end-to-end results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows:
...@@ -141,12 +141,12 @@ python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{ ...@@ -141,12 +141,12 @@ python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{
### Model Test ### Model Test
Test the end-to-end result on a single image: Test the end-to-end result on a single image:
```shell ```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false
``` ```
Test the end-to-end result on all images in the folder: Test the end-to-end result on all images in the folder:
```shell ```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false
``` ```
### Model inference ### Model inference
...@@ -154,7 +154,7 @@ python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img= ...@@ -154,7 +154,7 @@ python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img=
First, convert the model saved in the PGNet end-to-end training process into an inference model. In the first stage of training based on composite dataset, the model of English data set training is taken as an example[model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar), you can use the following command to convert: First, convert the model saved in the PGNet end-to-end training process into an inference model. In the first stage of training based on composite dataset, the model of English data set training is taken as an example[model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar), you can use the following command to convert:
``` ```
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
``` ```
**For PGNet quadrangle end-to-end model inference, you need to set the parameter `--e2e_algorithm="PGNet"`**, run the following command: **For PGNet quadrangle end-to-end model inference, you need to set the parameter `--e2e_algorithm="PGNet"`**, run the following command:
``` ```
...@@ -173,3 +173,9 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im ...@@ -173,3 +173,9 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows: The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows:
![](../imgs_results/e2e_res_img623_pgnet.jpg) ![](../imgs_results/e2e_res_img623_pgnet.jpg)
#### (3). Performance
| |det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS (size=640)|
| --- | --- | --- | --- | --- | --- | --- | --- |
|Paper|85.30|86.80|86.1|-|-|61.7|38.20|
|Ours|87.03|82.48|84.69|61.71|58.43|60.03|62.61|
*note:PGNet in PaddleOCR optimizes the prediction speed, and can significantly improve the end-to-end prediction speed within the acceptable range of accuracy reduction*
...@@ -200,18 +200,16 @@ class E2ELabelEncode(BaseRecLabelEncode): ...@@ -200,18 +200,16 @@ class E2ELabelEncode(BaseRecLabelEncode):
self.pad_num = len(self.dict) # the length to pad self.pad_num = len(self.dict) # the length to pad
def __call__(self, data): def __call__(self, data):
text_label_index_list, temp_text = [], []
texts = data['strs'] texts = data['strs']
temp_texts = []
for text in texts: for text in texts:
text = text.lower() text = text.lower()
temp_text = [] text = self.encode(text)
for c_ in text: if text is None:
if c_ in self.dict: return None
temp_text.append(self.dict[c_]) text = text + [self.pad_num] * (self.max_text_len - len(text))
temp_text = temp_text + [self.pad_num] * (self.max_text_len - temp_texts.append(text)
len(temp_text)) data['strs'] = np.array(temp_texts)
text_label_index_list.append(temp_text)
data['strs'] = np.array(text_label_index_list)
return data return data
......
...@@ -64,9 +64,6 @@ class PGDataSet(Dataset): ...@@ -64,9 +64,6 @@ class PGDataSet(Dataset):
for line in f.readlines(): for line in f.readlines():
poly_str, txt = line.strip().split('\t') poly_str, txt = line.strip().split('\t')
poly = list(map(float, poly_str.split(','))) poly = list(map(float, poly_str.split(',')))
if self.mode.lower() == "eval":
while len(poly) < 100:
poly.append(-1)
text_polys.append( text_polys.append(
np.array( np.array(
poly, dtype=np.float32).reshape(-1, 2)) poly, dtype=np.float32).reshape(-1, 2))
...@@ -139,10 +136,6 @@ class PGDataSet(Dataset): ...@@ -139,10 +136,6 @@ class PGDataSet(Dataset):
try: try:
if self.data_format == 'icdar': if self.data_format == 'icdar':
im_path = os.path.join(data_path, 'rgb', data_line) im_path = os.path.join(data_path, 'rgb', data_line)
if self.mode.lower() == "eval":
poly_path = os.path.join(data_path, 'poly_gt',
data_line.split('.')[0] + '.txt')
else:
poly_path = os.path.join(data_path, 'poly', poly_path = os.path.join(data_path, 'poly',
data_line.split('.')[0] + '.txt') data_line.split('.')[0] + '.txt')
text_polys, text_tags, text_strs = self.extract_polys(poly_path) text_polys, text_tags, text_strs = self.extract_polys(poly_path)
...@@ -150,12 +143,14 @@ class PGDataSet(Dataset): ...@@ -150,12 +143,14 @@ class PGDataSet(Dataset):
image_dir = os.path.join(os.path.dirname(data_path), 'image') image_dir = os.path.join(os.path.dirname(data_path), 'image')
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet( im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
data_line, image_dir) data_line, image_dir)
img_id = int(data_line.split(".")[0][3:])
data = { data = {
'img_path': im_path, 'img_path': im_path,
'polys': text_polys, 'polys': text_polys,
'tags': text_tags, 'tags': text_tags,
'strs': text_strs 'strs': text_strs,
'img_id': img_id
} }
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
......
...@@ -19,57 +19,28 @@ from __future__ import print_function ...@@ -19,57 +19,28 @@ from __future__ import print_function
__all__ = ['E2EMetric'] __all__ = ['E2EMetric']
from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results
from ppocr.utils.e2e_utils.extract_textpoint import get_dict from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict
class E2EMetric(object): class E2EMetric(object):
def __init__(self, def __init__(self,
gt_mat_dir,
character_dict_path, character_dict_path,
main_indicator='f_score_e2e', main_indicator='f_score_e2e',
**kwargs): **kwargs):
self.gt_mat_dir = gt_mat_dir
self.label_list = get_dict(character_dict_path) self.label_list = get_dict(character_dict_path)
self.max_index = len(self.label_list) self.max_index = len(self.label_list)
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.reset() self.reset()
def __call__(self, preds, batch, **kwargs): def __call__(self, preds, batch, **kwargs):
temp_gt_polyons_batch = batch[2] img_id = batch[5][0]
temp_gt_strs_batch = batch[3]
ignore_tags_batch = batch[4]
gt_polyons_batch = []
gt_strs_batch = []
temp_gt_polyons_batch = temp_gt_polyons_batch[0].tolist()
for temp_list in temp_gt_polyons_batch:
t = []
for index in temp_list:
if index[0] != -1 and index[1] != -1:
t.append(index)
gt_polyons_batch.append(t)
temp_gt_strs_batch = temp_gt_strs_batch[0].tolist()
for temp_list in temp_gt_strs_batch:
t = ""
for index in temp_list:
if index < self.max_index:
t += self.label_list[index]
gt_strs_batch.append(t)
for pred, gt_polyons, gt_strs, ignore_tags in zip(
[preds], [gt_polyons_batch], [gt_strs_batch], ignore_tags_batch):
# prepare gt
gt_info_list = [{
'points': gt_polyon,
'text': gt_str,
'ignore': ignore_tag
} for gt_polyon, gt_str, ignore_tag in
zip(gt_polyons, gt_strs, ignore_tags)]
# prepare det
e2e_info_list = [{ e2e_info_list = [{
'points': det_polyon, 'points': det_polyon,
'text': pred_str 'text': pred_str
} for det_polyon, pred_str in zip(pred['points'], pred['strs'])] } for det_polyon, pred_str in zip(preds['points'], preds['strs'])]
result = get_socre(gt_info_list, e2e_info_list) result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
self.results.append(result) self.results.append(result)
def get_metric(self): def get_metric(self):
......
...@@ -22,10 +22,7 @@ import sys ...@@ -22,10 +22,7 @@ import sys
__dir__ = os.path.dirname(__file__) __dir__ = os.path.dirname(__file__)
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..')) sys.path.append(os.path.join(__dir__, '..'))
from ppocr.utils.e2e_utils.pgnet_pp_utils import PGNet_PostProcess
from ppocr.utils.e2e_utils.extract_textpoint import *
from ppocr.utils.e2e_utils.visual import *
import paddle
class PGPostProcess(object): class PGPostProcess(object):
...@@ -33,10 +30,12 @@ class PGPostProcess(object): ...@@ -33,10 +30,12 @@ class PGPostProcess(object):
The post process for PGNet. The post process for PGNet.
""" """
def __init__(self, character_dict_path, valid_set, score_thresh, **kwargs): def __init__(self, character_dict_path, valid_set, score_thresh, mode,
self.Lexicon_Table = get_dict(character_dict_path) **kwargs):
self.character_dict_path = character_dict_path
self.valid_set = valid_set self.valid_set = valid_set
self.score_thresh = score_thresh self.score_thresh = score_thresh
self.mode = mode
# c++ la-nms is faster, but only support python 3.5 # c++ la-nms is faster, but only support python 3.5
self.is_python35 = False self.is_python35 = False
...@@ -44,112 +43,10 @@ class PGPostProcess(object): ...@@ -44,112 +43,10 @@ class PGPostProcess(object):
self.is_python35 = True self.is_python35 = True
def __call__(self, outs_dict, shape_list): def __call__(self, outs_dict, shape_list):
p_score = outs_dict['f_score'] post = PGNet_PostProcess(self.character_dict_path, self.valid_set,
p_border = outs_dict['f_border'] self.score_thresh, outs_dict, shape_list)
p_char = outs_dict['f_char'] if self.mode == 'fast':
p_direction = outs_dict['f_direction'] data = post.pg_postprocess_fast()
if isinstance(p_score, paddle.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
else:
p_score = p_score[0]
p_border = p_border[0]
p_direction = p_direction[0]
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = shape_list[0]
is_curved = self.valid_set == "totaltext"
instance_yxs_list = generate_pivot_list(
p_score,
p_char,
p_direction,
score_thresh=self.score_thresh,
is_backbone=True,
is_curved=is_curved)
p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
char_seq_idx_set = []
for i in range(len(instance_yxs_list)):
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
feature_len = [len(feature_seq[0])]
featyre_seq = paddle.to_tensor(feature_seq)
feature_len = np.array([feature_len]).astype(np.int64)
length = paddle.to_tensor(feature_len)
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
input=featyre_seq, blank=36, input_length=length)
seq_pred_str = seq_pred[0].numpy().tolist()[0]
seq_len = seq_pred[1].numpy()[0][0]
temp_t = []
for c in seq_pred_str[:seq_len]:
temp_t.append(c)
char_seq_idx_set.append(temp_t)
seq_strs = []
for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
seq_strs.append(pr_str)
poly_list = []
keep_str_list = []
all_point_list = []
all_point_pair_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(yx_center_line) == 1:
yx_center_line.append(yx_center_line[-1])
offset_expand = 1.0
if self.valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for batch_id, y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
all_point_list.append([
int(round(x * 4.0 / ratio_w)),
int(round(y * 4.0 / ratio_h))
])
all_point_pair_list.append(point_pair.round().astype(np.int32)
.tolist())
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
if len(keep_str) < 2:
continue
keep_str_list.append(keep_str)
if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif self.valid_set == 'totaltext':
poly_list.append(detected_poly)
else: else:
print('--> Not supported format.') data = post.pg_postprocess_slow()
exit(-1)
data = {
'points': poly_list,
'strs': keep_str_list,
}
return data return data
...@@ -13,10 +13,11 @@ ...@@ -13,10 +13,11 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import scipy.io as io
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
def get_socre(gt_dict, pred_dict): def get_socre(gt_dir, img_id, pred_dict):
allInputs = 1 allInputs = 1
def input_reading_mod(pred_dict): def input_reading_mod(pred_dict):
...@@ -30,31 +31,9 @@ def get_socre(gt_dict, pred_dict): ...@@ -30,31 +31,9 @@ def get_socre(gt_dict, pred_dict):
det.append([point, text]) det.append([point, text])
return det return det
def gt_reading_mod(gt_dict): def gt_reading_mod(gt_dir, gt_id):
"""This helper reads groundtruths from mat files""" gt = io.loadmat('%s/poly_gt_img%s.mat' % (gt_dir, gt_id))
gt = [] gt = gt['polygt']
n = len(gt_dict)
for i in range(n):
points = gt_dict[i]['points']
h = len(points)
text = gt_dict[i]['text']
xx = [
np.array(
['x:'], dtype='<U2'), 0, np.array(
['y:'], dtype='<U2'), 0, np.array(
['#'], dtype='<U1'), np.array(
['#'], dtype='<U1')
]
t_x, t_y = [], []
for j in range(h):
t_x.append(points[j][0])
t_y.append(points[j][1])
xx[1] = np.array([t_x], dtype='int16')
xx[3] = np.array([t_y], dtype='int16')
if text != "" and "#" not in text:
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
xx[5] = np.array(['c'], dtype='<U1')
gt.append(xx)
return gt return gt
def detection_filtering(detections, groundtruths, threshold=0.5): def detection_filtering(detections, groundtruths, threshold=0.5):
...@@ -101,7 +80,7 @@ def get_socre(gt_dict, pred_dict): ...@@ -101,7 +80,7 @@ def get_socre(gt_dict, pred_dict):
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \ input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
and (input_id != 'Deteval_result_non_curved.txt'): and (input_id != 'Deteval_result_non_curved.txt'):
detections = input_reading_mod(pred_dict) detections = input_reading_mod(pred_dict)
groundtruths = gt_reading_mod(gt_dict) groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
detections = detection_filtering( detections = detection_filtering(
detections, detections,
groundtruths) # filters detections overlapping with DC area groundtruths) # filters detections overlapping with DC area
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import math
import numpy as np
from itertools import groupby
from skimage.morphology._skeletonize import thin
def get_dict(character_dict_path):
character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
character_str += line
dict_character = list(character_str)
return dict_character
def softmax(logits):
"""
logits: N x d
"""
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def get_keep_pos_idxs(labels, remove_blank=None):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list = []
keep_pos_idx_list = []
keep_char_idx_list = []
for k, v_ in groupby(labels):
current_len = len(list(v_))
if k != remove_blank:
current_idx = int(sum(duplicate_len_list) + current_len // 2)
keep_pos_idx_list.append(current_idx)
keep_char_idx_list.append(k)
duplicate_len_list.append(current_len)
return keep_char_idx_list, keep_pos_idx_list
def remove_blank(labels, blank=0):
new_labels = [x for x in labels if x != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
"""
CTC greedy (best path) decoder.
"""
raw_str = np.argmax(np.array(probs_seq), axis=1)
remove_blank_in_pos = None if keep_blank_in_idxs else blank
dedup_str, keep_idx_list = get_keep_pos_idxs(
raw_str, remove_blank=remove_blank_in_pos)
dst_str = remove_blank(dedup_str, blank=blank)
return dst_str, keep_idx_list
def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
_, _, C = logits_map.shape
ys, xs = zip(*gather_info)
logits_seq = logits_map[list(ys), list(xs)]
probs_seq = logits_seq
labels = np.argmax(probs_seq, axis=1)
dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
detal = len(gather_info) // (pts_num - 1)
keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
return dst_str, keep_gather_list
def ctc_decoder_for_image(gather_info_list,
logits_map,
Lexicon_Table,
pts_num=6):
"""
CTC decoder using multiple processes.
"""
decoder_str = []
decoder_xys = []
for gather_info in gather_info_list:
if len(gather_info) < pts_num:
continue
dst_str, xys_list = instance_ctc_greedy_decoder(
gather_info, logits_map, pts_num=pts_num)
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
if len(dst_str_readable) < 2:
continue
decoder_str.append(dst_str_readable)
decoder_xys.append(xys_list)
return decoder_str, decoder_xys
def sort_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list, point_direction):
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point, np.array(sorted_direction)
def add_id(pos_list, image_id=0):
"""
Add id for gather feature, for inference.
"""
new_list = []
for item in pos_list:
new_list.append((image_id, item[0], item[1]))
return new_list
def sort_and_expand_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
left_list = []
right_list = []
for i in range(append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
left_list.append((ly, lx))
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
right_list.append((ry, rx))
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
binary_tcl_map: h x w
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
max_append_num = 2 * append_num
left_list = []
right_list = []
for i in range(max_append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
if binary_tcl_map[ly, lx] > 0.5:
left_list.append((ly, lx))
else:
break
for i in range(max_append_num):
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
if binary_tcl_map[ry, rx] > 0.5:
right_list.append((ry, rx))
else:
break
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2)
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
src_h, valid_set):
poly_list = []
keep_str_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(keep_str) < 2:
print('--> too short, {}'.format(keep_str))
continue
offset_expand = 1.0
if valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2) * offset_expand
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
detected_poly = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
keep_str_list.append(keep_str)
if valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
return poly_list, keep_str_list
def generate_pivot_list_fast(p_score,
p_char_maps,
f_direction,
Lexicon_Table,
score_thresh=0.5):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map.astype(np.uint8))
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
if len(pos_list) < 3:
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
all_pos_yxs.append(pos_list_sorted)
p_char_maps = p_char_maps.transpose([1, 2, 0])
decoded_str, keep_yxs_list = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table)
return keep_yxs_list, decoded_str
def extract_main_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list = np.array(pos_list)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
average_direction = average_direction / (
np.linalg.norm(average_direction) + 1e-6)
return average_direction
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full = np.array(pos_list).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
return sorted_list
def sort_by_direction_with_image_id(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list_full, point_direction):
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 3)
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point
...@@ -35,6 +35,64 @@ def get_dict(character_dict_path): ...@@ -35,6 +35,64 @@ def get_dict(character_dict_path):
return dict_character return dict_character
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
pair_length_list = []
for point_pair in point_pair_list:
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
pair_length_list.append(pair_length)
pair_length_list = np.array(pair_length_list)
pair_info = (pair_length_list.max(), pair_length_list.min(),
pair_length_list.mean())
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2), pair_info
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def softmax(logits): def softmax(logits):
""" """
logits: N x d logits: N x d
...@@ -399,7 +457,7 @@ def generate_pivot_list_horizontal(p_score, ...@@ -399,7 +457,7 @@ def generate_pivot_list_horizontal(p_score,
return center_pos_yxs, end_points_yxs return center_pos_yxs, end_points_yxs
def generate_pivot_list(p_score, def generate_pivot_list_slow(p_score,
p_char_maps, p_char_maps,
f_direction, f_direction,
score_thresh=0.5, score_thresh=0.5,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
from extract_textpoint_slow import *
from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
class PGNet_PostProcess(object):
# two different post-process
def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict,
shape_list):
self.Lexicon_Table = get_dict(character_dict_path)
self.valid_set = valid_set
self.score_thresh = score_thresh
self.outs_dict = outs_dict
self.shape_list = shape_list
def pg_postprocess_fast(self):
p_score = self.outs_dict['f_score']
p_border = self.outs_dict['f_border']
p_char = self.outs_dict['f_char']
p_direction = self.outs_dict['f_direction']
if isinstance(p_score, paddle.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
else:
p_score = p_score[0]
p_border = p_border[0]
p_direction = p_direction[0]
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
instance_yxs_list, seq_strs = generate_pivot_list_fast(
p_score,
p_char,
p_direction,
self.Lexicon_Table,
score_thresh=self.score_thresh)
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
p_border, ratio_w, ratio_h,
src_w, src_h, self.valid_set)
data = {
'points': poly_list,
'strs': keep_str_list,
}
return data
def pg_postprocess_slow(self):
p_score = self.outs_dict['f_score']
p_border = self.outs_dict['f_border']
p_char = self.outs_dict['f_char']
p_direction = self.outs_dict['f_direction']
if isinstance(p_score, paddle.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
else:
p_score = p_score[0]
p_border = p_border[0]
p_direction = p_direction[0]
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
is_curved = self.valid_set == "totaltext"
instance_yxs_list = generate_pivot_list_slow(
p_score,
p_char,
p_direction,
score_thresh=self.score_thresh,
is_backbone=True,
is_curved=is_curved)
p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
char_seq_idx_set = []
for i in range(len(instance_yxs_list)):
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
feature_len = [len(feature_seq[0])]
featyre_seq = paddle.to_tensor(feature_seq)
feature_len = np.array([feature_len]).astype(np.int64)
length = paddle.to_tensor(feature_len)
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
input=featyre_seq, blank=36, input_length=length)
seq_pred_str = seq_pred[0].numpy().tolist()[0]
seq_len = seq_pred[1].numpy()[0][0]
temp_t = []
for c in seq_pred_str[:seq_len]:
temp_t.append(c)
char_seq_idx_set.append(temp_t)
seq_strs = []
for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
seq_strs.append(pr_str)
poly_list = []
keep_str_list = []
all_point_list = []
all_point_pair_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(yx_center_line) == 1:
yx_center_line.append(yx_center_line[-1])
offset_expand = 1.0
if self.valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for batch_id, y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
all_point_list.append([
int(round(x * 4.0 / ratio_w)),
int(round(y * 4.0 / ratio_h))
])
all_point_pair_list.append(point_pair.round().astype(np.int32)
.tolist())
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
if len(keep_str) < 2:
continue
keep_str_list.append(keep_str)
detected_poly = np.round(detected_poly).astype('int32')
if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif self.valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
data = {
'points': poly_list,
'strs': keep_str_list,
}
return data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册