diff --git a/configs/det/det_r50_vd_sast_icdar15.yml b/configs/det/det_r50_vd_sast_icdar15.yml
index c24cae90132c68d662e9edb7a7975e358fb40d9c..c90327b22b9c73111c997e84cfdd47d0721ee5b9 100755
--- a/configs/det/det_r50_vd_sast_icdar15.yml
+++ b/configs/det/det_r50_vd_sast_icdar15.yml
@@ -14,12 +14,13 @@ Global:
load_static_weights: True
cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/
- checkpoints:
+ checkpoints:
save_inference_dir:
use_visualdl: False
- infer_img:
+ infer_img:
save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt
+
Architecture:
model_type: det
algorithm: SAST
diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0a232f7a4f3b9ca214bbc6fd1840cec186c027e4
--- /dev/null
+++ b/configs/e2e/e2e_r50_vd_pg.yml
@@ -0,0 +1,114 @@
+Global:
+ use_gpu: True
+ epoch_num: 600
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/pgnet_r50_vd_totaltext/
+ save_epoch_step: 10
+ # evaluation is run every 0 iterationss after the 1000th iteration
+ eval_batch_step: [ 0, 1000 ]
+ # 1. If pretrained_model is saved in static mode, such as classification pretrained model
+ # from static branch, load_static_weights must be set as True.
+ # 2. If you want to finetune the pretrained models we provide in the docs,
+ # you should set load_static_weights as False.
+ load_static_weights: False
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img:
+ valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words
+ save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
+ character_dict_path: ppocr/utils/ic15_dict.txt
+ character_type: EN
+ max_text_length: 50 # the max length in seq
+ max_text_nums: 30 # the max seq nums in a pic
+ tcl_len: 64
+
+Architecture:
+ model_type: e2e
+ algorithm: PGNet
+ Transform:
+ Backbone:
+ name: ResNet
+ layers: 50
+ Neck:
+ name: PGFPN
+ Head:
+ name: PGHead
+
+Loss:
+ name: PGLoss
+ tcl_bs: 64
+ max_text_length: 50 # the same as Global: max_text_length
+ max_text_nums: 30 # the same as Global:max_text_nums
+ pad_num: 36 # the length of dict for pad
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ lr:
+ learning_rate: 0.001
+ regularizer:
+ name: 'L2'
+ factor: 0
+
+
+PostProcess:
+ name: PGPostProcess
+ score_thresh: 0.5
+Metric:
+ name: E2EMetric
+ character_dict_path: ppocr/utils/ic15_dict.txt
+ main_indicator: f_score_e2e
+
+Train:
+ dataset:
+ name: PGDataSet
+ label_file_list: [.././train_data/total_text/train/]
+ ratio_list: [1.0]
+ data_format: icdar #two data format: icdar/textnet
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - PGProcessTrain:
+ batch_size: 14 # same as loader: batch_size_per_card
+ min_crop_size: 24
+ min_text_size: 4
+ max_text_size: 512
+ - KeepKeys:
+ keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ drop_last: True
+ batch_size_per_card: 14
+ num_workers: 16
+
+Eval:
+ dataset:
+ name: PGDataSet
+ data_dir: ./train_data/
+ label_file_list: [./train_data/total_text/test/]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: RGB
+ channel_first: False
+ - E2ELabelEncode:
+ - E2EResizeForTest:
+ max_side_len: 768
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [ 0.485, 0.456, 0.406 ]
+ std: [ 0.229, 0.224, 0.225 ]
+ order: 'hwc'
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ]
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1 # must be 1
+ num_workers: 2
\ No newline at end of file
diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
index 7968b355ea936d465b3c173c0fcdb3e08f12f16e..1288d90692e154220b8ceb22cd7b6d98f53d3efb 100755
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -12,7 +12,8 @@ inference 模型(`paddle.jit.save`保存的模型)
- [一、训练模型转inference模型](#训练模型转inference模型)
- [检测模型转inference模型](#检测模型转inference模型)
- [识别模型转inference模型](#识别模型转inference模型)
- - [方向分类模型转inference模型](#方向分类模型转inference模型)
+ - [方向分类模型转inference模型](#方向分类模型转inference模型)
+ - [端到端模型转inference模型](#端到端模型转inference模型)
- [二、文本检测模型推理](#文本检测模型推理)
- [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理)
@@ -27,10 +28,13 @@ inference 模型(`paddle.jit.save`保存的模型)
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- [5. 多语言模型的推理](#多语言模型的推理)
-- [四、方向分类模型推理](#方向识别模型推理)
+- [四、端到端模型推理](#端到端模型推理)
+ - [1. PGNet端到端模型推理](#PGNet端到端模型推理)
+
+- [五、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理)
-- [五、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
+- [六、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
- [2. 其他模型推理](#其他模型推理)
@@ -118,6 +122,32 @@ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_mo
├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
└── inference.pdmodel # 分类inference模型的program文件
```
+
+### 端到端模型转inference模型
+
+下载端到端模型:
+```
+wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar && tar xf ./ch_lite/ch_ppocr_mobile_v2.0_cls_train.tar -C ./ch_lite/
+```
+
+端到端模型转inference模型与检测的方式相同,如下:
+```
+# -c 后面设置训练算法的yml配置文件
+# -o 配置可选参数
+# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。
+# Global.load_static_weights 参数需要设置为 False。
+# Global.save_inference_dir参数设置转换的模型将保存的地址。
+
+python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_cls_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e/
+```
+
+转换成功后,在目录下有三个文件:
+```
+/inference/e2e/
+ ├── inference.pdiparams # 分类inference模型的参数文件
+ ├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
+ └── inference.pdmodel # 分类inference模型的program文件
+```
## 二、文本检测模型推理
@@ -332,8 +362,38 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" -
Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
```
+
+## 四、端到端模型推理
+
+端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
+
+### 1. PGNet端到端模型推理
+#### (1). 四边形文本检测模型(ICDAR2015)
+首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar)),可以使用如下命令进行转换:
+```
+python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
+```
+**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
+```
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
+```
+可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
+
+![](../imgs_results/e2e_res_img_10_pgnet.jpg)
+
+#### (2). 弯曲文本检测模型(Total-Text)
+和四边形文本检测模型共用一个推理模型
+**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
+```
+python3.7 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
+```
+可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
+
+![](../imgs_results/e2e_res_img623_pgnet.jpg)
+
+
-## 四、方向分类模型推理
+## 五、方向分类模型推理
下面将介绍方向分类模型推理。
@@ -358,7 +418,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982]
```
-## 五、文本检测、方向分类和文字识别串联推理
+## 六、文本检测、方向分类和文字识别串联推理
### 1. 超轻量中文OCR模型推理
diff --git a/doc/doc_ch/multi_languages.md b/doc/doc_ch/multi_languages.md
new file mode 100644
index 0000000000000000000000000000000000000000..a8f7c2b77f64285e0edfbd22c248e84f0bb84d42
--- /dev/null
+++ b/doc/doc_ch/multi_languages.md
@@ -0,0 +1,284 @@
+# 多语言模型
+
+**近期更新**
+
+- 2021.4.9 支持**80种**语言的检测和识别
+- 2021.4.9 支持**轻量高精度**英文模型检测识别
+
+- [1 安装](#安装)
+ - [1.1 paddle 安装](#paddle安装)
+ - [1.2 paddleocr package 安装](#paddleocr_package_安装)
+
+- [2 快速使用](#快速使用)
+ - [2.1 命令行运行](#命令行运行)
+ - [2.1.1 整图预测](#bash_检测+识别)
+ - [2.1.2 识别预测](#bash_识别)
+ - [2.1.3 检测预测](#bash_检测)
+ - [2.2 python 脚本运行](#python_脚本运行)
+ - [2.2.1 整图预测](#python_检测+识别)
+ - [2.2.2 识别预测](#python_识别)
+ - [2.2.3 检测预测](#python_检测)
+- [3 自定义训练](#自定义训练)
+- [4 支持语种及缩写](#语种缩写)
+
+
+## 1 安装
+
+
+### 1.1 paddle 安装
+```
+# cpu
+pip install paddlepaddle
+
+# gpu
+pip instll paddlepaddle-gpu
+```
+
+
+### 1.2 paddleocr package 安装
+
+
+pip 安装
+```
+pip install "paddleocr>=2.0.4" # 推荐使用2.0.4版本
+```
+本地构建并安装
+```
+python3 setup.py bdist_wheel
+pip3 install dist/paddleocr-x.x.x-py3-none-any.whl # x.x.x是paddleocr的版本号
+```
+
+
+## 2 快速使用
+
+
+### 2.1 命令行运行
+
+查看帮助信息
+
+```
+paddleocr -h
+```
+
+* 整图预测(检测+识别)
+
+Paddleocr目前支持80个语种,可以通过修改--lang参数进行切换,具体支持的[语种](#语种缩写)可查看表格。
+
+``` bash
+
+paddleocr --image_dir doc/imgs/japan_2.jpg --lang=japan
+```
+![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/imgs/japan_2.jpg)
+
+结果是一个list,每个item包含了文本框,文字和识别置信度
+```text
+[[[671.0, 60.0], [847.0, 63.0], [847.0, 104.0], [671.0, 102.0]], ('もちもち', 0.9993342)]
+[[[394.0, 82.0], [536.0, 77.0], [538.0, 127.0], [396.0, 132.0]], ('天然の', 0.9919842)]
+[[[880.0, 89.0], [1014.0, 93.0], [1013.0, 127.0], [879.0, 124.0]], ('とろっと', 0.9976762)]
+[[[1067.0, 101.0], [1294.0, 101.0], [1294.0, 138.0], [1067.0, 138.0]], ('後味のよい', 0.9988712)]
+......
+```
+
+* 识别预测
+
+```bash
+paddleocr --image_dir doc/imgs_words/japan/1.jpg --det false --lang=japan
+```
+
+![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/imgs_words/japan/1.jpg)
+
+结果是一个tuple,返回识别结果和识别置信度
+
+```text
+('したがって', 0.99965394)
+```
+
+* 检测预测
+
+```
+paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
+```
+
+结果是一个list,每个item只包含文本框
+
+```
+[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]]
+[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]]
+[[128.0, 397.0], [273.0, 397.0], [273.0, 414.0], [128.0, 414.0]]
+......
+```
+
+
+### 2.2 python 脚本运行
+
+ppocr 也支持在python脚本中运行,便于嵌入到您自己的代码中:
+
+* 整图预测(检测+识别)
+
+```
+from paddleocr import PaddleOCR, draw_ocr
+
+# 同样也是通过修改 lang 参数切换语种
+ocr = PaddleOCR(lang="korean") # 首次执行会自动下载模型文件
+img_path = 'doc/imgs/korean_1.jpg '
+result = ocr.ocr(img_path)
+# 打印检测框和识别结果
+for line in result:
+ print(line)
+
+# 可视化
+from PIL import Image
+image = Image.open(img_path).convert('RGB')
+boxes = [line[0] for line in result]
+txts = [line[1][0] for line in result]
+scores = [line[1][1] for line in result]
+im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/korean.ttf')
+im_show = Image.fromarray(im_show)
+im_show.save('result.jpg')
+```
+
+结果可视化:
+![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/imgs_results/korean.jpg)
+
+
+* 识别预测
+
+```
+from paddleocr import PaddleOCR
+ocr = PaddleOCR(lang="german")
+img_path = 'PaddleOCR/doc/imgs_words/german/1.jpg'
+result = ocr.ocr(img_path, det=False, cls=True)
+for line in result:
+ print(line)
+```
+
+![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/imgs_words/german/1.jpg)
+
+结果是一个tuple,只包含识别结果和识别置信度
+
+```
+('leider auch jetzt', 0.97538936)
+```
+
+* 检测预测
+
+```python
+from paddleocr import PaddleOCR, draw_ocr
+ocr = PaddleOCR() # need to run only once to download and load model into memory
+img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg'
+result = ocr.ocr(img_path, rec=False)
+for line in result:
+ print(line)
+
+# 显示结果
+from PIL import Image
+
+image = Image.open(img_path).convert('RGB')
+im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
+im_show = Image.fromarray(im_show)
+im_show.save('result.jpg')
+```
+结果是一个list,每个item只包含文本框
+```bash
+[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]]
+[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]]
+[[128.0, 397.0], [273.0, 397.0], [273.0, 414.0], [128.0, 414.0]]
+......
+```
+
+结果可视化 :
+![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/imgs_results/whl/12_det.jpg)
+
+ppocr 还支持方向分类, 更多使用方式请参考:[whl包使用说明](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_ch/whl.md)。
+
+
+## 3 自定义训练
+
+ppocr 支持使用自己的数据进行自定义训练或finetune, 其中识别模型可以参考 [法语配置文件](../../configs/rec/multi_language/rec_french_lite_train.yml)
+修改训练数据路径、字典等参数。
+
+具体数据准备、训练过程可参考:[文本检测](../doc_ch/detection.md)、[文本识别](../doc_ch/recognition.md),更多功能如预测部署、
+数据标注等功能可以阅读完整的[文档教程](../../README_ch.md)。
+
+
+## 4 支持语种及缩写
+
+| 语种 | 描述 | 缩写 |
+| --- | --- | --- |
+|中文|chinese and english|ch|
+|英文|english|en|
+|法文|french|fr|
+|德文|german|german|
+|日文|japan|japan|
+|韩文|korean|korean|
+|中文繁体|chinese traditional |ch_tra|
+|意大利文| Italian |it|
+|西班牙文|Spanish |es|
+|葡萄牙文| Portuguese|pt|
+|俄罗斯文|Russia|ru|
+|阿拉伯文|Arabic|ar|
+|印地文|Hindi|hi|
+|维吾尔|Uyghur|ug|
+|波斯文|Persian|fa|
+|乌尔都文|Urdu|ur|
+|塞尔维亚文(latin)| Serbian(latin) |rs_latin|
+|欧西坦文|Occitan |oc|
+|马拉地文|Marathi|mr|
+|尼泊尔文|Nepali|ne|
+|塞尔维亚文(cyrillic)|Serbian(cyrillic)|rs_cyrillic|
+|保加利亚文|Bulgarian |bg|
+|乌克兰文|Ukranian|uk|
+|白俄罗斯文|Belarusian|be|
+|泰卢固文|Telugu |te|
+|卡纳达文|Kannada |kn|
+|泰米尔文|Tamil |ta|
+|南非荷兰文 |Afrikaans |af|
+|阿塞拜疆文 |Azerbaijani |az|
+|波斯尼亚文|Bosnian|bs|
+|捷克文|Czech|cs|
+|威尔士文 |Welsh |cy|
+|丹麦文 |Danish|da|
+|爱沙尼亚文 |Estonian |et|
+|爱尔兰文 |Irish |ga|
+|克罗地亚文|Croatian |hr|
+|匈牙利文|Hungarian |hu|
+|印尼文|Indonesian|id|
+|冰岛文 |Icelandic|is|
+|库尔德文 |Kurdish|ku|
+|立陶宛文|Lithuanian |lt|
+|拉脱维亚文 |Latvian |lv|
+|毛利文|Maori|mi|
+|马来文 |Malay|ms|
+|马耳他文 |Maltese |mt|
+|荷兰文 |Dutch |nl|
+|挪威文 |Norwegian |no|
+|波兰文|Polish |pl|
+| 罗马尼亚文|Romanian |ro|
+| 斯洛伐克文|Slovak |sk|
+| 斯洛文尼亚文|Slovenian |sl|
+| 阿尔巴尼亚文|Albanian |sq|
+| 瑞典文|Swedish |sv|
+| 西瓦希里文|Swahili |sw|
+| 塔加洛文|Tagalog |tl|
+| 土耳其文|Turkish |tr|
+| 乌兹别克文|Uzbek |uz|
+| 越南文|Vietnamese |vi|
+| 蒙古文|Mongolian |mn|
+| 阿巴扎文|Abaza |abq|
+| 阿迪赫文|Adyghe |ady|
+| 卡巴丹文|Kabardian |kbd|
+| 阿瓦尔文|Avar |ava|
+| 达尔瓦文|Dargwa |dar|
+| 因古什文|Ingush |inh|
+| 拉克文|Lak |lbe|
+| 莱兹甘文|Lezghian |lez|
+|塔巴萨兰文 |Tabassaran |tab|
+| 比尔哈文|Bihari |bh|
+| 迈蒂利文|Maithili |mai|
+| 昂加文|Angika |ang|
+| 孟加拉文|Bhojpuri |bho|
+| 摩揭陀文 |Magahi |mah|
+| 那格浦尔文|Nagpur |sck|
+| 尼瓦尔文|Newari |new|
+| 保加利亚文 |Goan Konkani|gom|
+| 沙特阿拉伯文|Saudi Arabia|sa|
diff --git a/doc/doc_ch/pgnet.md b/doc/doc_ch/pgnet.md
index abe4122dfe4b5de618ff449582827eae264a7275..165b4bb44b6a151f8433952186bc07adc38b3763 100644
--- a/doc/doc_ch/pgnet.md
+++ b/doc/doc_ch/pgnet.md
@@ -1,14 +1,9 @@
-
-# 端对端OCR算法-PGNet
-
-----
# 端对端OCR算法-PGNet
- [一、简介](#简介)
- [二、环境配置](#环境配置)
- [三、快速使用](#快速使用)
- [四、模型训练、评估、推理](#快速训练)
-
## 一、简介
OCR算法可以分为两阶段算法和端对端的算法。二阶段OCR算法一般分为两个部分,文本检测和文本识别算法,文件检测算法从图像中得到文本行的检测框,然后识别算法去识别文本框中的内容。而端对端OCR算法可以在一个算法中完成文字检测和文字识别,其基本思想是设计一个同时具有检测单元和识别模块的模型,共享其中两者的CNN特征,并联合训练。由于一个算法即可完成文字识别,端对端模型更小,速度更快。
@@ -62,13 +57,12 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
# 如果想使用CPU进行预测,需设置use_gpu参数为False
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True --use_gpu=False
```
-
### 可视化结果
可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img623_pgnet.jpg)
-## 四、快速训练
+## 四、模型训练、评估、推理
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
### 准备数据
@@ -103,7 +97,6 @@ wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/tr
└─ best_accuracy.states
└─ best_accuracy.pdparams
```
-
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
```shell
@@ -121,8 +114,13 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0001
```
+#### 断点训练
+如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
+```shell
+python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model
+```
-### 模型评估
+**注意**:`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
PaddleOCR计算三个OCR端到端相关的指标,分别是:Precision、Recall、Hmean。
diff --git a/doc/imgs_results/whl/12_det.jpg b/doc/imgs_results/whl/12_det.jpg
index 1d5ccf2a6b5d3fa9516560e0cb2646ad6b917da6..71627f0b8db8fdc6e1bf0c4601f0311160d3164d 100644
Binary files a/doc/imgs_results/whl/12_det.jpg and b/doc/imgs_results/whl/12_det.jpg differ
diff --git a/paddleocr.py b/paddleocr.py
index c3741b264503534ef3e64531c2576273d8ccfd11..47e1267ac40effbe8b4ab80723c66eb5378be179 100644
--- a/paddleocr.py
+++ b/paddleocr.py
@@ -66,6 +66,46 @@ model_urls = {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt'
+ },
+ 'chinese_cht': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
+ 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
+ },
+ 'ta': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
+ 'dict_path': './ppocr/utils/dict/ta_dict.txt'
+ },
+ 'te': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
+ 'dict_path': './ppocr/utils/dict/te_dict.txt'
+ },
+ 'ka': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
+ 'dict_path': './ppocr/utils/dict/ka_dict.txt'
+ },
+ 'latin': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
+ 'dict_path': './ppocr/utils/dict/latin_dict.txt'
+ },
+ 'arabic': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
+ 'dict_path': './ppocr/utils/dict/arabic_dict.txt'
+ },
+ 'cyrillic': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
+ 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
+ },
+ 'devanagari': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
+ 'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
}
},
'cls':
@@ -233,6 +273,29 @@ class PaddleOCR(predict_system.TextSystem):
postprocess_params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls
lang = postprocess_params.lang
+ latin_lang = [
+ 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'en', 'es', 'et', 'fr',
+ 'ga', 'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi',
+ 'ms', 'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin',
+ 'sk', 'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi'
+ ]
+ arabic_lang = ['ar', 'fa', 'ug', 'ur']
+ cyrillic_lang = [
+ 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd',
+ 'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
+ ]
+ devanagari_lang = [
+ 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new',
+ 'gom', 'sa', 'bgc'
+ ]
+ if lang in latin_lang:
+ lang = "latin"
+ elif lang in arabic_lang:
+ lang = "arabic"
+ elif lang in cyrillic_lang:
+ lang = "cyrillic"
+ elif lang in devanagari_lang:
+ lang = "devanagari"
assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang)
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index 7cb50d7a62aa3f24811e517768e0635ac7b7321a..728b8317f54687ee76b519cba18f4d7807493821 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -34,6 +34,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet
+from ppocr.data.pgnet_dataset import PGDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
@@ -54,7 +55,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
- support_dict = ['SimpleDataSet', 'LMDBDataSet']
+ support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
@@ -72,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None):
else:
use_shared_memory = True
if mode == "Train":
- #Distribute data to multiple cards
+ # Distribute data to multiple cards
batch_sampler = DistributedBatchSampler(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
else:
- #Distribute data to single card
+ # Distribute data to single card
batch_sampler = BatchSampler(
dataset=dataset,
batch_size=batch_size,
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 250ac75e7683df2353d9fad02ef42b9e133681d3..a808fd586b0676751da1ee31d379179b026fd51d 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -28,6 +28,7 @@ from .label_ops import *
from .east_process import *
from .sast_process import *
+from .pg_process import *
def transform(data, ops=None):
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 7a32d870bfc7f532896ce6b11aac5508a6369993..47e0cbf07d8bd8b6ad838fa2d211345c65a6751a 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -187,6 +187,34 @@ class CTCLabelEncode(BaseRecLabelEncode):
return dict_character
+class E2ELabelEncode(BaseRecLabelEncode):
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ character_type='EN',
+ use_space_char=False,
+ **kwargs):
+ super(E2ELabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ character_type, use_space_char)
+ self.pad_num = len(self.dict) # the length to pad
+
+ def __call__(self, data):
+ text_label_index_list, temp_text = [], []
+ texts = data['strs']
+ for text in texts:
+ text = text.lower()
+ temp_text = []
+ for c_ in text:
+ if c_ in self.dict:
+ temp_text.append(self.dict[c_])
+ temp_text = temp_text + [self.pad_num] * (self.max_text_len -
+ len(temp_text))
+ text_label_index_list.append(temp_text)
+ data['strs'] = np.array(text_label_index_list)
+ return data
+
+
class AttnLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py
index eacfdf3b243af5b9051ad726ced6edacddff45ed..9c48b09647527cf718113ea1b5df152ff7befa04 100644
--- a/ppocr/data/imaug/operators.py
+++ b/ppocr/data/imaug/operators.py
@@ -197,7 +197,6 @@ class DetResizeForTest(object):
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
- # return img, np.array([h, w])
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
@@ -206,7 +205,6 @@ class DetResizeForTest(object):
resize_w = w
resize_h = h
- # Fix the longer side
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
@@ -223,3 +221,72 @@ class DetResizeForTest(object):
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
+
+
+class E2EResizeForTest(object):
+ def __init__(self, **kwargs):
+ super(E2EResizeForTest, self).__init__()
+ self.max_side_len = kwargs['max_side_len']
+ self.valid_set = kwargs['valid_set']
+
+ def __call__(self, data):
+ img = data['image']
+ src_h, src_w, _ = img.shape
+ if self.valid_set == 'totaltext':
+ im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
+ img, max_side_len=self.max_side_len)
+ else:
+ im_resized, (ratio_h, ratio_w) = self.resize_image(
+ img, max_side_len=self.max_side_len)
+ data['image'] = im_resized
+ data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
+ return data
+
+ def resize_image_for_totaltext(self, im, max_side_len=512):
+
+ h, w, _ = im.shape
+ resize_w = w
+ resize_h = h
+ ratio = 1.25
+ if h * ratio > max_side_len:
+ ratio = float(max_side_len) / resize_h
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+ return im, (ratio_h, ratio_w)
+
+ def resize_image(self, im, max_side_len=512):
+ """
+ resize image to a size multiple of max_stride which is required by the network
+ :param im: the resized image
+ :param max_side_len: limit of max image size to avoid out of memory in gpu
+ :return: the resized image and the resize ratio
+ """
+ h, w, _ = im.shape
+
+ resize_w = w
+ resize_h = h
+
+ # Fix the longer side
+ if resize_h > resize_w:
+ ratio = float(max_side_len) / resize_h
+ else:
+ ratio = float(max_side_len) / resize_w
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+
+ return im, (ratio_h, ratio_w)
diff --git a/ppocr/data/imaug/pg_process.py b/ppocr/data/imaug/pg_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c9439d7a274af27ca8d296d5e737bafdec3bd1f
--- /dev/null
+++ b/ppocr/data/imaug/pg_process.py
@@ -0,0 +1,906 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import cv2
+import numpy as np
+
+__all__ = ['PGProcessTrain']
+
+
+class PGProcessTrain(object):
+ def __init__(self,
+ character_dict_path,
+ max_text_length,
+ max_text_nums,
+ tcl_len,
+ batch_size=14,
+ min_crop_size=24,
+ min_text_size=4,
+ max_text_size=512,
+ **kwargs):
+ self.tcl_len = tcl_len
+ self.max_text_length = max_text_length
+ self.max_text_nums = max_text_nums
+ self.batch_size = batch_size
+ self.min_crop_size = min_crop_size
+ self.min_text_size = min_text_size
+ self.max_text_size = max_text_size
+ self.Lexicon_Table = self.get_dict(character_dict_path)
+ self.pad_num = len(self.Lexicon_Table)
+ self.img_id = 0
+
+ def get_dict(self, character_dict_path):
+ character_str = ""
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
+ character_str += line
+ dict_character = list(character_str)
+ return dict_character
+
+ def quad_area(self, poly):
+ """
+ compute area of a polygon
+ :param poly:
+ :return:
+ """
+ edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
+ (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
+ (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
+ (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
+ return np.sum(edge) / 2.
+
+ def gen_quad_from_poly(self, poly):
+ """
+ Generate min area quad from poly.
+ """
+ point_num = poly.shape[0]
+ min_area_quad = np.zeros((4, 2), dtype=np.float32)
+ rect = cv2.minAreaRect(poly.astype(
+ np.int32)) # (center (x,y), (width, height), angle of rotation)
+ box = np.array(cv2.boxPoints(rect))
+
+ first_point_idx = 0
+ min_dist = 1e4
+ for i in range(4):
+ dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
+ np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
+ np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
+ np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ if dist < min_dist:
+ min_dist = dist
+ first_point_idx = i
+ for i in range(4):
+ min_area_quad[i] = box[(first_point_idx + i) % 4]
+
+ return min_area_quad
+
+ def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
+ """
+ check so that the text poly is in the same direction,
+ and also filter some invalid polygons
+ :param polys:
+ :param tags:
+ :return:
+ """
+ (h, w) = xxx_todo_changeme
+ if polys.shape[0] == 0:
+ return polys, np.array([]), np.array([])
+ polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
+ polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
+
+ validated_polys = []
+ validated_tags = []
+ hv_tags = []
+ for poly, tag in zip(polys, tags):
+ quad = self.gen_quad_from_poly(poly)
+ p_area = self.quad_area(quad)
+ if abs(p_area) < 1:
+ print('invalid poly')
+ continue
+ if p_area > 0:
+ if tag == False:
+ print('poly in wrong direction')
+ tag = True # reversed cases should be ignore
+ poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
+ 1), :]
+ quad = quad[(0, 3, 2, 1), :]
+
+ len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
+ quad[2])
+ len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
+ quad[2])
+ hv_tag = 1
+
+ if len_w * 2.0 < len_h:
+ hv_tag = 0
+
+ validated_polys.append(poly)
+ validated_tags.append(tag)
+ hv_tags.append(hv_tag)
+ return np.array(validated_polys), np.array(validated_tags), np.array(
+ hv_tags)
+
+ def crop_area(self,
+ im,
+ polys,
+ tags,
+ hv_tags,
+ txts,
+ crop_background=False,
+ max_tries=25):
+ """
+ make random crop from the input image
+ :param im:
+ :param polys: [b,4,2]
+ :param tags:
+ :param crop_background:
+ :param max_tries: 50 -> 25
+ :return:
+ """
+ h, w, _ = im.shape
+ pad_h = h // 10
+ pad_w = w // 10
+ h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
+ w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
+ for poly in polys:
+ poly = np.round(poly, decimals=0).astype(np.int32)
+ minx = np.min(poly[:, 0])
+ maxx = np.max(poly[:, 0])
+ w_array[minx + pad_w:maxx + pad_w] = 1
+ miny = np.min(poly[:, 1])
+ maxy = np.max(poly[:, 1])
+ h_array[miny + pad_h:maxy + pad_h] = 1
+ # ensure the cropped area not across a text
+ h_axis = np.where(h_array == 0)[0]
+ w_axis = np.where(w_array == 0)[0]
+ if len(h_axis) == 0 or len(w_axis) == 0:
+ return im, polys, tags, hv_tags, txts
+ for i in range(max_tries):
+ xx = np.random.choice(w_axis, size=2)
+ xmin = np.min(xx) - pad_w
+ xmax = np.max(xx) - pad_w
+ xmin = np.clip(xmin, 0, w - 1)
+ xmax = np.clip(xmax, 0, w - 1)
+ yy = np.random.choice(h_axis, size=2)
+ ymin = np.min(yy) - pad_h
+ ymax = np.max(yy) - pad_h
+ ymin = np.clip(ymin, 0, h - 1)
+ ymax = np.clip(ymax, 0, h - 1)
+ if xmax - xmin < self.min_crop_size or \
+ ymax - ymin < self.min_crop_size:
+ continue
+ if polys.shape[0] != 0:
+ poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
+ & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
+ selected_polys = np.where(
+ np.sum(poly_axis_in_area, axis=1) == 4)[0]
+ else:
+ selected_polys = []
+ if len(selected_polys) == 0:
+ # no text in this area
+ if crop_background:
+ txts_tmp = []
+ for selected_poly in selected_polys:
+ txts_tmp.append(txts[selected_poly])
+ txts = txts_tmp
+ return im[ymin: ymax + 1, xmin: xmax + 1, :], \
+ polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
+ else:
+ continue
+ im = im[ymin:ymax + 1, xmin:xmax + 1, :]
+ polys = polys[selected_polys]
+ tags = tags[selected_polys]
+ hv_tags = hv_tags[selected_polys]
+ txts_tmp = []
+ for selected_poly in selected_polys:
+ txts_tmp.append(txts[selected_poly])
+ txts = txts_tmp
+ polys[:, :, 0] -= xmin
+ polys[:, :, 1] -= ymin
+ return im, polys, tags, hv_tags, txts
+
+ return im, polys, tags, hv_tags, txts
+
+ def fit_and_gather_tcl_points_v2(self,
+ min_area_quad,
+ poly,
+ max_h,
+ max_w,
+ fixed_point_num=64,
+ img_id=0,
+ reference_height=3):
+ """
+ Find the center point of poly as key_points, then fit and gather.
+ """
+ key_point_xys = []
+ point_num = poly.shape[0]
+ for idx in range(point_num // 2):
+ center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0
+ key_point_xys.append(center_point)
+
+ tmp_image = np.zeros(
+ shape=(
+ max_h,
+ max_w, ), dtype='float32')
+ cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')],
+ False, 1.0)
+ ys, xs = np.where(tmp_image > 0)
+ xy_text = np.array(list(zip(xs, ys)), dtype='float32')
+
+ left_center_pt = (
+ (min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
+ right_center_pt = (
+ (min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
+ proj_unit_vec = (right_center_pt - left_center_pt) / (
+ np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
+ proj_unit_vec_tile = np.tile(proj_unit_vec,
+ (xy_text.shape[0], 1)) # (n, 2)
+ left_center_pt_tile = np.tile(left_center_pt,
+ (xy_text.shape[0], 1)) # (n, 2)
+ xy_text_to_left_center = xy_text - left_center_pt_tile
+ proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
+ xy_text = xy_text[np.argsort(proj_value)]
+
+ # convert to np and keep the num of point not greater then fixed_point_num
+ pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx
+ point_num = len(pos_info)
+ if point_num > fixed_point_num:
+ keep_ids = [
+ int((point_num * 1.0 / fixed_point_num) * x)
+ for x in range(fixed_point_num)
+ ]
+ pos_info = pos_info[keep_ids, :]
+
+ keep = int(min(len(pos_info), fixed_point_num))
+ if np.random.rand() < 0.2 and reference_height >= 3:
+ dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
+ random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape(
+ [keep, 1])
+ pos_info += random_float
+ pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
+ pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
+
+ # padding to fixed length
+ pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
+ pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
+ pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
+ pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
+ pos_m[:keep] = 1.0
+ return pos_l, pos_m
+
+ def generate_direction_map(self, poly_quads, n_char, direction_map):
+ """
+ """
+ width_list = []
+ height_list = []
+ for quad in poly_quads:
+ quad_w = (np.linalg.norm(quad[0] - quad[1]) +
+ np.linalg.norm(quad[2] - quad[3])) / 2.0
+ quad_h = (np.linalg.norm(quad[0] - quad[3]) +
+ np.linalg.norm(quad[2] - quad[1])) / 2.0
+ width_list.append(quad_w)
+ height_list.append(quad_h)
+ norm_width = max(sum(width_list) / n_char, 1.0)
+ average_height = max(sum(height_list) / len(height_list), 1.0)
+ k = 1
+ for quad in poly_quads:
+ direct_vector_full = (
+ (quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
+ direct_vector = direct_vector_full / (
+ np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
+ direction_label = tuple(
+ map(float,
+ [direct_vector[0], direct_vector[1], 1.0 / average_height]))
+ cv2.fillPoly(direction_map,
+ quad.round().astype(np.int32)[np.newaxis, :, :],
+ direction_label)
+ k += 1
+ return direction_map
+
+ def calculate_average_height(self, poly_quads):
+ """
+ """
+ height_list = []
+ for quad in poly_quads:
+ quad_h = (np.linalg.norm(quad[0] - quad[3]) +
+ np.linalg.norm(quad[2] - quad[1])) / 2.0
+ height_list.append(quad_h)
+ average_height = max(sum(height_list) / len(height_list), 1.0)
+ return average_height
+
+ def generate_tcl_ctc_label(self,
+ h,
+ w,
+ polys,
+ tags,
+ text_strs,
+ ds_ratio,
+ tcl_ratio=0.3,
+ shrink_ratio_of_width=0.15):
+ """
+ Generate polygon.
+ """
+ score_map_big = np.zeros(
+ (
+ h,
+ w, ), dtype=np.float32)
+ h, w = int(h * ds_ratio), int(w * ds_ratio)
+ polys = polys * ds_ratio
+
+ score_map = np.zeros(
+ (
+ h,
+ w, ), dtype=np.float32)
+ score_label_map = np.zeros(
+ (
+ h,
+ w, ), dtype=np.float32)
+ tbo_map = np.zeros((h, w, 5), dtype=np.float32)
+ training_mask = np.ones(
+ (
+ h,
+ w, ), dtype=np.float32)
+ direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
+ [1, 1, 3]).astype(np.float32)
+
+ label_idx = 0
+ score_label_map_text_label_list = []
+ pos_list, pos_mask, label_list = [], [], []
+ for poly_idx, poly_tag in enumerate(zip(polys, tags)):
+ poly = poly_tag[0]
+ tag = poly_tag[1]
+
+ # generate min_area_quad
+ min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
+ min_area_quad_h = 0.5 * (
+ np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
+ np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
+ min_area_quad_w = 0.5 * (
+ np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
+ np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
+
+ if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
+ or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
+ continue
+
+ if tag:
+ cv2.fillPoly(training_mask,
+ poly.astype(np.int32)[np.newaxis, :, :], 0.15)
+ else:
+ text_label = text_strs[poly_idx]
+ text_label = self.prepare_text_label(text_label,
+ self.Lexicon_Table)
+
+ text_label_index_list = [[self.Lexicon_Table.index(c_)]
+ for c_ in text_label
+ if c_ in self.Lexicon_Table]
+ if len(text_label_index_list) < 1:
+ continue
+
+ tcl_poly = self.poly2tcl(poly, tcl_ratio)
+ tcl_quads = self.poly2quads(tcl_poly)
+ poly_quads = self.poly2quads(poly)
+
+ stcl_quads, quad_index = self.shrink_poly_along_width(
+ tcl_quads,
+ shrink_ratio_of_width=shrink_ratio_of_width,
+ expand_height_ratio=1.0 / tcl_ratio)
+
+ cv2.fillPoly(score_map,
+ np.round(stcl_quads).astype(np.int32), 1.0)
+ cv2.fillPoly(score_map_big,
+ np.round(stcl_quads / ds_ratio).astype(np.int32),
+ 1.0)
+
+ for idx, quad in enumerate(stcl_quads):
+ quad_mask = np.zeros((h, w), dtype=np.float32)
+ quad_mask = cv2.fillPoly(
+ quad_mask,
+ np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
+ tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
+ quad_mask, tbo_map)
+
+ # score label map and score_label_map_text_label_list for refine
+ if label_idx == 0:
+ text_pos_list_ = [[len(self.Lexicon_Table)], ]
+ score_label_map_text_label_list.append(text_pos_list_)
+
+ label_idx += 1
+ cv2.fillPoly(score_label_map,
+ np.round(poly_quads).astype(np.int32), label_idx)
+ score_label_map_text_label_list.append(text_label_index_list)
+
+ # direction info, fix-me
+ n_char = len(text_label_index_list)
+ direction_map = self.generate_direction_map(poly_quads, n_char,
+ direction_map)
+
+ # pos info
+ average_shrink_height = self.calculate_average_height(
+ stcl_quads)
+ pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
+ min_area_quad,
+ poly,
+ max_h=h,
+ max_w=w,
+ fixed_point_num=64,
+ img_id=self.img_id,
+ reference_height=average_shrink_height)
+
+ label_l = text_label_index_list
+ if len(text_label_index_list) < 2:
+ continue
+
+ pos_list.append(pos_l)
+ pos_mask.append(pos_m)
+ label_list.append(label_l)
+
+ # use big score_map for smooth tcl lines
+ score_map_big_resized = cv2.resize(
+ score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio)
+ score_map = np.array(score_map_big_resized > 1e-3, dtype='float32')
+
+ return score_map, score_label_map, tbo_map, direction_map, training_mask, \
+ pos_list, pos_mask, label_list, score_label_map_text_label_list
+
+ def adjust_point(self, poly):
+ """
+ adjust point order.
+ """
+ point_num = poly.shape[0]
+ if point_num == 4:
+ len_1 = np.linalg.norm(poly[0] - poly[1])
+ len_2 = np.linalg.norm(poly[1] - poly[2])
+ len_3 = np.linalg.norm(poly[2] - poly[3])
+ len_4 = np.linalg.norm(poly[3] - poly[0])
+
+ if (len_1 + len_3) * 1.5 < (len_2 + len_4):
+ poly = poly[[1, 2, 3, 0], :]
+
+ elif point_num > 4:
+ vector_1 = poly[0] - poly[1]
+ vector_2 = poly[1] - poly[2]
+ cos_theta = np.dot(vector_1, vector_2) / (
+ np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
+ theta = np.arccos(np.round(cos_theta, decimals=4))
+
+ if abs(theta) > (70 / 180 * math.pi):
+ index = list(range(1, point_num)) + [0]
+ poly = poly[np.array(index), :]
+ return poly
+
+ def gen_min_area_quad_from_poly(self, poly):
+ """
+ Generate min area quad from poly.
+ """
+ point_num = poly.shape[0]
+ min_area_quad = np.zeros((4, 2), dtype=np.float32)
+ if point_num == 4:
+ min_area_quad = poly
+ center_point = np.sum(poly, axis=0) / 4
+ else:
+ rect = cv2.minAreaRect(poly.astype(
+ np.int32)) # (center (x,y), (width, height), angle of rotation)
+ center_point = rect[0]
+ box = np.array(cv2.boxPoints(rect))
+
+ first_point_idx = 0
+ min_dist = 1e4
+ for i in range(4):
+ dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
+ np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
+ np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
+ np.linalg.norm(box[(i + 3) % 4] - poly[-1])
+ if dist < min_dist:
+ min_dist = dist
+ first_point_idx = i
+
+ for i in range(4):
+ min_area_quad[i] = box[(first_point_idx + i) % 4]
+
+ return min_area_quad, center_point
+
+ def shrink_quad_along_width(self,
+ quad,
+ begin_width_ratio=0.,
+ end_width_ratio=1.):
+ """
+ Generate shrink_quad_along_width.
+ """
+ ratio_pair = np.array(
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
+
+ def shrink_poly_along_width(self,
+ quads,
+ shrink_ratio_of_width,
+ expand_height_ratio=1.0):
+ """
+ shrink poly with given length.
+ """
+ upper_edge_list = []
+
+ def get_cut_info(edge_len_list, cut_len):
+ for idx, edge_len in enumerate(edge_len_list):
+ cut_len -= edge_len
+ if cut_len <= 0.000001:
+ ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
+ return idx, ratio
+
+ for quad in quads:
+ upper_edge_len = np.linalg.norm(quad[0] - quad[1])
+ upper_edge_list.append(upper_edge_len)
+
+ # length of left edge and right edge.
+ left_length = np.linalg.norm(quads[0][0] - quads[0][
+ 3]) * expand_height_ratio
+ right_length = np.linalg.norm(quads[-1][1] - quads[-1][
+ 2]) * expand_height_ratio
+
+ shrink_length = min(left_length, right_length,
+ sum(upper_edge_list)) * shrink_ratio_of_width
+ # shrinking length
+ upper_len_left = shrink_length
+ upper_len_right = sum(upper_edge_list) - shrink_length
+
+ left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
+ left_quad = self.shrink_quad_along_width(
+ quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
+ right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
+ right_quad = self.shrink_quad_along_width(
+ quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
+
+ out_quad_list = []
+ if left_idx == right_idx:
+ out_quad_list.append(
+ [left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
+ else:
+ out_quad_list.append(left_quad)
+ for idx in range(left_idx + 1, right_idx):
+ out_quad_list.append(quads[idx])
+ out_quad_list.append(right_quad)
+
+ return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
+
+ def prepare_text_label(self, label_str, Lexicon_Table):
+ """
+ Prepare text lablel by given Lexicon_Table.
+ """
+ if len(Lexicon_Table) == 36:
+ return label_str.lower()
+ else:
+ return label_str
+
+ def vector_angle(self, A, B):
+ """
+ Calculate the angle between vector AB and x-axis positive direction.
+ """
+ AB = np.array([B[1] - A[1], B[0] - A[0]])
+ return np.arctan2(*AB)
+
+ def theta_line_cross_point(self, theta, point):
+ """
+ Calculate the line through given point and angle in ax + by + c =0 form.
+ """
+ x, y = point
+ cos = np.cos(theta)
+ sin = np.sin(theta)
+ return [sin, -cos, cos * y - sin * x]
+
+ def line_cross_two_point(self, A, B):
+ """
+ Calculate the line through given point A and B in ax + by + c =0 form.
+ """
+ angle = self.vector_angle(A, B)
+ return self.theta_line_cross_point(angle, A)
+
+ def average_angle(self, poly):
+ """
+ Calculate the average angle between left and right edge in given poly.
+ """
+ p0, p1, p2, p3 = poly
+ angle30 = self.vector_angle(p3, p0)
+ angle21 = self.vector_angle(p2, p1)
+ return (angle30 + angle21) / 2
+
+ def line_cross_point(self, line1, line2):
+ """
+ line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
+ """
+ a1, b1, c1 = line1
+ a2, b2, c2 = line2
+ d = a1 * b2 - a2 * b1
+
+ if d == 0:
+ print('Cross point does not exist')
+ return np.array([0, 0], dtype=np.float32)
+ else:
+ x = (b1 * c2 - b2 * c1) / d
+ y = (a2 * c1 - a1 * c2) / d
+
+ return np.array([x, y], dtype=np.float32)
+
+ def quad2tcl(self, poly, ratio):
+ """
+ Generate center line by poly clock-wise point. (4, 2)
+ """
+ ratio_pair = np.array(
+ [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
+ p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
+ p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
+ return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
+
+ def poly2tcl(self, poly, ratio):
+ """
+ Generate center line by poly clock-wise point.
+ """
+ ratio_pair = np.array(
+ [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
+ tcl_poly = np.zeros_like(poly)
+ point_num = poly.shape[0]
+
+ for idx in range(point_num // 2):
+ point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
+ ) * ratio_pair
+ tcl_poly[idx] = point_pair[0]
+ tcl_poly[point_num - 1 - idx] = point_pair[1]
+ return tcl_poly
+
+ def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
+ """
+ Generate tbo_map for give quad.
+ """
+ # upper and lower line function: ax + by + c = 0;
+ up_line = self.line_cross_two_point(quad[0], quad[1])
+ lower_line = self.line_cross_two_point(quad[3], quad[2])
+
+ quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
+ np.linalg.norm(quad[1] - quad[2]))
+ quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
+ np.linalg.norm(quad[2] - quad[3]))
+
+ # average angle of left and right line.
+ angle = self.average_angle(quad)
+
+ xy_in_poly = np.argwhere(tcl_mask == 1)
+ for y, x in xy_in_poly:
+ point = (x, y)
+ line = self.theta_line_cross_point(angle, point)
+ cross_point_upper = self.line_cross_point(up_line, line)
+ cross_point_lower = self.line_cross_point(lower_line, line)
+ ##FIX, offset reverse
+ upper_offset_x, upper_offset_y = cross_point_upper - point
+ lower_offset_x, lower_offset_y = cross_point_lower - point
+ tbo_map[y, x, 0] = upper_offset_y
+ tbo_map[y, x, 1] = upper_offset_x
+ tbo_map[y, x, 2] = lower_offset_y
+ tbo_map[y, x, 3] = lower_offset_x
+ tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
+ return tbo_map
+
+ def poly2quads(self, poly):
+ """
+ Split poly into quads.
+ """
+ quad_list = []
+ point_num = poly.shape[0]
+
+ # point pair
+ point_pair_list = []
+ for idx in range(point_num // 2):
+ point_pair = [poly[idx], poly[point_num - 1 - idx]]
+ point_pair_list.append(point_pair)
+
+ quad_num = point_num // 2 - 1
+ for idx in range(quad_num):
+ # reshape and adjust to clock-wise
+ quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
+ ).reshape(4, 2)[[0, 2, 3, 1]])
+
+ return np.array(quad_list)
+
+ def rotate_im_poly(self, im, text_polys):
+ """
+ rotate image with 90 / 180 / 270 degre
+ """
+ im_w, im_h = im.shape[1], im.shape[0]
+ dst_im = im.copy()
+ dst_polys = []
+ rand_degree_ratio = np.random.rand()
+ rand_degree_cnt = 1
+ if rand_degree_ratio > 0.5:
+ rand_degree_cnt = 3
+ for i in range(rand_degree_cnt):
+ dst_im = np.rot90(dst_im)
+ rot_degree = -90 * rand_degree_cnt
+ rot_angle = rot_degree * math.pi / 180.0
+ n_poly = text_polys.shape[0]
+ cx, cy = 0.5 * im_w, 0.5 * im_h
+ ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
+ for i in range(n_poly):
+ wordBB = text_polys[i]
+ poly = []
+ for j in range(4): # 16->4
+ sx, sy = wordBB[j][0], wordBB[j][1]
+ dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (
+ sy - cy) + ncx
+ dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (
+ sy - cy) + ncy
+ poly.append([dx, dy])
+ dst_polys.append(poly)
+ return dst_im, np.array(dst_polys, dtype=np.float32)
+
+ def __call__(self, data):
+ input_size = 512
+ im = data['image']
+ text_polys = data['polys']
+ text_tags = data['tags']
+ text_strs = data['strs']
+ h, w, _ = im.shape
+ text_polys, text_tags, hv_tags = self.check_and_validate_polys(
+ text_polys, text_tags, (h, w))
+ if text_polys.shape[0] <= 0:
+ return None
+ # set aspect ratio and keep area fix
+ asp_scales = np.arange(1.0, 1.55, 0.1)
+ asp_scale = np.random.choice(asp_scales)
+ if np.random.rand() < 0.5:
+ asp_scale = 1.0 / asp_scale
+ asp_scale = math.sqrt(asp_scale)
+
+ asp_wx = asp_scale
+ asp_hy = 1.0 / asp_scale
+ im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
+ text_polys[:, :, 0] *= asp_wx
+ text_polys[:, :, 1] *= asp_hy
+
+ h, w, _ = im.shape
+ if max(h, w) > 2048:
+ rd_scale = 2048.0 / max(h, w)
+ im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
+ text_polys *= rd_scale
+ h, w, _ = im.shape
+ if min(h, w) < 16:
+ return None
+
+ # no background
+ im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
+ im,
+ text_polys,
+ text_tags,
+ hv_tags,
+ text_strs,
+ crop_background=False)
+
+ if text_polys.shape[0] == 0:
+ return None
+ # # continue for all ignore case
+ if np.sum((text_tags * 1.0)) >= text_tags.size:
+ return None
+ new_h, new_w, _ = im.shape
+ if (new_h is None) or (new_w is None):
+ return None
+ # resize image
+ std_ratio = float(input_size) / max(new_w, new_h)
+ rand_scales = np.array(
+ [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
+ rz_scale = std_ratio * np.random.choice(rand_scales)
+ im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
+ text_polys[:, :, 0] *= rz_scale
+ text_polys[:, :, 1] *= rz_scale
+
+ # add gaussian blur
+ if np.random.rand() < 0.1 * 0.5:
+ ks = np.random.permutation(5)[0] + 1
+ ks = int(ks / 2) * 2 + 1
+ im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
+ # add brighter
+ if np.random.rand() < 0.1 * 0.5:
+ im = im * (1.0 + np.random.rand() * 0.5)
+ im = np.clip(im, 0.0, 255.0)
+ # add darker
+ if np.random.rand() < 0.1 * 0.5:
+ im = im * (1.0 - np.random.rand() * 0.5)
+ im = np.clip(im, 0.0, 255.0)
+
+ # Padding the im to [input_size, input_size]
+ new_h, new_w, _ = im.shape
+ if min(new_w, new_h) < input_size * 0.5:
+ return None
+ im_padded = np.ones((input_size, input_size, 3), dtype=np.float32)
+ im_padded[:, :, 2] = 0.485 * 255
+ im_padded[:, :, 1] = 0.456 * 255
+ im_padded[:, :, 0] = 0.406 * 255
+
+ # Random the start position
+ del_h = input_size - new_h
+ del_w = input_size - new_w
+ sh, sw = 0, 0
+ if del_h > 1:
+ sh = int(np.random.rand() * del_h)
+ if del_w > 1:
+ sw = int(np.random.rand() * del_w)
+
+ # Padding
+ im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
+ text_polys[:, :, 0] += sw
+ text_polys[:, :, 1] += sh
+
+ score_map, score_label_map, border_map, direction_map, training_mask, \
+ pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size,
+ input_size,
+ text_polys,
+ text_tags,
+ text_strs, 0.25)
+ if len(label_list) <= 0: # eliminate negative samples
+ return None
+ pos_list_temp = np.zeros([64, 3])
+ pos_mask_temp = np.zeros([64, 1])
+ label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
+
+ for i, label in enumerate(label_list):
+ n = len(label)
+ if n > self.max_text_length:
+ label_list[i] = label[:self.max_text_length]
+ continue
+ while n < self.max_text_length:
+ label.append([self.pad_num])
+ n += 1
+
+ for i in range(len(label_list)):
+ label_list[i] = np.array(label_list[i])
+
+ if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
+ return None
+ for __ in range(self.max_text_nums - len(pos_list), 0, -1):
+ pos_list.append(pos_list_temp)
+ pos_mask.append(pos_mask_temp)
+ label_list.append(label_list_temp)
+
+ if self.img_id == self.batch_size - 1:
+ self.img_id = 0
+ else:
+ self.img_id += 1
+
+ im_padded[:, :, 2] -= 0.485 * 255
+ im_padded[:, :, 1] -= 0.456 * 255
+ im_padded[:, :, 0] -= 0.406 * 255
+ im_padded[:, :, 2] /= (255.0 * 0.229)
+ im_padded[:, :, 1] /= (255.0 * 0.224)
+ im_padded[:, :, 0] /= (255.0 * 0.225)
+ im_padded = im_padded.transpose((2, 0, 1))
+ images = im_padded[::-1, :, :]
+ tcl_maps = score_map[np.newaxis, :, :]
+ tcl_label_maps = score_label_map[np.newaxis, :, :]
+ border_maps = border_map.transpose((2, 0, 1))
+ direction_maps = direction_map.transpose((2, 0, 1))
+ training_masks = training_mask[np.newaxis, :, :]
+ pos_list = np.array(pos_list)
+ pos_mask = np.array(pos_mask)
+ label_list = np.array(label_list)
+ data['images'] = images
+ data['tcl_maps'] = tcl_maps
+ data['tcl_label_maps'] = tcl_label_maps
+ data['border_maps'] = border_maps
+ data['direction_maps'] = direction_maps
+ data['training_masks'] = training_masks
+ data['label_list'] = label_list
+ data['pos_list'] = pos_list
+ data['pos_mask'] = pos_mask
+ return data
diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae0638350ad02f10202a67bc6cd531daf742f984
--- /dev/null
+++ b/ppocr/data/pgnet_dataset.py
@@ -0,0 +1,175 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import os
+from paddle.io import Dataset
+from .imaug import transform, create_operators
+import random
+
+
+class PGDataSet(Dataset):
+ def __init__(self, config, mode, logger, seed=None):
+ super(PGDataSet, self).__init__()
+
+ self.logger = logger
+ self.seed = seed
+ self.mode = mode
+ global_config = config['Global']
+ dataset_config = config[mode]['dataset']
+ loader_config = config[mode]['loader']
+
+ label_file_list = dataset_config.pop('label_file_list')
+ data_source_num = len(label_file_list)
+ ratio_list = dataset_config.get("ratio_list", [1.0])
+ if isinstance(ratio_list, (float, int)):
+ ratio_list = [float(ratio_list)] * int(data_source_num)
+ self.data_format = dataset_config.get('data_format', 'icdar')
+ assert len(
+ ratio_list
+ ) == data_source_num, "The length of ratio_list should be the same as the file_list."
+ self.do_shuffle = loader_config['shuffle']
+
+ logger.info("Initialize indexs of datasets:%s" % label_file_list)
+ self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
+ self.data_format)
+ self.data_idx_order_list = list(range(len(self.data_lines)))
+ if mode.lower() == "train":
+ self.shuffle_data_random()
+
+ self.ops = create_operators(dataset_config['transforms'], global_config)
+
+ def shuffle_data_random(self):
+ if self.do_shuffle:
+ random.seed(self.seed)
+ random.shuffle(self.data_lines)
+ return
+
+ def extract_polys(self, poly_txt_path):
+ """
+ Read text_polys, txt_tags, txts from give txt file.
+ """
+ text_polys, txt_tags, txts = [], [], []
+ with open(poly_txt_path) as f:
+ for line in f.readlines():
+ poly_str, txt = line.strip().split('\t')
+ poly = list(map(float, poly_str.split(',')))
+ if self.mode.lower() == "eval":
+ while len(poly) < 100:
+ poly.append(-1)
+ text_polys.append(
+ np.array(
+ poly, dtype=np.float32).reshape(-1, 2))
+ txts.append(txt)
+ txt_tags.append(txt == '###')
+
+ return np.array(list(map(np.array, text_polys))), \
+ np.array(txt_tags, dtype=np.bool), txts
+
+ def extract_info_textnet(self, im_fn, img_dir=''):
+ """
+ Extract information from line in textnet format.
+ """
+ info_list = im_fn.split('\t')
+ img_path = ''
+ for ext in [
+ 'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
+ ]:
+ if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)):
+ img_path = os.path.join(img_dir, info_list[0] + "." + ext)
+ break
+
+ if img_path == '':
+ print('Image {0} NOT found in {1}, and it will be ignored.'.format(
+ info_list[0], img_dir))
+
+ nBox = (len(info_list) - 1) // 9
+ wordBBs, txts, txt_tags = [], [], []
+ for n in range(0, nBox):
+ wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
+ txt = info_list[(n + 1) * 9]
+ wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
+ [wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
+ txts.append(txt)
+ if txt == '###':
+ txt_tags.append(True)
+ else:
+ txt_tags.append(False)
+ return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
+
+ def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
+ if isinstance(file_list, str):
+ file_list = [file_list]
+ data_lines = []
+ for idx, data_source in enumerate(file_list):
+ image_files = []
+ if data_format == 'icdar':
+ image_files = [(data_source, x) for x in
+ os.listdir(os.path.join(data_source, 'rgb'))
+ if x.split('.')[-1] in [
+ 'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif',
+ 'tiff', 'gif', 'JPG'
+ ]]
+ elif data_format == 'textnet':
+ with open(data_source) as f:
+ image_files = [(data_source, x.strip())
+ for x in f.readlines()]
+ else:
+ print("Unrecognized data format...")
+ exit(-1)
+ random.seed(self.seed)
+ image_files = random.sample(
+ image_files, round(len(image_files) * ratio_list[idx]))
+ data_lines.extend(image_files)
+ return data_lines
+
+ def __getitem__(self, idx):
+ file_idx = self.data_idx_order_list[idx]
+ data_path, data_line = self.data_lines[file_idx]
+ try:
+ if self.data_format == 'icdar':
+ im_path = os.path.join(data_path, 'rgb', data_line)
+ if self.mode.lower() == "eval":
+ poly_path = os.path.join(data_path, 'poly_gt',
+ data_line.split('.')[0] + '.txt')
+ else:
+ poly_path = os.path.join(data_path, 'poly',
+ data_line.split('.')[0] + '.txt')
+ text_polys, text_tags, text_strs = self.extract_polys(poly_path)
+ else:
+ image_dir = os.path.join(os.path.dirname(data_path), 'image')
+ im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
+ data_line, image_dir)
+
+ data = {
+ 'img_path': im_path,
+ 'polys': text_polys,
+ 'tags': text_tags,
+ 'strs': text_strs
+ }
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ outs = transform(data, self.ops)
+
+ except Exception as e:
+ self.logger.error(
+ "When parsing line {}, error happened with msg: {}".format(
+ self.data_idx_order_list[idx], e))
+ outs = None
+ if outs is None:
+ return self.__getitem__(np.random.randint(self.__len__()))
+ return outs
+
+ def __len__(self):
+ return len(self.data_idx_order_list)
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 3881abf7741b8be78306bd070afb11df15606327..223ae6b1da996478ac607e29dd37173ca51d9903 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -29,10 +29,11 @@ def build_loss(config):
# cls loss
from .cls_loss import ClsLoss
+ # e2e loss
+ from .e2e_pg_loss import PGLoss
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
- 'SRNLoss'
- ]
+ 'SRNLoss', 'PGLoss']
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/e2e_pg_loss.py b/ppocr/losses/e2e_pg_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..10a8ed0aa907123b155976ba498426604f23c2b0
--- /dev/null
+++ b/ppocr/losses/e2e_pg_loss.py
@@ -0,0 +1,140 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn
+import paddle
+
+from .det_basic_loss import DiceLoss
+from ppocr.utils.e2e_utils.extract_batchsize import pre_process
+
+
+class PGLoss(nn.Layer):
+ def __init__(self,
+ tcl_bs,
+ max_text_length,
+ max_text_nums,
+ pad_num,
+ eps=1e-6,
+ **kwargs):
+ super(PGLoss, self).__init__()
+ self.tcl_bs = tcl_bs
+ self.max_text_nums = max_text_nums
+ self.max_text_length = max_text_length
+ self.pad_num = pad_num
+ self.dice_loss = DiceLoss(eps=eps)
+
+ def border_loss(self, f_border, l_border, l_score, l_mask):
+ l_border_split, l_border_norm = paddle.tensor.split(
+ l_border, num_or_sections=[4, 1], axis=1)
+ f_border_split = f_border
+ b, c, h, w = l_border_norm.shape
+ l_border_norm_split = paddle.expand(
+ x=l_border_norm, shape=[b, 4 * c, h, w])
+ b, c, h, w = l_score.shape
+ l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w])
+ b, c, h, w = l_mask.shape
+ l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w])
+ border_diff = l_border_split - f_border_split
+ abs_border_diff = paddle.abs(border_diff)
+ border_sign = abs_border_diff < 1.0
+ border_sign = paddle.cast(border_sign, dtype='float32')
+ border_sign.stop_gradient = True
+ border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
+ (abs_border_diff - 0.5) * (1.0 - border_sign)
+ border_out_loss = l_border_norm_split * border_in_loss
+ border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
+ (paddle.sum(l_border_score * l_border_mask) + 1e-5)
+ return border_loss
+
+ def direction_loss(self, f_direction, l_direction, l_score, l_mask):
+ l_direction_split, l_direction_norm = paddle.tensor.split(
+ l_direction, num_or_sections=[2, 1], axis=1)
+ f_direction_split = f_direction
+ b, c, h, w = l_direction_norm.shape
+ l_direction_norm_split = paddle.expand(
+ x=l_direction_norm, shape=[b, 2 * c, h, w])
+ b, c, h, w = l_score.shape
+ l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w])
+ b, c, h, w = l_mask.shape
+ l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w])
+ direction_diff = l_direction_split - f_direction_split
+ abs_direction_diff = paddle.abs(direction_diff)
+ direction_sign = abs_direction_diff < 1.0
+ direction_sign = paddle.cast(direction_sign, dtype='float32')
+ direction_sign.stop_gradient = True
+ direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \
+ (abs_direction_diff - 0.5) * (1.0 - direction_sign)
+ direction_out_loss = l_direction_norm_split * direction_in_loss
+ direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \
+ (paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
+ return direction_loss
+
+ def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
+ f_char = paddle.transpose(f_char, [0, 2, 3, 1])
+ tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
+ tcl_pos = paddle.cast(tcl_pos, dtype=int)
+ f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
+ f_tcl_char = paddle.reshape(f_tcl_char,
+ [-1, 64, 37]) # len(Lexicon_Table)+1
+ f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2)
+ f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
+ b, c, l = tcl_mask.shape
+ tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l])
+ tcl_mask_fg.stop_gradient = True
+ f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
+ -20.0)
+ f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2)
+ f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
+ N, B, _ = f_tcl_char_ld.shape
+ input_lengths = paddle.to_tensor([N] * B, dtype='int64')
+ cost = paddle.nn.functional.ctc_loss(
+ log_probs=f_tcl_char_ld,
+ labels=tcl_label,
+ input_lengths=input_lengths,
+ label_lengths=label_t,
+ blank=self.pad_num,
+ reduction='none')
+ cost = cost.mean()
+ return cost
+
+ def forward(self, predicts, labels):
+ images, tcl_maps, tcl_label_maps, border_maps \
+ , direction_maps, training_masks, label_list, pos_list, pos_mask = labels
+ # for all the batch_size
+ pos_list, pos_mask, label_list, label_t = pre_process(
+ label_list, pos_list, pos_mask, self.max_text_length,
+ self.max_text_nums, self.pad_num, self.tcl_bs)
+
+ f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \
+ predicts['f_char']
+ score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
+ border_loss = self.border_loss(f_border, border_maps, tcl_maps,
+ training_masks)
+ direction_loss = self.direction_loss(f_direction, direction_maps,
+ tcl_maps, training_masks)
+ ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t)
+ loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss
+
+ losses = {
+ 'loss': loss_all,
+ "score_loss": score_loss,
+ "border_loss": border_loss,
+ "direction_loss": direction_loss,
+ "ctc_loss": ctc_loss
+ }
+ return losses
diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py
index a0e7d91207277d5c1696d99473f6bc5f685591fc..f913010dbd994633d3df1cf996abb994d246a11a 100644
--- a/ppocr/metrics/__init__.py
+++ b/ppocr/metrics/__init__.py
@@ -26,8 +26,9 @@ def build_metric(config):
from .det_metric import DetMetric
from .rec_metric import RecMetric
from .cls_metric import ClsMetric
+ from .e2e_metric import E2EMetric
- support_dict = ['DetMetric', 'RecMetric', 'ClsMetric']
+ support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric']
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..684d77421c659d4150ea4a28a99b4ae43d678b69
--- /dev/null
+++ b/ppocr/metrics/e2e_metric.py
@@ -0,0 +1,81 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+__all__ = ['E2EMetric']
+
+from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results
+from ppocr.utils.e2e_utils.extract_textpoint import get_dict
+
+
+class E2EMetric(object):
+ def __init__(self,
+ character_dict_path,
+ main_indicator='f_score_e2e',
+ **kwargs):
+ self.label_list = get_dict(character_dict_path)
+ self.max_index = len(self.label_list)
+ self.main_indicator = main_indicator
+ self.reset()
+
+ def __call__(self, preds, batch, **kwargs):
+ temp_gt_polyons_batch = batch[2]
+ temp_gt_strs_batch = batch[3]
+ ignore_tags_batch = batch[4]
+ gt_polyons_batch = []
+ gt_strs_batch = []
+
+ temp_gt_polyons_batch = temp_gt_polyons_batch[0].tolist()
+ for temp_list in temp_gt_polyons_batch:
+ t = []
+ for index in temp_list:
+ if index[0] != -1 and index[1] != -1:
+ t.append(index)
+ gt_polyons_batch.append(t)
+
+ temp_gt_strs_batch = temp_gt_strs_batch[0].tolist()
+ for temp_list in temp_gt_strs_batch:
+ t = ""
+ for index in temp_list:
+ if index < self.max_index:
+ t += self.label_list[index]
+ gt_strs_batch.append(t)
+
+ for pred, gt_polyons, gt_strs, ignore_tags in zip(
+ [preds], [gt_polyons_batch], [gt_strs_batch], ignore_tags_batch):
+ # prepare gt
+ gt_info_list = [{
+ 'points': gt_polyon,
+ 'text': gt_str,
+ 'ignore': ignore_tag
+ } for gt_polyon, gt_str, ignore_tag in
+ zip(gt_polyons, gt_strs, ignore_tags)]
+ # prepare det
+ e2e_info_list = [{
+ 'points': det_polyon,
+ 'text': pred_str
+ } for det_polyon, pred_str in zip(pred['points'], pred['strs'])]
+ result = get_socre(gt_info_list, e2e_info_list)
+ self.results.append(result)
+
+ def get_metric(self):
+ metircs = combine_results(self.results)
+ self.reset()
+ return metircs
+
+ def reset(self):
+ self.results = [] # clear results
diff --git a/ppocr/metrics/eval_det_iou.py b/ppocr/metrics/eval_det_iou.py
index a2a3f41833a9ef7615b73b70808fcb3ba2f22aa4..0e32b2d19281de9a18a1fe0343bd7e8237825b7b 100644
--- a/ppocr/metrics/eval_det_iou.py
+++ b/ppocr/metrics/eval_det_iou.py
@@ -150,7 +150,7 @@ class DetectionIoUEvaluator(object):
pairs.append({'gt': gtNum, 'det': detNum})
detMatchedNums.append(detNum)
evaluationLog += "Match GT #" + \
- str(gtNum) + " with Det #" + str(detNum) + "\n"
+ str(gtNum) + " with Det #" + str(detNum) + "\n"
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
numDetCare = (len(detPols) - len(detDontCarePolsNum))
@@ -162,7 +162,7 @@ class DetectionIoUEvaluator(object):
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
hmean = 0 if (precision + recall) == 0 else 2.0 * \
- precision * recall / (precision + recall)
+ precision * recall / (precision + recall)
matchedSum += detMatched
numGlobalCareGt += numGtCare
@@ -200,7 +200,8 @@ class DetectionIoUEvaluator(object):
methodPrecision = 0 if numGlobalCareDet == 0 else float(
matchedSum) / numGlobalCareDet
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
- methodRecall * methodPrecision / (methodRecall + methodPrecision)
+ methodRecall * methodPrecision / (
+ methodRecall + methodPrecision)
# print(methodRecall, methodPrecision, methodHmean)
# sys.exit(-1)
methodMetrics = {
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 03c15508a58313b234a72bb3ef47ac27dc3ebb7e..fe2c9bc30a4f2abd1ba7d3d6989b9ef9b20c1f4f 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -26,6 +26,9 @@ def build_backbone(config, model_type):
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
+ elif model_type == 'e2e':
+ from .e2e_resnet_vd_pg import ResNet
+ support_dict = ['ResNet']
else:
raise NotImplementedError
diff --git a/ppocr/modeling/backbones/e2e_resnet_vd_pg.py b/ppocr/modeling/backbones/e2e_resnet_vd_pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..97afd3460d03dc078b53064fb45b6fb6d3542df9
--- /dev/null
+++ b/ppocr/modeling/backbones/e2e_resnet_vd_pg.py
@@ -0,0 +1,265 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+__all__ = ["ResNet"]
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None,
+ name=None, ):
+ super(ConvBNLayer, self).__init__()
+
+ self.is_vd_mode = is_vd_mode
+ self._pool2d_avg = nn.AvgPool2D(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self._conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False)
+ if name == "conv1":
+ bn_name = "bn_" + name
+ else:
+ bn_name = "bn" + name[3:]
+ self._batch_norm = nn.BatchNorm(
+ out_channels,
+ act=act,
+ param_attr=ParamAttr(name=bn_name + '_scale'),
+ bias_attr=ParamAttr(bn_name + '_offset'),
+ moving_mean_name=bn_name + '_mean',
+ moving_variance_name=bn_name + '_variance')
+
+ def forward(self, inputs):
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ return y
+
+
+class BottleneckBlock(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None):
+ super(BottleneckBlock, self).__init__()
+
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ act='relu',
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ name=name + "_branch2b")
+ self.conv2 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ act=None,
+ name=name + "_branch2c")
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ stride=stride,
+ is_vd_mode=False if if_first else True,
+ name=name + "_branch1")
+
+ self.shortcut = shortcut
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+ conv2 = self.conv2(conv1)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = paddle.add(x=short, y=conv2)
+ y = F.relu(y)
+ return y
+
+
+class BasicBlock(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False,
+ name=None):
+ super(BasicBlock, self).__init__()
+ self.stride = stride
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ act=None,
+ name=name + "_branch2b")
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ is_vd_mode=False if if_first else True,
+ name=name + "_branch1")
+
+ self.shortcut = shortcut
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = paddle.add(x=short, y=conv1)
+ y = F.relu(y)
+ return y
+
+
+class ResNet(nn.Layer):
+ def __init__(self, in_channels=3, layers=50, **kwargs):
+ super(ResNet, self).__init__()
+
+ self.layers = layers
+ supported_layers = [18, 34, 50, 101, 152, 200]
+ assert layers in supported_layers, \
+ "supported layers are {} but input layer is {}".format(
+ supported_layers, layers)
+
+ if layers == 18:
+ depth = [2, 2, 2, 2]
+ elif layers == 34 or layers == 50:
+ # depth = [3, 4, 6, 3]
+ depth = [3, 4, 6, 3, 3]
+ elif layers == 101:
+ depth = [3, 4, 23, 3]
+ elif layers == 152:
+ depth = [3, 8, 36, 3]
+ elif layers == 200:
+ depth = [3, 12, 48, 3]
+ num_channels = [64, 256, 512, 1024,
+ 2048] if layers >= 50 else [64, 64, 128, 256]
+ num_filters = [64, 128, 256, 512, 512]
+
+ self.conv1_1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=64,
+ kernel_size=7,
+ stride=2,
+ act='relu',
+ name="conv1_1")
+ self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
+
+ self.stages = []
+ self.out_channels = [3, 64]
+ # num_filters = [64, 128, 256, 512, 512]
+ if layers >= 50:
+ for block in range(len(depth)):
+ block_list = []
+ shortcut = False
+ for i in range(depth[block]):
+ if layers in [101, 152] and block == 2:
+ if i == 0:
+ conv_name = "res" + str(block + 2) + "a"
+ else:
+ conv_name = "res" + str(block + 2) + "b" + str(i)
+ else:
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ bottleneck_block = self.add_sublayer(
+ 'bb_%d_%d' % (block, i),
+ BottleneckBlock(
+ in_channels=num_channels[block]
+ if i == 0 else num_filters[block] * 4,
+ out_channels=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut,
+ if_first=block == i == 0,
+ name=conv_name))
+ shortcut = True
+ block_list.append(bottleneck_block)
+ self.out_channels.append(num_filters[block] * 4)
+ self.stages.append(nn.Sequential(*block_list))
+ else:
+ for block in range(len(depth)):
+ block_list = []
+ shortcut = False
+ for i in range(depth[block]):
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ basic_block = self.add_sublayer(
+ 'bb_%d_%d' % (block, i),
+ BasicBlock(
+ in_channels=num_channels[block]
+ if i == 0 else num_filters[block],
+ out_channels=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut,
+ if_first=block == i == 0,
+ name=conv_name))
+ shortcut = True
+ block_list.append(basic_block)
+ self.out_channels.append(num_filters[block])
+ self.stages.append(nn.Sequential(*block_list))
+
+ def forward(self, inputs):
+ out = [inputs]
+ y = self.conv1_1(inputs)
+ out.append(y)
+ y = self.pool2d_max(y)
+ for block in self.stages:
+ y = block(y)
+ out.append(y)
+ return out
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index efe05718506e94a5ae8ad5ff47bcff26d44c1473..4852c7f2d14d72b9e4d59f40532469f7226c966d 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -20,6 +20,7 @@ def build_head(config):
from .det_db_head import DBHead
from .det_east_head import EASTHead
from .det_sast_head import SASTHead
+ from .e2e_pg_head import PGHead
# rec head
from .rec_ctc_head import CTCHead
@@ -30,8 +31,8 @@ def build_head(config):
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
- 'SRNHead'
- ]
+ 'SRNHead', 'PGHead']
+
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0da9de7580a0ceb473f971b2246c966497026a5d
--- /dev/null
+++ b/ppocr/modeling/heads/e2e_pg_head.py
@@ -0,0 +1,253 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None):
+ super(ConvBNLayer, self).__init__()
+ self.if_act = if_act
+ self.act = act
+ self.conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + '_weights'),
+ bias_attr=False)
+
+ self.bn = nn.BatchNorm(
+ num_channels=out_channels,
+ act=act,
+ param_attr=ParamAttr(name="bn_" + name + "_scale"),
+ bias_attr=ParamAttr(name="bn_" + name + "_offset"),
+ moving_mean_name="bn_" + name + "_mean",
+ moving_variance_name="bn_" + name + "_variance",
+ use_global_stats=False)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class PGHead(nn.Layer):
+ """
+ """
+
+ def __init__(self, in_channels, **kwargs):
+ super(PGHead, self).__init__()
+ self.conv_f_score1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=64,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_score{}".format(1))
+ self.conv_f_score2 = ConvBNLayer(
+ in_channels=64,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ act='relu',
+ name="conv_f_score{}".format(2))
+ self.conv_f_score3 = ConvBNLayer(
+ in_channels=64,
+ out_channels=128,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_score{}".format(3))
+
+ self.conv1 = nn.Conv2D(
+ in_channels=128,
+ out_channels=1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=ParamAttr(name="conv_f_score{}".format(4)),
+ bias_attr=False)
+
+ self.conv_f_boder1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=64,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_boder{}".format(1))
+ self.conv_f_boder2 = ConvBNLayer(
+ in_channels=64,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ act='relu',
+ name="conv_f_boder{}".format(2))
+ self.conv_f_boder3 = ConvBNLayer(
+ in_channels=64,
+ out_channels=128,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_boder{}".format(3))
+ self.conv2 = nn.Conv2D(
+ in_channels=128,
+ out_channels=4,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=ParamAttr(name="conv_f_boder{}".format(4)),
+ bias_attr=False)
+ self.conv_f_char1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=128,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_char{}".format(1))
+ self.conv_f_char2 = ConvBNLayer(
+ in_channels=128,
+ out_channels=128,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ act='relu',
+ name="conv_f_char{}".format(2))
+ self.conv_f_char3 = ConvBNLayer(
+ in_channels=128,
+ out_channels=256,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_char{}".format(3))
+ self.conv_f_char4 = ConvBNLayer(
+ in_channels=256,
+ out_channels=256,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ act='relu',
+ name="conv_f_char{}".format(4))
+ self.conv_f_char5 = ConvBNLayer(
+ in_channels=256,
+ out_channels=256,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_char{}".format(5))
+ self.conv3 = nn.Conv2D(
+ in_channels=256,
+ out_channels=37,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=ParamAttr(name="conv_f_char{}".format(6)),
+ bias_attr=False)
+
+ self.conv_f_direc1 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=64,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_direc{}".format(1))
+ self.conv_f_direc2 = ConvBNLayer(
+ in_channels=64,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ act='relu',
+ name="conv_f_direc{}".format(2))
+ self.conv_f_direc3 = ConvBNLayer(
+ in_channels=64,
+ out_channels=128,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ act='relu',
+ name="conv_f_direc{}".format(3))
+ self.conv4 = nn.Conv2D(
+ in_channels=128,
+ out_channels=2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
+ bias_attr=False)
+
+ def forward(self, x):
+ f_score = self.conv_f_score1(x)
+ f_score = self.conv_f_score2(f_score)
+ f_score = self.conv_f_score3(f_score)
+ f_score = self.conv1(f_score)
+ f_score = F.sigmoid(f_score)
+
+ # f_border
+ f_border = self.conv_f_boder1(x)
+ f_border = self.conv_f_boder2(f_border)
+ f_border = self.conv_f_boder3(f_border)
+ f_border = self.conv2(f_border)
+
+ f_char = self.conv_f_char1(x)
+ f_char = self.conv_f_char2(f_char)
+ f_char = self.conv_f_char3(f_char)
+ f_char = self.conv_f_char4(f_char)
+ f_char = self.conv_f_char5(f_char)
+ f_char = self.conv3(f_char)
+
+ f_direction = self.conv_f_direc1(x)
+ f_direction = self.conv_f_direc2(f_direction)
+ f_direction = self.conv_f_direc3(f_direction)
+ f_direction = self.conv4(f_direction)
+
+ predicts = {}
+ predicts['f_score'] = f_score
+ predicts['f_border'] = f_border
+ predicts['f_char'] = f_char
+ predicts['f_direction'] = f_direction
+ return predicts
diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py
index 405e062b352da759b743d5997fd6e4c8b89c038b..37a5cf7863cb386884d82ed88c756c9fc06a541d 100644
--- a/ppocr/modeling/necks/__init__.py
+++ b/ppocr/modeling/necks/__init__.py
@@ -14,12 +14,14 @@
__all__ = ['build_neck']
+
def build_neck(config):
from .db_fpn import DBFPN
from .east_fpn import EASTFPN
from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder
- support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder']
+ from .pg_fpn import PGFPN
+ support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
diff --git a/ppocr/modeling/necks/pg_fpn.py b/ppocr/modeling/necks/pg_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f64539f790b55bb1f95adc8d3c78b84ca2fccc5
--- /dev/null
+++ b/ppocr/modeling/necks/pg_fpn.py
@@ -0,0 +1,314 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None,
+ name=None):
+ super(ConvBNLayer, self).__init__()
+
+ self.is_vd_mode = is_vd_mode
+ self._pool2d_avg = nn.AvgPool2D(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self._conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + "_weights"),
+ bias_attr=False)
+ if name == "conv1":
+ bn_name = "bn_" + name
+ else:
+ bn_name = "bn" + name[3:]
+ self._batch_norm = nn.BatchNorm(
+ out_channels,
+ act=act,
+ param_attr=ParamAttr(name=bn_name + '_scale'),
+ bias_attr=ParamAttr(bn_name + '_offset'),
+ moving_mean_name=bn_name + '_mean',
+ moving_variance_name=bn_name + '_variance',
+ use_global_stats=False)
+
+ def forward(self, inputs):
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ return y
+
+
+class DeConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ groups=1,
+ if_act=True,
+ act=None,
+ name=None):
+ super(DeConvBNLayer, self).__init__()
+
+ self.if_act = if_act
+ self.act = act
+ self.deconv = nn.Conv2DTranspose(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + '_weights'),
+ bias_attr=False)
+ self.bn = nn.BatchNorm(
+ num_channels=out_channels,
+ act=act,
+ param_attr=ParamAttr(name="bn_" + name + "_scale"),
+ bias_attr=ParamAttr(name="bn_" + name + "_offset"),
+ moving_mean_name="bn_" + name + "_mean",
+ moving_variance_name="bn_" + name + "_variance",
+ use_global_stats=False)
+
+ def forward(self, x):
+ x = self.deconv(x)
+ x = self.bn(x)
+ return x
+
+
+class PGFPN(nn.Layer):
+ def __init__(self, in_channels, **kwargs):
+ super(PGFPN, self).__init__()
+ num_inputs = [2048, 2048, 1024, 512, 256]
+ num_outputs = [256, 256, 192, 192, 128]
+ self.out_channels = 128
+ self.conv_bn_layer_1 = ConvBNLayer(
+ in_channels=3,
+ out_channels=32,
+ kernel_size=3,
+ stride=1,
+ act=None,
+ name='FPN_d1')
+ self.conv_bn_layer_2 = ConvBNLayer(
+ in_channels=64,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ act=None,
+ name='FPN_d2')
+ self.conv_bn_layer_3 = ConvBNLayer(
+ in_channels=256,
+ out_channels=128,
+ kernel_size=3,
+ stride=1,
+ act=None,
+ name='FPN_d3')
+ self.conv_bn_layer_4 = ConvBNLayer(
+ in_channels=32,
+ out_channels=64,
+ kernel_size=3,
+ stride=2,
+ act=None,
+ name='FPN_d4')
+ self.conv_bn_layer_5 = ConvBNLayer(
+ in_channels=64,
+ out_channels=64,
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name='FPN_d5')
+ self.conv_bn_layer_6 = ConvBNLayer(
+ in_channels=64,
+ out_channels=128,
+ kernel_size=3,
+ stride=2,
+ act=None,
+ name='FPN_d6')
+ self.conv_bn_layer_7 = ConvBNLayer(
+ in_channels=128,
+ out_channels=128,
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name='FPN_d7')
+ self.conv_bn_layer_8 = ConvBNLayer(
+ in_channels=128,
+ out_channels=128,
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name='FPN_d8')
+
+ self.conv_h0 = ConvBNLayer(
+ in_channels=num_inputs[0],
+ out_channels=num_outputs[0],
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name="conv_h{}".format(0))
+ self.conv_h1 = ConvBNLayer(
+ in_channels=num_inputs[1],
+ out_channels=num_outputs[1],
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name="conv_h{}".format(1))
+ self.conv_h2 = ConvBNLayer(
+ in_channels=num_inputs[2],
+ out_channels=num_outputs[2],
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name="conv_h{}".format(2))
+ self.conv_h3 = ConvBNLayer(
+ in_channels=num_inputs[3],
+ out_channels=num_outputs[3],
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name="conv_h{}".format(3))
+ self.conv_h4 = ConvBNLayer(
+ in_channels=num_inputs[4],
+ out_channels=num_outputs[4],
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name="conv_h{}".format(4))
+
+ self.dconv0 = DeConvBNLayer(
+ in_channels=num_outputs[0],
+ out_channels=num_outputs[0 + 1],
+ name="dconv_{}".format(0))
+ self.dconv1 = DeConvBNLayer(
+ in_channels=num_outputs[1],
+ out_channels=num_outputs[1 + 1],
+ act=None,
+ name="dconv_{}".format(1))
+ self.dconv2 = DeConvBNLayer(
+ in_channels=num_outputs[2],
+ out_channels=num_outputs[2 + 1],
+ act=None,
+ name="dconv_{}".format(2))
+ self.dconv3 = DeConvBNLayer(
+ in_channels=num_outputs[3],
+ out_channels=num_outputs[3 + 1],
+ act=None,
+ name="dconv_{}".format(3))
+ self.conv_g1 = ConvBNLayer(
+ in_channels=num_outputs[1],
+ out_channels=num_outputs[1],
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv_g{}".format(1))
+ self.conv_g2 = ConvBNLayer(
+ in_channels=num_outputs[2],
+ out_channels=num_outputs[2],
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv_g{}".format(2))
+ self.conv_g3 = ConvBNLayer(
+ in_channels=num_outputs[3],
+ out_channels=num_outputs[3],
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv_g{}".format(3))
+ self.conv_g4 = ConvBNLayer(
+ in_channels=num_outputs[4],
+ out_channels=num_outputs[4],
+ kernel_size=3,
+ stride=1,
+ act='relu',
+ name="conv_g{}".format(4))
+ self.convf = ConvBNLayer(
+ in_channels=num_outputs[4],
+ out_channels=num_outputs[4],
+ kernel_size=1,
+ stride=1,
+ act=None,
+ name="conv_f{}".format(4))
+
+ def forward(self, x):
+ c0, c1, c2, c3, c4, c5, c6 = x
+ # FPN_Down_Fusion
+ f = [c0, c1, c2]
+ g = [None, None, None]
+ h = [None, None, None]
+ h[0] = self.conv_bn_layer_1(f[0])
+ h[1] = self.conv_bn_layer_2(f[1])
+ h[2] = self.conv_bn_layer_3(f[2])
+
+ g[0] = self.conv_bn_layer_4(h[0])
+ g[1] = paddle.add(g[0], h[1])
+ g[1] = F.relu(g[1])
+ g[1] = self.conv_bn_layer_5(g[1])
+ g[1] = self.conv_bn_layer_6(g[1])
+
+ g[2] = paddle.add(g[1], h[2])
+ g[2] = F.relu(g[2])
+ g[2] = self.conv_bn_layer_7(g[2])
+ f_down = self.conv_bn_layer_8(g[2])
+
+ # FPN UP Fusion
+ f1 = [c6, c5, c4, c3, c2]
+ g = [None, None, None, None, None]
+ h = [None, None, None, None, None]
+ h[0] = self.conv_h0(f1[0])
+ h[1] = self.conv_h1(f1[1])
+ h[2] = self.conv_h2(f1[2])
+ h[3] = self.conv_h3(f1[3])
+ h[4] = self.conv_h4(f1[4])
+
+ g[0] = self.dconv0(h[0])
+ g[1] = paddle.add(g[0], h[1])
+ g[1] = F.relu(g[1])
+ g[1] = self.conv_g1(g[1])
+ g[1] = self.dconv1(g[1])
+
+ g[2] = paddle.add(g[1], h[2])
+ g[2] = F.relu(g[2])
+ g[2] = self.conv_g2(g[2])
+ g[2] = self.dconv2(g[2])
+
+ g[3] = paddle.add(g[2], h[3])
+ g[3] = F.relu(g[3])
+ g[3] = self.conv_g3(g[3])
+ g[3] = self.dconv3(g[3])
+
+ g[4] = paddle.add(x=g[3], y=h[4])
+ g[4] = F.relu(g[4])
+ g[4] = self.conv_g4(g[4])
+ f_up = self.convf(g[4])
+ f_common = paddle.add(f_down, f_up)
+ f_common = F.relu(f_common)
+ return f_common
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 0156e438e9e24820943c9e48b04565710ea2fd4b..042654a19d2d2d2f1363fedbb9ac3530696e6903 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -28,10 +28,11 @@ def build_post_process(config, global_config=None):
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess
+ from .pg_postprocess import PGPostProcess
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
- 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
+ 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess'
]
config = copy.deepcopy(config)
diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9c0048f20ff46850ab8a26554af31532c73efd6
--- /dev/null
+++ b/ppocr/postprocess/pg_postprocess.py
@@ -0,0 +1,155 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+
+__dir__ = os.path.dirname(__file__)
+sys.path.append(__dir__)
+sys.path.append(os.path.join(__dir__, '..'))
+
+from ppocr.utils.e2e_utils.extract_textpoint import *
+from ppocr.utils.e2e_utils.visual import *
+import paddle
+
+
+class PGPostProcess(object):
+ """
+ The post process for PGNet.
+ """
+
+ def __init__(self, character_dict_path, valid_set, score_thresh, **kwargs):
+ self.Lexicon_Table = get_dict(character_dict_path)
+ self.valid_set = valid_set
+ self.score_thresh = score_thresh
+
+ # c++ la-nms is faster, but only support python 3.5
+ self.is_python35 = False
+ if sys.version_info.major == 3 and sys.version_info.minor == 5:
+ self.is_python35 = True
+
+ def __call__(self, outs_dict, shape_list):
+ p_score = outs_dict['f_score']
+ p_border = outs_dict['f_border']
+ p_char = outs_dict['f_char']
+ p_direction = outs_dict['f_direction']
+ if isinstance(p_score, paddle.Tensor):
+ p_score = p_score[0].numpy()
+ p_border = p_border[0].numpy()
+ p_direction = p_direction[0].numpy()
+ p_char = p_char[0].numpy()
+ else:
+ p_score = p_score[0]
+ p_border = p_border[0]
+ p_direction = p_direction[0]
+ p_char = p_char[0]
+ src_h, src_w, ratio_h, ratio_w = shape_list[0]
+ is_curved = self.valid_set == "totaltext"
+ instance_yxs_list = generate_pivot_list(
+ p_score,
+ p_char,
+ p_direction,
+ score_thresh=self.score_thresh,
+ is_backbone=True,
+ is_curved=is_curved)
+ p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
+ char_seq_idx_set = []
+ for i in range(len(instance_yxs_list)):
+ gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
+ f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
+ feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
+ feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
+ feature_len = [len(feature_seq[0])]
+ featyre_seq = paddle.to_tensor(feature_seq)
+ feature_len = np.array([feature_len]).astype(np.int64)
+ length = paddle.to_tensor(feature_len)
+ seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
+ input=featyre_seq, blank=36, input_length=length)
+ seq_pred_str = seq_pred[0].numpy().tolist()[0]
+ seq_len = seq_pred[1].numpy()[0][0]
+ temp_t = []
+ for c in seq_pred_str[:seq_len]:
+ temp_t.append(c)
+ char_seq_idx_set.append(temp_t)
+ seq_strs = []
+ for char_idx_set in char_seq_idx_set:
+ pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
+ seq_strs.append(pr_str)
+ poly_list = []
+ keep_str_list = []
+ all_point_list = []
+ all_point_pair_list = []
+ for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
+ if len(yx_center_line) == 1:
+ yx_center_line.append(yx_center_line[-1])
+
+ offset_expand = 1.0
+ if self.valid_set == 'totaltext':
+ offset_expand = 1.2
+
+ point_pair_list = []
+ for batch_id, y, x in yx_center_line:
+ offset = p_border[:, y, x].reshape(2, 2)
+ if offset_expand != 1.0:
+ offset_length = np.linalg.norm(
+ offset, axis=1, keepdims=True)
+ expand_length = np.clip(
+ offset_length * (offset_expand - 1),
+ a_min=0.5,
+ a_max=3.0)
+ offset_detal = offset / offset_length * expand_length
+ offset = offset + offset_detal
+ ori_yx = np.array([y, x], dtype=np.float32)
+ point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
+ [ratio_w, ratio_h]).reshape(-1, 2)
+ point_pair_list.append(point_pair)
+
+ all_point_list.append([
+ int(round(x * 4.0 / ratio_w)),
+ int(round(y * 4.0 / ratio_h))
+ ])
+ all_point_pair_list.append(point_pair.round().astype(np.int32)
+ .tolist())
+
+ detected_poly, pair_length_info = point_pair2poly(point_pair_list)
+ detected_poly = expand_poly_along_width(
+ detected_poly, shrink_ratio_of_width=0.2)
+ detected_poly[:, 0] = np.clip(
+ detected_poly[:, 0], a_min=0, a_max=src_w)
+ detected_poly[:, 1] = np.clip(
+ detected_poly[:, 1], a_min=0, a_max=src_h)
+
+ if len(keep_str) < 2:
+ continue
+
+ keep_str_list.append(keep_str)
+ if self.valid_set == 'partvgg':
+ middle_point = len(detected_poly) // 2
+ detected_poly = detected_poly[
+ [0, middle_point - 1, middle_point, -1], :]
+ poly_list.append(detected_poly)
+ elif self.valid_set == 'totaltext':
+ poly_list.append(detected_poly)
+ else:
+ print('--> Not supported format.')
+ exit(-1)
+ data = {
+ 'points': poly_list,
+ 'strs': keep_str_list,
+ }
+ return data
diff --git a/ppocr/postprocess/sast_postprocess.py b/ppocr/postprocess/sast_postprocess.py
index f011e7e571cf4c2297a81a7f7772aa0c09f0aaf1..bee75c05b1a3ea59193d566f91378c96797f533b 100755
--- a/ppocr/postprocess/sast_postprocess.py
+++ b/ppocr/postprocess/sast_postprocess.py
@@ -18,6 +18,7 @@ from __future__ import print_function
import os
import sys
+
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
@@ -49,12 +50,12 @@ class SASTPostProcess(object):
self.shrink_ratio_of_width = shrink_ratio_of_width
self.expand_scale = expand_scale
self.tcl_map_thresh = tcl_map_thresh
-
+
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
-
+
def point_pair2poly(self, point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
@@ -66,31 +67,42 @@ class SASTPostProcess(object):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2)
-
- def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
+
+ def shrink_quad_along_width(self,
+ quad,
+ begin_width_ratio=0.,
+ end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
- ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ ratio_pair = np.array(
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
-
+
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
- left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_quad = np.array(
+ [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
- (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
- left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
- right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1],
- poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32)
+ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio,
+ 1.0)
+ right_quad = np.array(
+ [
+ poly[point_num // 2 - 2], poly[point_num // 2 - 1],
+ poly[point_num // 2], poly[point_num // 2 + 1]
+ ],
+ dtype=np.float32)
right_ratio = 1.0 + \
- shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
- (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
- right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
+ shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
+ (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
+ right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0,
+ right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
@@ -100,7 +112,7 @@ class SASTPostProcess(object):
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
"""Restore quad."""
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
- xy_text = xy_text[:, ::-1] # (n, 2)
+ xy_text = xy_text[:, ::-1] # (n, 2)
# Sort the text boxes via the y axis
xy_text = xy_text[np.argsort(xy_text[:, 1])]
@@ -112,7 +124,7 @@ class SASTPostProcess(object):
point_num = int(tvo_map.shape[-1] / 2)
assert point_num == 4
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
- xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
+ xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
quads = xy_text_tile - tvo_map
return scores, quads, xy_text
@@ -121,14 +133,12 @@ class SASTPostProcess(object):
"""
compute area of a quad.
"""
- edge = [
- (quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
- (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
- (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
- (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
- ]
+ edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
+ (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
+ (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
+ (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])]
return np.sum(edge) / 2.
-
+
def nms(self, dets):
if self.is_python35:
import lanms
@@ -141,7 +151,7 @@ class SASTPostProcess(object):
"""
Cluster pixels in tcl_map based on quads.
"""
- instance_count = quads.shape[0] + 1 # contain background
+ instance_count = quads.shape[0] + 1 # contain background
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
if instance_count == 1:
return instance_count, instance_label_map
@@ -149,18 +159,19 @@ class SASTPostProcess(object):
# predict text center
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
n = xy_text.shape[0]
- xy_text = xy_text[:, ::-1] # (n, 2)
- tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
+ xy_text = xy_text[:, ::-1] # (n, 2)
+ tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
pred_tc = xy_text - tco
-
+
# get gt text center
m = quads.shape[0]
- gt_tc = np.mean(quads, axis=1) # (m, 2)
+ gt_tc = np.mean(quads, axis=1) # (m, 2)
- pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
- gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
- dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
- xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
+ pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :],
+ (1, m, 1)) # (n, m, 2)
+ gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
+ dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
+ xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
return instance_count, instance_label_map
@@ -169,26 +180,47 @@ class SASTPostProcess(object):
"""
Estimate sample points number.
"""
- eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0
- ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
+ eh = (np.linalg.norm(quad[0] - quad[3]) +
+ np.linalg.norm(quad[1] - quad[2])) / 2.0
+ ew = (np.linalg.norm(quad[0] - quad[1]) +
+ np.linalg.norm(quad[2] - quad[3])) / 2.0
dense_sample_pts_num = max(2, int(ew))
- dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num,
- endpoint=True, dtype=np.float32).astype(np.int32)]
-
- dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
- estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
+ dense_xy_center_line = xy_text[np.linspace(
+ 0,
+ xy_text.shape[0] - 1,
+ dense_sample_pts_num,
+ endpoint=True,
+ dtype=np.float32).astype(np.int32)]
+
+ dense_xy_center_line_diff = dense_xy_center_line[
+ 1:] - dense_xy_center_line[:-1]
+ estimate_arc_len = np.sum(
+ np.linalg.norm(
+ dense_xy_center_line_diff, axis=1))
sample_pts_num = max(2, int(estimate_arc_len / eh))
return sample_pts_num
- def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h,
- shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0):
+ def detect_sast(self,
+ tcl_map,
+ tvo_map,
+ tbo_map,
+ tco_map,
+ ratio_w,
+ ratio_h,
+ src_w,
+ src_h,
+ shrink_ratio_of_width=0.3,
+ tcl_map_thresh=0.5,
+ offset_expand=1.0,
+ out_strid=4.0):
"""
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
"""
# restore quad
- scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
+ scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh,
+ tvo_map)
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
dets = self.nms(dets)
if dets.shape[0] == 0:
@@ -202,7 +234,8 @@ class SASTPostProcess(object):
# instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
- instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map)
+ instance_count, instance_label_map = self.cluster_by_quads_tco(
+ tcl_map, tcl_map_thresh, quads, tco_map)
# restore single poly with tcl instance.
poly_list = []
@@ -212,10 +245,10 @@ class SASTPostProcess(object):
q_area = quad_areas[instance_idx - 1]
if q_area < 5:
continue
-
+
#
- len1 = float(np.linalg.norm(quad[0] -quad[1]))
- len2 = float(np.linalg.norm(quad[1] -quad[2]))
+ len1 = float(np.linalg.norm(quad[0] - quad[1]))
+ len2 = float(np.linalg.norm(quad[1] - quad[2]))
min_len = min(len1, len2)
if min_len < 3:
continue
@@ -225,16 +258,18 @@ class SASTPostProcess(object):
continue
# filter low confidence instance
- xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
+ xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
- # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
+ # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
continue
# sort xy_text
- left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0,
- (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
- right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0,
- (quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
+ left_center_pt = np.array(
+ [[(quad[0, 0] + quad[-1, 0]) / 2.0,
+ (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
+ right_center_pt = np.array(
+ [[(quad[1, 0] + quad[2, 0]) / 2.0,
+ (quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
proj_unit_vec = (right_center_pt - left_center_pt) / \
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
@@ -245,33 +280,45 @@ class SASTPostProcess(object):
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
else:
sample_pts_num = self.sample_pts_num
- xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num,
- endpoint=True, dtype=np.float32).astype(np.int32)]
+ xy_center_line = xy_text[np.linspace(
+ 0,
+ xy_text.shape[0] - 1,
+ sample_pts_num,
+ endpoint=True,
+ dtype=np.float32).astype(np.int32)]
point_pair_list = []
for x, y in xy_center_line:
# get corresponding offset
offset = tbo_map[y, x, :].reshape(2, 2)
if offset_expand != 1.0:
- offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
- expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0)
+ offset_length = np.linalg.norm(
+ offset, axis=1, keepdims=True)
+ expand_length = np.clip(
+ offset_length * (offset_expand - 1),
+ a_min=0.5,
+ a_max=3.0)
offset_detal = offset / offset_length * expand_length
- offset = offset + offset_detal
- # original point
+ offset = offset + offset_detal
+ # original point
ori_yx = np.array([y, x], dtype=np.float32)
- point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2)
+ point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array(
+ [ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
# ndarry: (x, 2), expand poly along width
detected_poly = self.point_pair2poly(point_pair_list)
- detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width)
- detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
- detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
+ detected_poly = self.expand_poly_along_width(detected_poly,
+ shrink_ratio_of_width)
+ detected_poly[:, 0] = np.clip(
+ detected_poly[:, 0], a_min=0, a_max=src_w)
+ detected_poly[:, 1] = np.clip(
+ detected_poly[:, 1], a_min=0, a_max=src_h)
poly_list.append(detected_poly)
return poly_list
- def __call__(self, outs_dict, shape_list):
+ def __call__(self, outs_dict, shape_list):
score_list = outs_dict['f_score']
border_list = outs_dict['f_border']
tvo_list = outs_dict['f_tvo']
@@ -281,20 +328,28 @@ class SASTPostProcess(object):
border_list = border_list.numpy()
tvo_list = tvo_list.numpy()
tco_list = tco_list.numpy()
-
+
img_num = len(shape_list)
poly_lists = []
for ino in range(img_num):
- p_score = score_list[ino].transpose((1,2,0))
- p_border = border_list[ino].transpose((1,2,0))
- p_tvo = tvo_list[ino].transpose((1,2,0))
- p_tco = tco_list[ino].transpose((1,2,0))
+ p_score = score_list[ino].transpose((1, 2, 0))
+ p_border = border_list[ino].transpose((1, 2, 0))
+ p_tvo = tvo_list[ino].transpose((1, 2, 0))
+ p_tco = tco_list[ino].transpose((1, 2, 0))
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
- poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
- shrink_ratio_of_width=self.shrink_ratio_of_width,
- tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale)
+ poly_list = self.detect_sast(
+ p_score,
+ p_tvo,
+ p_border,
+ p_tco,
+ ratio_w,
+ ratio_h,
+ src_w,
+ src_h,
+ shrink_ratio_of_width=self.shrink_ratio_of_width,
+ tcl_map_thresh=self.tcl_map_thresh,
+ offset_expand=self.expand_scale)
poly_lists.append({'points': np.array(poly_list)})
return poly_lists
-
diff --git a/ppocr/utils/dict/arabic_dict.txt b/ppocr/utils/dict/arabic_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e97abf39274df77fbad066ee4635aebc6743140c
--- /dev/null
+++ b/ppocr/utils/dict/arabic_dict.txt
@@ -0,0 +1,162 @@
+
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+ء
+آ
+أ
+ؤ
+إ
+ئ
+ا
+ب
+ة
+ت
+ث
+ج
+ح
+خ
+د
+ذ
+ر
+ز
+س
+ش
+ص
+ض
+ط
+ظ
+ع
+غ
+ف
+ق
+ك
+ل
+م
+ن
+ه
+و
+ى
+ي
+ً
+ٌ
+ٍ
+َ
+ُ
+ِ
+ّ
+ْ
+ٓ
+ٔ
+ٰ
+ٱ
+ٹ
+پ
+چ
+ڈ
+ڑ
+ژ
+ک
+ڭ
+گ
+ں
+ھ
+ۀ
+ہ
+ۂ
+ۃ
+ۆ
+ۇ
+ۈ
+ۋ
+ی
+ې
+ے
+ۓ
+ە
+١
+٢
+٣
+٤
+٥
+٦
+٧
+٨
+٩
diff --git a/ppocr/utils/dict/cyrillic_dict.txt b/ppocr/utils/dict/cyrillic_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2b6f66494d5417e18bbd225719aa72690e09e126
--- /dev/null
+++ b/ppocr/utils/dict/cyrillic_dict.txt
@@ -0,0 +1,163 @@
+
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+Ё
+Є
+І
+Ј
+Љ
+Ў
+А
+Б
+В
+Г
+Д
+Е
+Ж
+З
+И
+Й
+К
+Л
+М
+Н
+О
+П
+Р
+С
+Т
+У
+Ф
+Х
+Ц
+Ч
+Ш
+Щ
+Ъ
+Ы
+Ь
+Э
+Ю
+Я
+а
+б
+в
+г
+д
+е
+ж
+з
+и
+й
+к
+л
+м
+н
+о
+п
+р
+с
+т
+у
+ф
+х
+ц
+ч
+ш
+щ
+ъ
+ы
+ь
+э
+ю
+я
+ё
+ђ
+є
+і
+ј
+љ
+њ
+ћ
+ў
+џ
+Ґ
+ґ
diff --git a/ppocr/utils/dict/devanagari_dict.txt b/ppocr/utils/dict/devanagari_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f55923061bfd480b875bb3679d7a75a9157387a9
--- /dev/null
+++ b/ppocr/utils/dict/devanagari_dict.txt
@@ -0,0 +1,167 @@
+
+!
+#
+$
+%
+&
+'
+(
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+É
+é
+ँ
+ं
+ः
+अ
+आ
+इ
+ई
+उ
+ऊ
+ऋ
+ए
+ऐ
+ऑ
+ओ
+औ
+क
+ख
+ग
+घ
+ङ
+च
+छ
+ज
+झ
+ञ
+ट
+ठ
+ड
+ढ
+ण
+त
+थ
+द
+ध
+न
+ऩ
+प
+फ
+ब
+भ
+म
+य
+र
+ऱ
+ल
+ळ
+व
+श
+ष
+स
+ह
+़
+ा
+ि
+ी
+ु
+ू
+ृ
+ॅ
+े
+ै
+ॉ
+ो
+ौ
+्
+॒
+क़
+ख़
+ग़
+ज़
+ड़
+ढ़
+फ़
+ॠ
+।
+०
+१
+२
+३
+४
+५
+६
+७
+८
+९
+॰
diff --git a/ppocr/utils/dict/latin_dict.txt b/ppocr/utils/dict/latin_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e166bf33ecfbdc90ddb3d9743fded23306acabd5
--- /dev/null
+++ b/ppocr/utils/dict/latin_dict.txt
@@ -0,0 +1,185 @@
+
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+[
+]
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+}
+¡
+£
+§
+ª
+«
+
+°
+²
+³
+´
+µ
+·
+º
+»
+¿
+À
+Á
+Â
+Ä
+Å
+Ç
+È
+É
+Ê
+Ë
+Ì
+Í
+Î
+Ï
+Ò
+Ó
+Ô
+Õ
+Ö
+Ú
+Ü
+Ý
+ß
+à
+á
+â
+ã
+ä
+å
+æ
+ç
+è
+é
+ê
+ë
+ì
+í
+î
+ï
+ñ
+ò
+ó
+ô
+õ
+ö
+ø
+ù
+ú
+û
+ü
+ý
+ą
+Ć
+ć
+Č
+č
+Đ
+đ
+ę
+ı
+Ł
+ł
+ō
+Œ
+œ
+Š
+š
+Ÿ
+Ž
+ž
+ʒ
+β
+δ
+ε
+з
+Ṡ
+‘
+€
+™
diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py
new file mode 100755
index 0000000000000000000000000000000000000000..8033a9ff9f1f55200d43472f405d5805e238085b
--- /dev/null
+++ b/ppocr/utils/e2e_metric/Deteval.py
@@ -0,0 +1,458 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
+
+
+def get_socre(gt_dict, pred_dict):
+ allInputs = 1
+
+ def input_reading_mod(pred_dict):
+ """This helper reads input from txt files"""
+ det = []
+ n = len(pred_dict)
+ for i in range(n):
+ points = pred_dict[i]['points']
+ text = pred_dict[i]['text']
+ point = ",".join(map(str, points.reshape(-1, )))
+ det.append([point, text])
+ return det
+
+ def gt_reading_mod(gt_dict):
+ """This helper reads groundtruths from mat files"""
+ gt = []
+ n = len(gt_dict)
+ for i in range(n):
+ points = gt_dict[i]['points']
+ h = len(points)
+ text = gt_dict[i]['text']
+ xx = [
+ np.array(
+ ['x:'], dtype=' 1):
+ gt_x = list(map(int, np.squeeze(gt[1])))
+ gt_y = list(map(int, np.squeeze(gt[3])))
+ for det_id, detection in enumerate(detections):
+ detection_orig = detection
+ detection = [float(x) for x in detection[0].split(',')]
+ detection = list(map(int, detection))
+ det_x = detection[0::2]
+ det_y = detection[1::2]
+ det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
+ if det_gt_iou > threshold:
+ detections[det_id] = []
+
+ detections[:] = [item for item in detections if item != []]
+ return detections
+
+ def sigma_calculation(det_x, det_y, gt_x, gt_y):
+ """
+ sigma = inter_area / gt_area
+ """
+ return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
+ area(gt_x, gt_y)), 2)
+
+ def tau_calculation(det_x, det_y, gt_x, gt_y):
+ if area(det_x, det_y) == 0.0:
+ return 0
+ return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
+ area(det_x, det_y)), 2)
+
+ ##############################Initialization###################################
+ # global_sigma = []
+ # global_tau = []
+ # global_pred_str = []
+ # global_gt_str = []
+ ###############################################################################
+
+ for input_id in range(allInputs):
+ if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
+ input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
+ input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
+ and (input_id != 'Deteval_result_non_curved.txt'):
+ detections = input_reading_mod(pred_dict)
+ groundtruths = gt_reading_mod(gt_dict)
+ detections = detection_filtering(
+ detections,
+ groundtruths) # filters detections overlapping with DC area
+ dc_id = []
+ for i in range(len(groundtruths)):
+ if groundtruths[i][5] == '#':
+ dc_id.append(i)
+ cnt = 0
+ for a in dc_id:
+ num = a - cnt
+ del groundtruths[num]
+ cnt += 1
+
+ local_sigma_table = np.zeros((len(groundtruths), len(detections)))
+ local_tau_table = np.zeros((len(groundtruths), len(detections)))
+ local_pred_str = {}
+ local_gt_str = {}
+
+ for gt_id, gt in enumerate(groundtruths):
+ if len(detections) > 0:
+ for det_id, detection in enumerate(detections):
+ detection_orig = detection
+ detection = [float(x) for x in detection[0].split(',')]
+ detection = list(map(int, detection))
+ pred_seq_str = detection_orig[1].strip()
+ det_x = detection[0::2]
+ det_y = detection[1::2]
+ gt_x = list(map(int, np.squeeze(gt[1])))
+ gt_y = list(map(int, np.squeeze(gt[3])))
+ gt_seq_str = str(gt[4].tolist()[0])
+
+ local_sigma_table[gt_id, det_id] = sigma_calculation(
+ det_x, det_y, gt_x, gt_y)
+ local_tau_table[gt_id, det_id] = tau_calculation(
+ det_x, det_y, gt_x, gt_y)
+ local_pred_str[det_id] = pred_seq_str
+ local_gt_str[gt_id] = gt_seq_str
+
+ global_sigma = local_sigma_table
+ global_tau = local_tau_table
+ global_pred_str = local_pred_str
+ global_gt_str = local_gt_str
+
+ single_data = {}
+ single_data['sigma'] = global_sigma
+ single_data['global_tau'] = global_tau
+ single_data['global_pred_str'] = global_pred_str
+ single_data['global_gt_str'] = global_gt_str
+ return single_data
+
+
+def combine_results(all_data):
+ tr = 0.7
+ tp = 0.6
+ fsc_k = 0.8
+ k = 2
+ global_sigma = []
+ global_tau = []
+ global_pred_str = []
+ global_gt_str = []
+ for data in all_data:
+ global_sigma.append(data['sigma'])
+ global_tau.append(data['global_tau'])
+ global_pred_str.append(data['global_pred_str'])
+ global_gt_str.append(data['global_gt_str'])
+
+ global_accumulative_recall = 0
+ global_accumulative_precision = 0
+ total_num_gt = 0
+ total_num_det = 0
+ hit_str_count = 0
+ hit_count = 0
+
+ def one_to_one(local_sigma_table, local_tau_table,
+ local_accumulative_recall, local_accumulative_precision,
+ global_accumulative_recall, global_accumulative_precision,
+ gt_flag, det_flag, idy):
+ hit_str_num = 0
+ for gt_id in range(num_gt):
+ gt_matching_qualified_sigma_candidates = np.where(
+ local_sigma_table[gt_id, :] > tr)
+ gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
+ 0].shape[0]
+ gt_matching_qualified_tau_candidates = np.where(
+ local_tau_table[gt_id, :] > tp)
+ gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
+ 0].shape[0]
+
+ det_matching_qualified_sigma_candidates = np.where(
+ local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
+ > tr)
+ det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
+ 0].shape[0]
+ det_matching_qualified_tau_candidates = np.where(
+ local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
+ tp)
+ det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
+ 0].shape[0]
+
+ if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
+ (det_matching_num_qualified_sigma_candidates == 1) and (
+ det_matching_num_qualified_tau_candidates == 1):
+ global_accumulative_recall = global_accumulative_recall + 1.0
+ global_accumulative_precision = global_accumulative_precision + 1.0
+ local_accumulative_recall = local_accumulative_recall + 1.0
+ local_accumulative_precision = local_accumulative_precision + 1.0
+
+ gt_flag[0, gt_id] = 1
+ matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
+ # recg start
+ gt_str_cur = global_gt_str[idy][gt_id]
+ pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
+ 0]]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ # recg end
+ det_flag[0, matched_det_id] = 1
+ return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
+
+ def one_to_many(local_sigma_table, local_tau_table,
+ local_accumulative_recall, local_accumulative_precision,
+ global_accumulative_recall, global_accumulative_precision,
+ gt_flag, det_flag, idy):
+ hit_str_num = 0
+ for gt_id in range(num_gt):
+ # skip the following if the groundtruth was matched
+ if gt_flag[0, gt_id] > 0:
+ continue
+
+ non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
+ num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
+
+ if num_non_zero_in_sigma >= k:
+ ####search for all detections that overlaps with this groundtruth
+ qualified_tau_candidates = np.where((local_tau_table[
+ gt_id, :] >= tp) & (det_flag[0, :] == 0))
+ num_qualified_tau_candidates = qualified_tau_candidates[
+ 0].shape[0]
+
+ if num_qualified_tau_candidates == 1:
+ if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
+ and
+ (local_sigma_table[gt_id, qualified_tau_candidates] >=
+ tr)):
+ # became an one-to-one case
+ global_accumulative_recall = global_accumulative_recall + 1.0
+ global_accumulative_precision = global_accumulative_precision + 1.0
+ local_accumulative_recall = local_accumulative_recall + 1.0
+ local_accumulative_precision = local_accumulative_precision + 1.0
+
+ gt_flag[0, gt_id] = 1
+ det_flag[0, qualified_tau_candidates] = 1
+ # recg start
+ gt_str_cur = global_gt_str[idy][gt_id]
+ pred_str_cur = global_pred_str[idy][
+ qualified_tau_candidates[0].tolist()[0]]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ # recg end
+ elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
+ >= tr):
+ gt_flag[0, gt_id] = 1
+ det_flag[0, qualified_tau_candidates] = 1
+ # recg start
+ gt_str_cur = global_gt_str[idy][gt_id]
+ pred_str_cur = global_pred_str[idy][
+ qualified_tau_candidates[0].tolist()[0]]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ # recg end
+
+ global_accumulative_recall = global_accumulative_recall + fsc_k
+ global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
+
+ local_accumulative_recall = local_accumulative_recall + fsc_k
+ local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
+
+ return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
+
+ def many_to_one(local_sigma_table, local_tau_table,
+ local_accumulative_recall, local_accumulative_precision,
+ global_accumulative_recall, global_accumulative_precision,
+ gt_flag, det_flag, idy):
+ hit_str_num = 0
+ for det_id in range(num_det):
+ # skip the following if the detection was matched
+ if det_flag[0, det_id] > 0:
+ continue
+
+ non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
+ num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
+
+ if num_non_zero_in_tau >= k:
+ ####search for all detections that overlaps with this groundtruth
+ qualified_sigma_candidates = np.where((
+ local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
+ num_qualified_sigma_candidates = qualified_sigma_candidates[
+ 0].shape[0]
+
+ if num_qualified_sigma_candidates == 1:
+ if ((local_tau_table[qualified_sigma_candidates, det_id] >=
+ tp) and
+ (local_sigma_table[qualified_sigma_candidates, det_id]
+ >= tr)):
+ # became an one-to-one case
+ global_accumulative_recall = global_accumulative_recall + 1.0
+ global_accumulative_precision = global_accumulative_precision + 1.0
+ local_accumulative_recall = local_accumulative_recall + 1.0
+ local_accumulative_precision = local_accumulative_precision + 1.0
+
+ gt_flag[0, qualified_sigma_candidates] = 1
+ det_flag[0, det_id] = 1
+ # recg start
+ pred_str_cur = global_pred_str[idy][det_id]
+ gt_len = len(qualified_sigma_candidates[0])
+ for idx in range(gt_len):
+ ele_gt_id = qualified_sigma_candidates[0].tolist()[
+ idx]
+ if ele_gt_id not in global_gt_str[idy]:
+ continue
+ gt_str_cur = global_gt_str[idy][ele_gt_id]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ break
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ break
+ # recg end
+ elif (np.sum(local_tau_table[qualified_sigma_candidates,
+ det_id]) >= tp):
+ det_flag[0, det_id] = 1
+ gt_flag[0, qualified_sigma_candidates] = 1
+ # recg start
+ pred_str_cur = global_pred_str[idy][det_id]
+ gt_len = len(qualified_sigma_candidates[0])
+ for idx in range(gt_len):
+ ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
+ if ele_gt_id not in global_gt_str[idy]:
+ continue
+ gt_str_cur = global_gt_str[idy][ele_gt_id]
+ if pred_str_cur == gt_str_cur:
+ hit_str_num += 1
+ break
+ else:
+ if pred_str_cur.lower() == gt_str_cur.lower():
+ hit_str_num += 1
+ break
+ # recg end
+
+ global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
+ global_accumulative_precision = global_accumulative_precision + fsc_k
+
+ local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
+ local_accumulative_precision = local_accumulative_precision + fsc_k
+ return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
+
+ for idx in range(len(global_sigma)):
+ local_sigma_table = np.array(global_sigma[idx])
+ local_tau_table = global_tau[idx]
+
+ num_gt = local_sigma_table.shape[0]
+ num_det = local_sigma_table.shape[1]
+
+ total_num_gt = total_num_gt + num_gt
+ total_num_det = total_num_det + num_det
+
+ local_accumulative_recall = 0
+ local_accumulative_precision = 0
+ gt_flag = np.zeros((1, num_gt))
+ det_flag = np.zeros((1, num_det))
+
+ #######first check for one-to-one case##########
+ local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
+ gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
+ local_accumulative_recall, local_accumulative_precision,
+ global_accumulative_recall, global_accumulative_precision,
+ gt_flag, det_flag, idx)
+
+ hit_str_count += hit_str_num
+ #######then check for one-to-many case##########
+ local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
+ gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
+ local_accumulative_recall, local_accumulative_precision,
+ global_accumulative_recall, global_accumulative_precision,
+ gt_flag, det_flag, idx)
+ hit_str_count += hit_str_num
+ #######then check for many-to-one case##########
+ local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
+ gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
+ local_accumulative_recall, local_accumulative_precision,
+ global_accumulative_recall, global_accumulative_precision,
+ gt_flag, det_flag, idx)
+ hit_str_count += hit_str_num
+
+ try:
+ recall = global_accumulative_recall / total_num_gt
+ except ZeroDivisionError:
+ recall = 0
+
+ try:
+ precision = global_accumulative_precision / total_num_det
+ except ZeroDivisionError:
+ precision = 0
+
+ try:
+ f_score = 2 * precision * recall / (precision + recall)
+ except ZeroDivisionError:
+ f_score = 0
+
+ try:
+ seqerr = 1 - float(hit_str_count) / global_accumulative_recall
+ except ZeroDivisionError:
+ seqerr = 1
+
+ try:
+ recall_e2e = float(hit_str_count) / total_num_gt
+ except ZeroDivisionError:
+ recall_e2e = 0
+
+ try:
+ precision_e2e = float(hit_str_count) / total_num_det
+ except ZeroDivisionError:
+ precision_e2e = 0
+
+ try:
+ f_score_e2e = 2 * precision_e2e * recall_e2e / (
+ precision_e2e + recall_e2e)
+ except ZeroDivisionError:
+ f_score_e2e = 0
+
+ final = {
+ 'total_num_gt': total_num_gt,
+ 'total_num_det': total_num_det,
+ 'global_accumulative_recall': global_accumulative_recall,
+ 'hit_str_count': hit_str_count,
+ 'recall': recall,
+ 'precision': precision,
+ 'f_score': f_score,
+ 'seqerr': seqerr,
+ 'recall_e2e': recall_e2e,
+ 'precision_e2e': precision_e2e,
+ 'f_score_e2e': f_score_e2e
+ }
+ return final
diff --git a/ppocr/utils/e2e_metric/polygon_fast.py b/ppocr/utils/e2e_metric/polygon_fast.py
new file mode 100755
index 0000000000000000000000000000000000000000..81c9ad70675bb37a95968283b6dc6f42f709df27
--- /dev/null
+++ b/ppocr/utils/e2e_metric/polygon_fast.py
@@ -0,0 +1,83 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+from shapely.geometry import Polygon
+"""
+:param det_x: [1, N] Xs of detection's vertices
+:param det_y: [1, N] Ys of detection's vertices
+:param gt_x: [1, N] Xs of groundtruth's vertices
+:param gt_y: [1, N] Ys of groundtruth's vertices
+
+##############
+All the calculation of 'AREA' in this script is handled by:
+1) First generating a binary mask with the polygon area filled up with 1's
+2) Summing up all the 1's
+"""
+
+
+def area(x, y):
+ polygon = Polygon(np.stack([x, y], axis=1))
+ return float(polygon.area)
+
+
+def approx_area_of_intersection(det_x, det_y, gt_x, gt_y):
+ """
+ This helper determine if both polygons are intersecting with each others with an approximation method.
+ Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax]
+ """
+ det_ymax = np.max(det_y)
+ det_xmax = np.max(det_x)
+ det_ymin = np.min(det_y)
+ det_xmin = np.min(det_x)
+
+ gt_ymax = np.max(gt_y)
+ gt_xmax = np.max(gt_x)
+ gt_ymin = np.min(gt_y)
+ gt_xmin = np.min(gt_x)
+
+ all_min_ymax = np.minimum(det_ymax, gt_ymax)
+ all_max_ymin = np.maximum(det_ymin, gt_ymin)
+
+ intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin))
+
+ all_min_xmax = np.minimum(det_xmax, gt_xmax)
+ all_max_xmin = np.maximum(det_xmin, gt_xmin)
+ intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin))
+
+ return intersect_heights * intersect_widths
+
+
+def area_of_intersection(det_x, det_y, gt_x, gt_y):
+ p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
+ p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
+ return float(p1.intersection(p2).area)
+
+
+def area_of_union(det_x, det_y, gt_x, gt_y):
+ p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
+ p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
+ return float(p1.union(p2).area)
+
+
+def iou(det_x, det_y, gt_x, gt_y):
+ return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
+ area_of_union(det_x, det_y, gt_x, gt_y) + 1.0)
+
+
+def iod(det_x, det_y, gt_x, gt_y):
+ """
+ This helper determine the fraction of intersection area over detection area
+ """
+ return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
+ area(det_x, det_y) + 1.0)
diff --git a/ppocr/utils/e2e_utils/extract_batchsize.py b/ppocr/utils/e2e_utils/extract_batchsize.py
new file mode 100644
index 0000000000000000000000000000000000000000..e99a833ea76a81e02d39b16fe1a01e22f15bf3a4
--- /dev/null
+++ b/ppocr/utils/e2e_utils/extract_batchsize.py
@@ -0,0 +1,87 @@
+import paddle
+import numpy as np
+import copy
+
+
+def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
+ """
+ """
+ pos_lists_, pos_masks_, label_lists_ = [], [], []
+ img_bs = batch_size
+ ngpu = int(batch_size / img_bs)
+ img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
+ pos_lists_split, pos_masks_split, label_lists_split = [], [], []
+ for i in range(ngpu):
+ pos_lists_split.append([])
+ pos_masks_split.append([])
+ label_lists_split.append([])
+
+ for i in range(img_ids.shape[0]):
+ img_id = img_ids[i]
+ gpu_id = int(img_id / img_bs)
+ img_id = img_id % img_bs
+ pos_list = pos_lists[i].copy()
+ pos_list[:, 0] = img_id
+ pos_lists_split[gpu_id].append(pos_list)
+ pos_masks_split[gpu_id].append(pos_masks[i].copy())
+ label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
+ # repeat or delete
+ for i in range(ngpu):
+ vp_len = len(pos_lists_split[i])
+ if vp_len <= tcl_bs:
+ for j in range(0, tcl_bs - vp_len):
+ pos_list = pos_lists_split[i][j].copy()
+ pos_lists_split[i].append(pos_list)
+ pos_mask = pos_masks_split[i][j].copy()
+ pos_masks_split[i].append(pos_mask)
+ label_list = copy.deepcopy(label_lists_split[i][j])
+ label_lists_split[i].append(label_list)
+ else:
+ for j in range(0, vp_len - tcl_bs):
+ c_len = len(pos_lists_split[i])
+ pop_id = np.random.permutation(c_len)[0]
+ pos_lists_split[i].pop(pop_id)
+ pos_masks_split[i].pop(pop_id)
+ label_lists_split[i].pop(pop_id)
+ # merge
+ for i in range(ngpu):
+ pos_lists_.extend(pos_lists_split[i])
+ pos_masks_.extend(pos_masks_split[i])
+ label_lists_.extend(label_lists_split[i])
+ return pos_lists_, pos_masks_, label_lists_
+
+
+def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
+ pad_num, tcl_bs):
+ label_list = label_list.numpy()
+ batch, _, _, _ = label_list.shape
+ pos_list = pos_list.numpy()
+ pos_mask = pos_mask.numpy()
+ pos_list_t = []
+ pos_mask_t = []
+ label_list_t = []
+ for i in range(batch):
+ for j in range(max_text_nums):
+ if pos_mask[i, j].any():
+ pos_list_t.append(pos_list[i][j])
+ pos_mask_t.append(pos_mask[i][j])
+ label_list_t.append(label_list[i][j])
+ pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t,
+ label_list_t, tcl_bs)
+ label = []
+ tt = [l.tolist() for l in label_list]
+ for i in range(tcl_bs):
+ k = 0
+ for j in range(max_text_length):
+ if tt[i][j][0] != pad_num:
+ k += 1
+ else:
+ break
+ label.append(k)
+ label = paddle.to_tensor(label)
+ label = paddle.cast(label, dtype='int64')
+ pos_list = paddle.to_tensor(pos_list)
+ pos_mask = paddle.to_tensor(pos_mask)
+ label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
+ label_list = paddle.cast(label_list, dtype='int32')
+ return pos_list, pos_mask, label_list, label
diff --git a/ppocr/utils/e2e_utils/extract_textpoint.py b/ppocr/utils/e2e_utils/extract_textpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..975ca16174f2ee1c7f985a5eb9ae1ec66aa7ca28
--- /dev/null
+++ b/ppocr/utils/e2e_utils/extract_textpoint.py
@@ -0,0 +1,532 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains various CTC decoders."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import cv2
+import math
+
+import numpy as np
+from itertools import groupby
+from skimage.morphology._skeletonize import thin
+
+
+def get_dict(character_dict_path):
+ character_str = ""
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
+ character_str += line
+ dict_character = list(character_str)
+ return dict_character
+
+
+def softmax(logits):
+ """
+ logits: N x d
+ """
+ max_value = np.max(logits, axis=1, keepdims=True)
+ exp = np.exp(logits - max_value)
+ exp_sum = np.sum(exp, axis=1, keepdims=True)
+ dist = exp / exp_sum
+ return dist
+
+
+def get_keep_pos_idxs(labels, remove_blank=None):
+ """
+ Remove duplicate and get pos idxs of keep items.
+ The value of keep_blank should be [None, 95].
+ """
+ duplicate_len_list = []
+ keep_pos_idx_list = []
+ keep_char_idx_list = []
+ for k, v_ in groupby(labels):
+ current_len = len(list(v_))
+ if k != remove_blank:
+ current_idx = int(sum(duplicate_len_list) + current_len // 2)
+ keep_pos_idx_list.append(current_idx)
+ keep_char_idx_list.append(k)
+ duplicate_len_list.append(current_len)
+ return keep_char_idx_list, keep_pos_idx_list
+
+
+def remove_blank(labels, blank=0):
+ new_labels = [x for x in labels if x != blank]
+ return new_labels
+
+
+def insert_blank(labels, blank=0):
+ new_labels = [blank]
+ for l in labels:
+ new_labels += [l, blank]
+ return new_labels
+
+
+def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
+ """
+ CTC greedy (best path) decoder.
+ """
+ raw_str = np.argmax(np.array(probs_seq), axis=1)
+ remove_blank_in_pos = None if keep_blank_in_idxs else blank
+ dedup_str, keep_idx_list = get_keep_pos_idxs(
+ raw_str, remove_blank=remove_blank_in_pos)
+ dst_str = remove_blank(dedup_str, blank=blank)
+ return dst_str, keep_idx_list
+
+
+def instance_ctc_greedy_decoder(gather_info,
+ logits_map,
+ keep_blank_in_idxs=True):
+ """
+ gather_info: [[x, y], [x, y] ...]
+ logits_map: H x W X (n_chars + 1)
+ """
+ _, _, C = logits_map.shape
+ ys, xs = zip(*gather_info)
+ logits_seq = logits_map[list(ys), list(xs)] # n x 96
+ probs_seq = softmax(logits_seq)
+ dst_str, keep_idx_list = ctc_greedy_decoder(
+ probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
+ keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
+ return dst_str, keep_gather_list
+
+
+def ctc_decoder_for_image(gather_info_list, logits_map,
+ keep_blank_in_idxs=True):
+ """
+ CTC decoder using multiple processes.
+ """
+ decoder_results = []
+ for gather_info in gather_info_list:
+ res = instance_ctc_greedy_decoder(
+ gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
+ decoder_results.append(res)
+ return decoder_results
+
+
+def sort_with_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+
+ def sort_part_with_direction(pos_list, point_direction):
+ pos_list = np.array(pos_list).reshape(-1, 2)
+ point_direction = np.array(point_direction).reshape(-1, 2)
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list, sorted_direction
+
+ pos_list = np.array(pos_list).reshape(-1, 2)
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list,
+ point_direction)
+
+ point_num = len(sorted_point)
+ if point_num >= 16:
+ middle_num = point_num // 2
+ first_part_point = sorted_point[:middle_num]
+ first_point_direction = sorted_direction[:middle_num]
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
+ first_part_point, first_point_direction)
+
+ last_part_point = sorted_point[middle_num:]
+ last_point_direction = sorted_direction[middle_num:]
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
+ last_part_point, last_point_direction)
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
+
+ return sorted_point, np.array(sorted_direction)
+
+
+def add_id(pos_list, image_id=0):
+ """
+ Add id for gather feature, for inference.
+ """
+ new_list = []
+ for item in pos_list:
+ new_list.append((image_id, item[0], item[1]))
+ return new_list
+
+
+def sort_and_expand_with_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+ h, w, _ = f_direction.shape
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
+
+ # expand along
+ point_num = len(sorted_list)
+ sub_direction_len = max(point_num // 3, 2)
+ left_direction = point_direction[:sub_direction_len, :]
+ right_dirction = point_direction[point_num - sub_direction_len:, :]
+
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
+ left_average_len = np.linalg.norm(left_average_direction)
+ left_start = np.array(sorted_list[0])
+ left_step = left_average_direction / (left_average_len + 1e-6)
+
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
+ right_average_len = np.linalg.norm(right_average_direction)
+ right_step = right_average_direction / (right_average_len + 1e-6)
+ right_start = np.array(sorted_list[-1])
+
+ append_num = max(
+ int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ left_list = []
+ right_list = []
+ for i in range(append_num):
+ ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ly < h and lx < w and (ly, lx) not in left_list:
+ left_list.append((ly, lx))
+ ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ry < h and rx < w and (ry, rx) not in right_list:
+ right_list.append((ry, rx))
+
+ all_list = left_list[::-1] + sorted_list + right_list
+ return all_list
+
+
+def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ binary_tcl_map: h x w
+ """
+ h, w, _ = f_direction.shape
+ sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
+
+ # expand along
+ point_num = len(sorted_list)
+ sub_direction_len = max(point_num // 3, 2)
+ left_direction = point_direction[:sub_direction_len, :]
+ right_dirction = point_direction[point_num - sub_direction_len:, :]
+
+ left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
+ left_average_len = np.linalg.norm(left_average_direction)
+ left_start = np.array(sorted_list[0])
+ left_step = left_average_direction / (left_average_len + 1e-6)
+
+ right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
+ right_average_len = np.linalg.norm(right_average_direction)
+ right_step = right_average_direction / (right_average_len + 1e-6)
+ right_start = np.array(sorted_list[-1])
+
+ append_num = max(
+ int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
+ max_append_num = 2 * append_num
+
+ left_list = []
+ right_list = []
+ for i in range(max_append_num):
+ ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ly < h and lx < w and (ly, lx) not in left_list:
+ if binary_tcl_map[ly, lx] > 0.5:
+ left_list.append((ly, lx))
+ else:
+ break
+
+ for i in range(max_append_num):
+ ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
+ 'int32').tolist()
+ if ry < h and rx < w and (ry, rx) not in right_list:
+ if binary_tcl_map[ry, rx] > 0.5:
+ right_list.append((ry, rx))
+ else:
+ break
+
+ all_list = left_list[::-1] + sorted_list + right_list
+ return all_list
+
+
+def generate_pivot_list_curved(p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_expand=True,
+ is_backbone=False,
+ image_id=0):
+ """
+ return center point and end point of TCL instance; filter with the char maps;
+ """
+ p_score = p_score[0]
+ f_direction = f_direction.transpose(1, 2, 0)
+ p_tcl_map = (p_score > score_thresh) * 1.0
+ skeleton_map = thin(p_tcl_map)
+ instance_count, instance_label_map = cv2.connectedComponents(
+ skeleton_map.astype(np.uint8), connectivity=8)
+
+ # get TCL Instance
+ all_pos_yxs = []
+ center_pos_yxs = []
+ end_points_yxs = []
+ instance_center_pos_yxs = []
+ if instance_count > 0:
+ for instance_id in range(1, instance_count):
+ pos_list = []
+ ys, xs = np.where(instance_label_map == instance_id)
+ pos_list = list(zip(ys, xs))
+
+ ### FIX-ME, eliminate outlier
+ if len(pos_list) < 3:
+ continue
+
+ if is_expand:
+ pos_list_sorted = sort_and_expand_with_direction_v2(
+ pos_list, f_direction, p_tcl_map)
+ else:
+ pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
+ all_pos_yxs.append(pos_list_sorted)
+
+ # use decoder to filter backgroud points.
+ p_char_maps = p_char_maps.transpose([1, 2, 0])
+ decode_res = ctc_decoder_for_image(
+ all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
+ for decoded_str, keep_yxs_list in decode_res:
+ if is_backbone:
+ keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
+ instance_center_pos_yxs.append(keep_yxs_list_with_id)
+ else:
+ end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
+ center_pos_yxs.extend(keep_yxs_list)
+
+ if is_backbone:
+ return instance_center_pos_yxs
+ else:
+ return center_pos_yxs, end_points_yxs
+
+
+def generate_pivot_list_horizontal(p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_backbone=False,
+ image_id=0):
+ """
+ return center point and end point of TCL instance; filter with the char maps;
+ """
+ p_score = p_score[0]
+ f_direction = f_direction.transpose(1, 2, 0)
+ p_tcl_map_bi = (p_score > score_thresh) * 1.0
+ instance_count, instance_label_map = cv2.connectedComponents(
+ p_tcl_map_bi.astype(np.uint8), connectivity=8)
+
+ # get TCL Instance
+ all_pos_yxs = []
+ center_pos_yxs = []
+ end_points_yxs = []
+ instance_center_pos_yxs = []
+
+ if instance_count > 0:
+ for instance_id in range(1, instance_count):
+ pos_list = []
+ ys, xs = np.where(instance_label_map == instance_id)
+ pos_list = list(zip(ys, xs))
+
+ ### FIX-ME, eliminate outlier
+ if len(pos_list) < 5:
+ continue
+
+ # add rule here
+ main_direction = extract_main_direction(pos_list,
+ f_direction) # y x
+ reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
+ is_h_angle = abs(np.sum(
+ main_direction * reference_directin)) < math.cos(math.pi / 180 *
+ 70)
+
+ point_yxs = np.array(pos_list)
+ max_y, max_x = np.max(point_yxs, axis=0)
+ min_y, min_x = np.min(point_yxs, axis=0)
+ is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
+
+ pos_list_final = []
+ if is_h_len:
+ xs = np.unique(xs)
+ for x in xs:
+ ys = instance_label_map[:, x].copy().reshape((-1, ))
+ y = int(np.where(ys == instance_id)[0].mean())
+ pos_list_final.append((y, x))
+ else:
+ ys = np.unique(ys)
+ for y in ys:
+ xs = instance_label_map[y, :].copy().reshape((-1, ))
+ x = int(np.where(xs == instance_id)[0].mean())
+ pos_list_final.append((y, x))
+
+ pos_list_sorted, _ = sort_with_direction(pos_list_final,
+ f_direction)
+ all_pos_yxs.append(pos_list_sorted)
+
+ # use decoder to filter backgroud points.
+ p_char_maps = p_char_maps.transpose([1, 2, 0])
+ decode_res = ctc_decoder_for_image(
+ all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
+ for decoded_str, keep_yxs_list in decode_res:
+ if is_backbone:
+ keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
+ instance_center_pos_yxs.append(keep_yxs_list_with_id)
+ else:
+ end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
+ center_pos_yxs.extend(keep_yxs_list)
+
+ if is_backbone:
+ return instance_center_pos_yxs
+ else:
+ return center_pos_yxs, end_points_yxs
+
+
+def generate_pivot_list(p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_backbone=False,
+ is_curved=True,
+ image_id=0):
+ """
+ Warp all the function together.
+ """
+ if is_curved:
+ return generate_pivot_list_curved(
+ p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=score_thresh,
+ is_expand=True,
+ is_backbone=is_backbone,
+ image_id=image_id)
+ else:
+ return generate_pivot_list_horizontal(
+ p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=score_thresh,
+ is_backbone=is_backbone,
+ image_id=image_id)
+
+
+# for refine module
+def extract_main_direction(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+ pos_list = np.array(pos_list)
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ average_direction = average_direction / (
+ np.linalg.norm(average_direction) + 1e-6)
+ return average_direction
+
+
+def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
+ """
+ pos_list_full = np.array(pos_list).reshape(-1, 3)
+ pos_list = pos_list_full[:, 1:]
+ point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list
+
+
+def sort_by_direction_with_image_id(pos_list, f_direction):
+ """
+ f_direction: h x w x 2
+ pos_list: [[y, x], [y, x], [y, x] ...]
+ """
+
+ def sort_part_with_direction(pos_list_full, point_direction):
+ pos_list_full = np.array(pos_list_full).reshape(-1, 3)
+ pos_list = pos_list_full[:, 1:]
+ point_direction = np.array(point_direction).reshape(-1, 2)
+ average_direction = np.mean(point_direction, axis=0, keepdims=True)
+ pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
+ sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
+ sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
+ return sorted_list, sorted_direction
+
+ pos_list = np.array(pos_list).reshape(-1, 3)
+ point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
+ point_direction = point_direction[:, ::-1] # x, y -> y, x
+ sorted_point, sorted_direction = sort_part_with_direction(pos_list,
+ point_direction)
+
+ point_num = len(sorted_point)
+ if point_num >= 16:
+ middle_num = point_num // 2
+ first_part_point = sorted_point[:middle_num]
+ first_point_direction = sorted_direction[:middle_num]
+ sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
+ first_part_point, first_point_direction)
+
+ last_part_point = sorted_point[middle_num:]
+ last_point_direction = sorted_direction[middle_num:]
+ sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
+ last_part_point, last_point_direction)
+ sorted_point = sorted_fist_part_point + sorted_last_part_point
+ sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
+
+ return sorted_point
+
+
+def generate_pivot_list_tt_inference(p_score,
+ p_char_maps,
+ f_direction,
+ score_thresh=0.5,
+ is_backbone=False,
+ is_curved=True,
+ image_id=0):
+ """
+ return center point and end point of TCL instance; filter with the char maps;
+ """
+ p_score = p_score[0]
+ f_direction = f_direction.transpose(1, 2, 0)
+ p_tcl_map = (p_score > score_thresh) * 1.0
+ skeleton_map = thin(p_tcl_map)
+ instance_count, instance_label_map = cv2.connectedComponents(
+ skeleton_map.astype(np.uint8), connectivity=8)
+
+ # get TCL Instance
+ all_pos_yxs = []
+ if instance_count > 0:
+ for instance_id in range(1, instance_count):
+ pos_list = []
+ ys, xs = np.where(instance_label_map == instance_id)
+ pos_list = list(zip(ys, xs))
+ ### FIX-ME, eliminate outlier
+ if len(pos_list) < 3:
+ continue
+ pos_list_sorted = sort_and_expand_with_direction_v2(
+ pos_list, f_direction, p_tcl_map)
+ pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
+ all_pos_yxs.append(pos_list_sorted_with_id)
+ return all_pos_yxs
diff --git a/ppocr/utils/e2e_utils/visual.py b/ppocr/utils/e2e_utils/visual.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6e4fd0667dbf4a42dbc0fd9bf26e6fd91be0d82
--- /dev/null
+++ b/ppocr/utils/e2e_utils/visual.py
@@ -0,0 +1,162 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import cv2
+import time
+
+
+def resize_image(im, max_side_len=512):
+ """
+ resize image to a size multiple of max_stride which is required by the network
+ :param im: the resized image
+ :param max_side_len: limit of max image size to avoid out of memory in gpu
+ :return: the resized image and the resize ratio
+ """
+ h, w, _ = im.shape
+
+ resize_w = w
+ resize_h = h
+
+ if resize_h > resize_w:
+ ratio = float(max_side_len) / resize_h
+ else:
+ ratio = float(max_side_len) / resize_w
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+
+ return im, (ratio_h, ratio_w)
+
+
+def resize_image_min(im, max_side_len=512):
+ """
+ """
+ h, w, _ = im.shape
+
+ resize_w = w
+ resize_h = h
+
+ if resize_h < resize_w:
+ ratio = float(max_side_len) / resize_h
+ else:
+ ratio = float(max_side_len) / resize_w
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+ return im, (ratio_h, ratio_w)
+
+
+def resize_image_for_totaltext(im, max_side_len=512):
+ """
+ """
+ h, w, _ = im.shape
+
+ resize_w = w
+ resize_h = h
+ ratio = 1.25
+ if h * ratio > max_side_len:
+ ratio = float(max_side_len) / resize_h
+
+ resize_h = int(resize_h * ratio)
+ resize_w = int(resize_w * ratio)
+
+ max_stride = 128
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(im, (int(resize_w), int(resize_h)))
+ ratio_h = resize_h / float(h)
+ ratio_w = resize_w / float(w)
+ return im, (ratio_h, ratio_w)
+
+
+def point_pair2poly(point_pair_list):
+ """
+ Transfer vertical point_pairs into poly point in clockwise.
+ """
+ pair_length_list = []
+ for point_pair in point_pair_list:
+ pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
+ pair_length_list.append(pair_length)
+ pair_length_list = np.array(pair_length_list)
+ pair_info = (pair_length_list.max(), pair_length_list.min(),
+ pair_length_list.mean())
+
+ point_num = len(point_pair_list) * 2
+ point_list = [0] * point_num
+ for idx, point_pair in enumerate(point_pair_list):
+ point_list[idx] = point_pair[0]
+ point_list[point_num - 1 - idx] = point_pair[1]
+ return np.array(point_list).reshape(-1, 2), pair_info
+
+
+def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
+ """
+ Generate shrink_quad_along_width.
+ """
+ ratio_pair = np.array(
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
+
+
+def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
+ """
+ expand poly along width.
+ """
+ point_num = poly.shape[0]
+ left_quad = np.array(
+ [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
+ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
+ right_quad = np.array(
+ [
+ poly[point_num // 2 - 2], poly[point_num // 2 - 1],
+ poly[point_num // 2], poly[point_num // 2 + 1]
+ ],
+ dtype=np.float32)
+ right_ratio = 1.0 + \
+ shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
+ (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
+ right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
+ poly[0] = left_quad_expand[0]
+ poly[-1] = left_quad_expand[-1]
+ poly[point_num // 2 - 1] = right_quad_expand[1]
+ poly[point_num // 2] = right_quad_expand[2]
+ return poly
+
+
+def norm2(x, axis=None):
+ if axis:
+ return np.sqrt(np.sum(x**2, axis=axis))
+ return np.sqrt(np.sum(x**2))
+
+
+def cos(p1, p2):
+ return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
diff --git a/requirements.txt b/requirements.txt
index 2401d52b48c10bad5ea5b244a0fd4c4365b94f09..1b01e690f77d2bf5e570c86b268c7128c8bf79fb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,4 +7,5 @@ opencv-python==4.2.0.32
tqdm
numpy
visualdl
-python-Levenshtein
\ No newline at end of file
+python-Levenshtein
+opencv-contrib-python
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 70400df484128ba751da5f97503cc7f84e260d86..d491adb17e6251355c0190d0ddecb9a82b09bc2e 100644
--- a/setup.py
+++ b/setup.py
@@ -32,7 +32,7 @@ setup(
package_dir={'paddleocr': ''},
include_package_data=True,
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
- version='2.0.3',
+ version='2.0.4',
install_requires=requirements,
license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
diff --git a/tools/infer/predict_e2e.py b/tools/infer/predict_e2e.py
new file mode 100755
index 0000000000000000000000000000000000000000..a5c57914173b7d44c9479f7bb120e4ff409b91e3
--- /dev/null
+++ b/tools/infer/predict_e2e.py
@@ -0,0 +1,158 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import numpy as np
+import time
+import sys
+
+import tools.infer.utility as utility
+from ppocr.utils.logging import get_logger
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppocr.data import create_operators, transform
+from ppocr.postprocess import build_post_process
+
+logger = get_logger()
+
+
+class TextE2E(object):
+ def __init__(self, args):
+ self.args = args
+ self.e2e_algorithm = args.e2e_algorithm
+ pre_process_list = [{
+ 'E2EResizeForTest': {}
+ }, {
+ 'NormalizeImage': {
+ 'std': [0.229, 0.224, 0.225],
+ 'mean': [0.485, 0.456, 0.406],
+ 'scale': '1./255.',
+ 'order': 'hwc'
+ }
+ }, {
+ 'ToCHWImage': None
+ }, {
+ 'KeepKeys': {
+ 'keep_keys': ['image', 'shape']
+ }
+ }]
+ postprocess_params = {}
+ if self.e2e_algorithm == "PGNet":
+ pre_process_list[0] = {
+ 'E2EResizeForTest': {
+ 'max_side_len': args.e2e_limit_side_len,
+ 'valid_set': 'totaltext'
+ }
+ }
+ postprocess_params['name'] = 'PGPostProcess'
+ postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh
+ postprocess_params["character_dict_path"] = args.e2e_char_dict_path
+ postprocess_params["valid_set"] = args.e2e_pgnet_valid_set
+ self.e2e_pgnet_polygon = args.e2e_pgnet_polygon
+ else:
+ logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm))
+ sys.exit(0)
+
+ self.preprocess_op = create_operators(pre_process_list)
+ self.postprocess_op = build_post_process(postprocess_params)
+ self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
+ args, 'e2e', logger) # paddle.jit.load(args.det_model_dir)
+ # self.predictor.eval()
+
+ def clip_det_res(self, points, img_height, img_width):
+ for pno in range(points.shape[0]):
+ points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
+ points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
+ return points
+
+ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
+ img_height, img_width = image_shape[0:2]
+ dt_boxes_new = []
+ for box in dt_boxes:
+ box = self.clip_det_res(box, img_height, img_width)
+ dt_boxes_new.append(box)
+ dt_boxes = np.array(dt_boxes_new)
+ return dt_boxes
+
+ def __call__(self, img):
+
+ ori_im = img.copy()
+ data = {'image': img}
+ data = transform(data, self.preprocess_op)
+ img, shape_list = data
+ if img is None:
+ return None, 0
+ img = np.expand_dims(img, axis=0)
+ shape_list = np.expand_dims(shape_list, axis=0)
+ img = img.copy()
+ starttime = time.time()
+
+ self.input_tensor.copy_from_cpu(img)
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+
+ preds = {}
+ if self.e2e_algorithm == 'PGNet':
+ preds['f_border'] = outputs[0]
+ preds['f_char'] = outputs[1]
+ preds['f_direction'] = outputs[2]
+ preds['f_score'] = outputs[3]
+ else:
+ raise NotImplementedError
+ post_result = self.postprocess_op(preds, shape_list)
+ points, strs = post_result['points'], post_result['strs']
+ dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
+ elapse = time.time() - starttime
+ return dt_boxes, strs, elapse
+
+
+if __name__ == "__main__":
+ args = utility.parse_args()
+ image_file_list = get_image_file_list(args.image_dir)
+ text_detector = TextE2E(args)
+ count = 0
+ total_time = 0
+ draw_img_save = "./inference_results"
+ if not os.path.exists(draw_img_save):
+ os.makedirs(draw_img_save)
+ for image_file in image_file_list:
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ img = cv2.imread(image_file)
+ if img is None:
+ logger.info("error in loading image:{}".format(image_file))
+ continue
+ points, strs, elapse = text_detector(img)
+ if count > 0:
+ total_time += elapse
+ count += 1
+ logger.info("Predict time of {}: {}".format(image_file, elapse))
+ src_im = utility.draw_e2e_res(points, strs, image_file)
+ img_name_pure = os.path.split(image_file)[-1]
+ img_path = os.path.join(draw_img_save,
+ "e2e_res_{}".format(img_name_pure))
+ cv2.imwrite(img_path, src_im)
+ logger.info("The visualized image saved in {}".format(img_path))
+ if count > 1:
+ logger.info("Avg Time: {}".format(total_time / (count - 1)))
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 911ca7fcb727d9f04002f367cda8928a2aefde90..9019f003b44d9ecb69ed390fba8cc97d4d074cd5 100755
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -74,6 +74,19 @@ def parse_args():
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5)
+ # params for e2e
+ parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
+ parser.add_argument("--e2e_model_dir", type=str)
+ parser.add_argument("--e2e_limit_side_len", type=float, default=768)
+ parser.add_argument("--e2e_limit_type", type=str, default='max')
+
+ # PGNet parmas
+ parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
+ parser.add_argument(
+ "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
+ parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
+ parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True)
+
# params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--cls_model_dir", type=str)
@@ -93,8 +106,10 @@ def create_predictor(args, mode, logger):
model_dir = args.det_model_dir
elif mode == 'cls':
model_dir = args.cls_model_dir
- else:
+ elif mode == 'rec':
model_dir = args.rec_model_dir
+ else:
+ model_dir = args.e2e_model_dir
if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir))
@@ -148,6 +163,22 @@ def create_predictor(args, mode, logger):
return predictor, input_tensor, output_tensors
+def draw_e2e_res(dt_boxes, strs, img_path):
+ src_im = cv2.imread(img_path)
+ for box, str in zip(dt_boxes, strs):
+ box = box.astype(np.int32).reshape((-1, 1, 2))
+ cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
+ cv2.putText(
+ src_im,
+ str,
+ org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
+ fontFace=cv2.FONT_HERSHEY_COMPLEX,
+ fontScale=0.7,
+ color=(0, 255, 0),
+ thickness=1)
+ return src_im
+
+
def draw_text_det_res(dt_boxes, img_path):
src_im = cv2.imread(img_path)
for box in dt_boxes:
diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py
new file mode 100755
index 0000000000000000000000000000000000000000..b7503adb94eb797d4fb12cf47b377fa72d02158b
--- /dev/null
+++ b/tools/infer_e2e.py
@@ -0,0 +1,122 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import json
+import paddle
+
+from ppocr.data import create_operators, transform
+from ppocr.modeling.architectures import build_model
+from ppocr.postprocess import build_post_process
+from ppocr.utils.save_load import init_model
+from ppocr.utils.utility import get_image_file_list
+import tools.program as program
+
+
+def draw_e2e_res(dt_boxes, strs, config, img, img_name):
+ if len(dt_boxes) > 0:
+ src_im = img
+ for box, str in zip(dt_boxes, strs):
+ box = box.astype(np.int32).reshape((-1, 1, 2))
+ cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
+ cv2.putText(
+ src_im,
+ str,
+ org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
+ fontFace=cv2.FONT_HERSHEY_COMPLEX,
+ fontScale=0.7,
+ color=(0, 255, 0),
+ thickness=1)
+ save_det_path = os.path.dirname(config['Global'][
+ 'save_res_path']) + "/e2e_results/"
+ if not os.path.exists(save_det_path):
+ os.makedirs(save_det_path)
+ save_path = os.path.join(save_det_path, os.path.basename(img_name))
+ cv2.imwrite(save_path, src_im)
+ logger.info("The e2e Image saved in {}".format(save_path))
+
+
+def main():
+ global_config = config['Global']
+
+ # build model
+ model = build_model(config['Architecture'])
+
+ init_model(config, model, logger)
+
+ # build post process
+ post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ # create data ops
+ transforms = []
+ for op in config['Eval']['dataset']['transforms']:
+ op_name = list(op)[0]
+ if 'Label' in op_name:
+ continue
+ elif op_name == 'KeepKeys':
+ op[op_name]['keep_keys'] = ['image', 'shape']
+ transforms.append(op)
+
+ ops = create_operators(transforms, global_config)
+
+ save_res_path = config['Global']['save_res_path']
+ if not os.path.exists(os.path.dirname(save_res_path)):
+ os.makedirs(os.path.dirname(save_res_path))
+
+ model.eval()
+ with open(save_res_path, "wb") as fout:
+ for file in get_image_file_list(config['Global']['infer_img']):
+ logger.info("infer_img: {}".format(file))
+ with open(file, 'rb') as f:
+ img = f.read()
+ data = {'image': img}
+ batch = transform(data, ops)
+ images = np.expand_dims(batch[0], axis=0)
+ shape_list = np.expand_dims(batch[1], axis=0)
+ images = paddle.to_tensor(images)
+ preds = model(images)
+ post_result = post_process_class(preds, shape_list)
+ points, strs = post_result['points'], post_result['strs']
+ # write resule
+ dt_boxes_json = []
+ for poly, str in zip(points, strs):
+ tmp_json = {"transcription": str}
+ tmp_json['points'] = poly.tolist()
+ dt_boxes_json.append(tmp_json)
+ otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
+ fout.write(otstr.encode())
+ src_img = cv2.imread(file)
+ draw_e2e_res(points, strs, config, src_img, file)
+ logger.info("success!")
+
+
+if __name__ == '__main__':
+ config, device, logger, vdl_writer = program.preprocess()
+ main()
diff --git a/tools/program.py b/tools/program.py
index cff43102d29e65c5375c001e7ee91ecf309132bc..c22bf18b991a8aed6d47a1ea242aa3b7bb02aacc 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -375,7 +375,8 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
- 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
+ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
+ 'CLS', 'PGNet'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'