diff --git a/README.md b/README.md
index f5bc7406b9649730e43e9fe5cdb5b71eba7dc3aa..06343f5d0bbe7479e2207d88a6e29f8e7fb4215e 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,6 @@ English | [简体中文](README_ch.md)
-
@@ -24,7 +23,8 @@ PaddleOCR aims to create multilingual, awesome, leading, and practical OCR tools
**Recent updates**
-
+- 2021.12.21 OCR open source online course starts. The lesson starts at 8:30 every night and lasts for ten days. Free registration: https://aistudio.baidu.com/aistudio/course/introduce/25207
+- 2021.12.21 release PaddleOCR v2.4, release 1 text detection algorithm (PSENet), 3 text recognition algorithms (NRTR、SEED、SAR), 1 key information extraction algorithm (SDMGR) and 3 DocVQA algorithms (LayoutLM、LayoutLMv2,LayoutXLM).
- PaddleOCR R&D team would like to share the key points of PP-OCRv2, at 20:15 pm on September 8th, [Course Address](https://aistudio.baidu.com/aistudio/education/group/info/6758).
- 2021.9.7 release PaddleOCR v2.3, [PP-OCRv2](#PP-OCRv2) is proposed. The inference speed of PP-OCRv2 is 220% higher than that of PP-OCR server in CPU device. The F-score of PP-OCRv2 is 7% higher than that of PP-OCR mobile.
- 2021.8.3 released PaddleOCR v2.2, add a new structured documents analysis toolkit, i.e., [PP-Structure](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/ppstructure/README.md), support layout analysis and table recognition (One-key to export chart images to Excel files).
@@ -38,7 +38,11 @@ PaddleOCR aims to create multilingual, awesome, leading, and practical OCR tools
- Ultra lightweight PP-OCR mobile series models: detection (3.0M) + direction classifier (1.4M) + recognition (5.0M) = 9.4M
- General PP-OCR server series models: detection (47.1M) + direction classifier (1.4M) + recognition (94.9M) = 143.4M
- Support Chinese, English, and digit recognition, vertical text recognition, and long text recognition
- - Support multi-language recognition: Korean, Japanese, German, French
+ - Support multi-language recognition: about 80 languages like Korean, Japanese, German, French, etc
+- document structurize system PP-Structure
+ - support layout analysis and table recognition (support export to Excel)
+ - support key information extraction
+ - support DocVQA
- Rich toolkits related to the OCR areas
- Semi-automatic data annotation tool, i.e., PPOCRLabel: support fast and efficient data annotation
- Data synthesis tool, i.e., Style-Text: easy to synthesize a large number of images which are similar to the target scene image
diff --git a/README_ch.md b/README_ch.md
index ca757a6356ec44232bc99bbadfa0d0839d751bb5..fae2a50f318b0282009f049373d5e3cf97e407f2 100755
--- a/README_ch.md
+++ b/README_ch.md
@@ -9,7 +9,6 @@
-
@@ -20,11 +19,13 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
## 近期更新
+- 2021.12.21 《OCR十讲》课程开讲,12月21日起每晚八点半线上授课! 【免费】报名地址:https://aistudio.baidu.com/aistudio/course/introduce/25207
+- 2021.12.21 发布PaddleOCR v2.4。OCR算法新增1种文本检测算法(PSENet),3种文本识别算法(NRTR、SEED、SAR);文档结构化算法新增1种关键信息提取算法(SDMGR),3种DocVQA算法(LayoutLM、LayoutLMv2,LayoutXLM)。
- PaddleOCR研发团队对最新发版内容技术深入解读,9月8日晚上20:15,[课程回放](https://aistudio.baidu.com/aistudio/education/group/info/6758)。
- 2021.9.7 发布PaddleOCR v2.3与[PP-OCRv2](#PP-OCRv2),CPU推理速度相比于PP-OCR server提升220%;效果相比于PP-OCR mobile 提升7%。
- 2021.8.3 发布PaddleOCR v2.2,新增文档结构分析[PP-Structure](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/ppstructure/README_ch.md)工具包,支持版面分析与表格识别(含Excel导出)。
-> 完整PaddleOCR更新时间线可参考[文档](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/doc/doc_ch/update.md)。
+> [更多](./doc/doc_ch/update.md)
## 特性
@@ -33,11 +34,14 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
- 超轻量PP-OCR mobile移动端系列:检测(3.0M)+方向分类器(1.4M)+ 识别(5.0M)= 9.4M
- 通用PPOCR server系列:检测(47.1M)+方向分类器(1.4M)+ 识别(94.9M)= 143.4M
- 支持中英文数字组合识别、竖排文本识别、长文本识别
- - 支持多语言识别:韩语、日语、德语、法语等
+ - 支持多语言识别:韩语、日语、德语、法语等约80种语言
+- PP-Structure文档结构化系统
+ - 支持版面分析与表格识别(含Excel导出)
+ - 支持关键信息提取任务
+ - 支持DocVQA任务
- 丰富易用的OCR相关工具组件
- 半自动数据标注工具PPOCRLabel:支持快速高效的数据标注
- 数据合成工具Style-Text:批量合成大量与目标场景类似的图像
- - 文档分析能力PP-Structure:支持版面分析与表格识别(含Excel导出)
- 支持用户自定义训练,提供丰富的预测推理部署方案
- 支持PIP快速安装使用
- 可运行于Linux、Windows、MacOS等多种系统
@@ -56,6 +60,7 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
+
## 零代码体验
- 在线网站体验:超轻量PP-OCR mobile模型体验地址:https://www.paddlepaddle.org.cn/hub/scene/ocr
diff --git a/configs/kie/kie_unet_sdmgr.yml b/configs/kie/kie_unet_sdmgr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a6968aaa3aa7a717a848416efc5ccc567f774b4d
--- /dev/null
+++ b/configs/kie/kie_unet_sdmgr.yml
@@ -0,0 +1,111 @@
+Global:
+ use_gpu: True
+ epoch_num: 60
+ log_smooth_window: 20
+ print_batch_step: 50
+ save_model_dir: ./output/kie_5/
+ save_epoch_step: 50
+ # evaluation is run every 5000 iterations after the 4000th iteration
+ eval_batch_step: [ 0, 80 ]
+ # 1. If pretrained_model is saved in static mode, such as classification pretrained model
+ # from static branch, load_static_weights must be set as True.
+ # 2. If you want to finetune the pretrained models we provide in the docs,
+ # you should set load_static_weights as False.
+ load_static_weights: False
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ class_path: ./train_data/wildreceipt/class_list.txt
+ infer_img: ./train_data/wildreceipt/1.txt
+ save_res_path: ./output/sdmgr_kie/predicts_kie.txt
+ img_scale: [ 1024, 512 ]
+
+Architecture:
+ model_type: kie
+ algorithm: SDMGR
+ Transform:
+ Backbone:
+ name: Kie_backbone
+ Head:
+ name: SDMGRHead
+
+Loss:
+ name: SDMGRLoss
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ lr:
+ name: Piecewise
+ learning_rate: 0.001
+ decay_epochs: [ 60, 80, 100]
+ values: [ 0.001, 0.0001, 0.00001]
+ warmup_epoch: 2
+ regularizer:
+ name: 'L2'
+ factor: 0.00005
+
+PostProcess:
+ name: None
+
+Metric:
+ name: KIEMetric
+ main_indicator: hmean
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/wildreceipt/
+ label_file_list: [ './train_data/wildreceipt/wildreceipt_train.txt' ]
+ ratio_list: [ 1.0 ]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: RGB
+ channel_first: False
+ - NormalizeImage:
+ scale: 1
+ mean: [ 123.675, 116.28, 103.53 ]
+ std: [ 58.395, 57.12, 57.375 ]
+ order: 'hwc'
+ - KieLabelEncode: # Class handling label
+ character_dict_path: ./train_data/wildreceipt/dict.txt
+ - KieResize:
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: [ 'image', 'relations', 'texts', 'points', 'labels', 'tag', 'shape'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ drop_last: False
+ batch_size_per_card: 4
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/wildreceipt
+ label_file_list:
+ - ./train_data/wildreceipt/wildreceipt_test.txt
+ # - /paddle/data/PaddleOCR/train_data/wildreceipt/1.txt
+ transforms:
+ - DecodeImage: # load image
+ img_mode: RGB
+ channel_first: False
+ - KieLabelEncode: # Class handling label
+ character_dict_path: ./train_data/wildreceipt/dict.txt
+ - KieResize:
+ - NormalizeImage:
+ scale: 1
+ mean: [ 123.675, 116.28, 103.53 ]
+ std: [ 58.395, 57.12, 57.375 ]
+ order: 'hwc'
+ - ToCHWImage:
+ - KeepKeys:
+ keep_keys: [ 'image', 'relations', 'texts', 'points', 'labels', 'tag', 'ori_image', 'ori_boxes', 'shape']
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1 # must be 1
+ num_workers: 4
diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml b/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml
index c76063d5cedc31985404ddfff5147e1e0c100d20..3e427b63ae25bdd3f935126443cc5710eb09f5f6 100644
--- a/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml
+++ b/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml
@@ -28,6 +28,7 @@ Optimizer:
lr:
name: Cosine
learning_rate: 0.001
+ warmup_epoch: 5
regularizer:
name: 'L2'
factor: 0.00004
diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml
index 563ce110b865adabf320616227bdf8d2eb465c11..abd5cd9f45a14b1255ab585b18aca336dd48825d 100644
--- a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml
+++ b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml
@@ -28,6 +28,7 @@ Optimizer:
lr:
name: Cosine
learning_rate: 0.001
+ warmup_epoch: 5
regularizer:
name: 'L2'
factor: 0.00001
diff --git a/configs/rec/rec_resnet_stn_bilstm_att.yml b/configs/rec/rec_resnet_stn_bilstm_att.yml
index 0f599258d46e2ce89a6b7deccf8287a2ec0f7e4e..0bb90b35264b424c58a45685f5a2a066843298a6 100644
--- a/configs/rec/rec_resnet_stn_bilstm_att.yml
+++ b/configs/rec/rec_resnet_stn_bilstm_att.yml
@@ -75,7 +75,7 @@ Train:
channel_first: False
- SEEDLabelEncode: # Class handling label
- RecResizeImg:
- character_type: en
+ character_dict_path:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
@@ -96,7 +96,7 @@ Eval:
channel_first: False
- SEEDLabelEncode: # Class handling label
- RecResizeImg:
- character_type: en
+ character_dict_path:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
diff --git a/deploy/cpp_infer/readme.md b/deploy/cpp_infer/readme.md
index 92ef70b642dc1eaed9e694b1ae756f76ee548703..d901366235db21727ceac88528d83ae1120fd030 100644
--- a/deploy/cpp_infer/readme.md
+++ b/deploy/cpp_infer/readme.md
@@ -103,7 +103,7 @@ opencv3/
#### 1.2.1 直接下载安装
-* [Paddle预测库官网](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0/guides/05_inference_deployment/inference/build_and_install_lib_cn.html) 上提供了不同cuda版本的Linux预测库,可以在官网查看并选择合适的预测库版本(*建议选择paddle版本>=2.0.1版本的预测库* )。
+* [Paddle预测库官网](https://paddle-inference.readthedocs.io/en/latest/user_guides/download_lib.html) 上提供了不同cuda版本的Linux预测库,可以在官网查看并选择合适的预测库版本(*建议选择paddle版本>=2.0.1版本的预测库* )。
* 下载之后使用下面的方法解压。
@@ -119,7 +119,7 @@ tar -xf paddle_inference.tgz
```shell
git clone https://github.com/PaddlePaddle/Paddle.git
-git checkout release/2.1
+git checkout develop
```
* 进入Paddle目录后,编译方法如下。
diff --git a/deploy/cpp_infer/readme_en.md b/deploy/cpp_infer/readme_en.md
index fd6d953de1f9168da734d6c5eda945c670cfce37..4daa73453507959ea10e21a7383d03d00aedf438 100644
--- a/deploy/cpp_infer/readme_en.md
+++ b/deploy/cpp_infer/readme_en.md
@@ -79,7 +79,7 @@ opencv3/
#### 1.2.1 Direct download and installation
-[Paddle inference library official website](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0/guides/05_inference_deployment/inference/build_and_install_lib_cn.html). You can view and select the appropriate version of the inference library on the official website.
+[Paddle inference library official website](https://paddle-inference.readthedocs.io/en/latest/user_guides/download_lib.html). You can view and select the appropriate version of the inference library on the official website.
* After downloading, use the following method to uncompress.
@@ -97,7 +97,7 @@ Finally you can see the following files in the folder of `paddle_inference/`.
```shell
git clone https://github.com/PaddlePaddle/Paddle.git
-git checkout release/2.1
+git checkout develop
```
* After entering the Paddle directory, the commands to compile the paddle inference library are as follows.
diff --git a/deploy/pdserving/README.md b/deploy/pdserving/README.md
index cb2845c581d244e80ca597e0eb485a16ad369f20..c461fd5e54d3a51ad3427f83a1fca35cbe3ab2d8 100644
--- a/deploy/pdserving/README.md
+++ b/deploy/pdserving/README.md
@@ -45,63 +45,67 @@ PaddleOCR operating environment and Paddle Serving operating environment are nee
```
3. Install the client to send requests to the service
- In [download link](https://github.com/PaddlePaddle/Serving/blob/develop/doc/LATEST_PACKAGES.md) find the client installation package corresponding to the python version.
- The python3.7 version is recommended here:
- ```
- wget https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_client-0.0.0-cp37-none-any.whl
- pip3 install paddle_serving_client-0.0.0-cp37-none-any.whl
- ```
-
-4. Install serving-app
- ```
- pip3 install paddle-serving-app==0.6.1
- ```
+```bash
+# 安装serving,用于启动服务
+wget https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_server_gpu-0.7.0.post102-py3-none-any.whl
+pip3 install paddle_serving_server_gpu-0.7.0.post102-py3-none-any.whl
+# 如果是cuda10.1环境,可以使用下面的命令安装paddle-serving-server
+# wget https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_server_gpu-0.7.0.post101-py3-none-any.whl
+# pip3 install paddle_serving_server_gpu-0.7.0.post101-py3-none-any.whl
+
+# 安装client,用于向服务发送请求
+wget https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_client-0.7.0-cp37-none-any.whl
+pip3 install paddle_serving_client-0.7.0-cp37-none-any.whl
+
+# 安装serving-app
+wget https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_app-0.7.0-py3-none-any.whl
+pip3 install paddle_serving_app-0.7.0-py3-none-any.whl
+```
- **note:** If you want to install the latest version of PaddleServing, refer to [link](https://github.com/PaddlePaddle/Serving/blob/develop/doc/LATEST_PACKAGES.md).
+ **note:** If you want to install the latest version of PaddleServing, refer to [link](https://github.com/PaddlePaddle/Serving/blob/v0.7.0/doc/Latest_Packages_CN.md).
## Model conversion
When using PaddleServing for service deployment, you need to convert the saved inference model into a serving model that is easy to deploy.
-Firstly, download the [inference model](https://github.com/PaddlePaddle/PaddleOCR#pp-ocr-20-series-model-listupdate-on-dec-15) of PPOCR
+Firstly, download the [inference model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/README_ch.md#pp-ocr%E7%B3%BB%E5%88%97%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8%E6%9B%B4%E6%96%B0%E4%B8%AD) of PPOCR
```
# Download and unzip the OCR text detection model
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar
+wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar -O ch_PP-OCRv2_det_infer.tar && tar -xf ch_PP-OCRv2_det_infer.tar
# Download and unzip the OCR text recognition model
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
-
+wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar -O ch_PP-OCRv2_rec_infer.tar && tar -xf ch_PP-OCRv2_rec_infer.tar
```
Then, you can use installed paddle_serving_client tool to convert inference model to mobile model.
```
# Detection model conversion
-python3 -m paddle_serving_client.convert --dirname ./ch_ppocr_mobile_v2.0_det_infer/ \
+python3 -m paddle_serving_client.convert --dirname ./ch_PP-OCRv2_det_infer/ \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
- --serving_server ./ppocr_det_mobile_2.0_serving/ \
- --serving_client ./ppocr_det_mobile_2.0_client/
+ --serving_server ./ppocrv2_det_serving/ \
+ --serving_client ./ppocrv2_det_client/
# Recognition model conversion
-python3 -m paddle_serving_client.convert --dirname ./ch_ppocr_mobile_v2.0_rec_infer/ \
+python3 -m paddle_serving_client.convert --dirname ./ch_PP-OCRv2_rec_infer/ \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
- --serving_server ./ppocr_rec_mobile_2.0_serving/ \
- --serving_client ./ppocr_rec_mobile_2.0_client/
+ --serving_server ./ppocrv2_rec_serving/ \
+ --serving_client ./ppocrv2_rec_client/
```
After the detection model is converted, there will be additional folders of `ppocr_det_mobile_2.0_serving` and `ppocr_det_mobile_2.0_client` in the current folder, with the following format:
```
-|- ppocr_det_mobile_2.0_serving/
- |- __model__
- |- __params__
- |- serving_server_conf.prototxt
- |- serving_server_conf.stream.prototxt
-
-|- ppocr_det_mobile_2.0_client
- |- serving_client_conf.prototxt
- |- serving_client_conf.stream.prototxt
+|- ppocrv2_det_serving/
+ |- __model__
+ |- __params__
+ |- serving_server_conf.prototxt
+ |- serving_server_conf.stream.prototxt
+
+|- ppocrv2_det_client
+ |- serving_client_conf.prototxt
+ |- serving_client_conf.stream.prototxt
```
The recognition model is the same.
diff --git a/deploy/pdserving/README_CN.md b/deploy/pdserving/README_CN.md
index 067be8bbda10d971b709afdf822aea96a979d000..00024639b0b108225a0835499f62174b6618ae47 100644
--- a/deploy/pdserving/README_CN.md
+++ b/deploy/pdserving/README_CN.md
@@ -34,70 +34,66 @@ PaddleOCR提供2种服务部署方式:
- 准备PaddleServing的运行环境,步骤如下
-1. 安装serving,用于启动服务
- ```
- pip3 install paddle-serving-server==0.6.1 # for CPU
- pip3 install paddle-serving-server-gpu==0.6.1 # for GPU
- # 其他GPU环境需要确认环境再选择执行如下命令
- pip3 install paddle-serving-server-gpu==0.6.1.post101 # GPU with CUDA10.1 + TensorRT6
- pip3 install paddle-serving-server-gpu==0.6.1.post11 # GPU with CUDA11 + TensorRT7
- ```
-
-2. 安装client,用于向服务发送请求
- 在[下载链接](https://github.com/PaddlePaddle/Serving/blob/develop/doc/LATEST_PACKAGES.md)中找到对应python版本的client安装包,这里推荐python3.7版本:
-
- ```
- wget https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_client-0.0.0-cp37-none-any.whl
- pip3 install paddle_serving_client-0.0.0-cp37-none-any.whl
- ```
-
-3. 安装serving-app
- ```
- pip3 install paddle-serving-app==0.6.1
- ```
+```bash
+# 安装serving,用于启动服务
+wget https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_server_gpu-0.7.0.post102-py3-none-any.whl
+pip3 install paddle_serving_server_gpu-0.7.0.post102-py3-none-any.whl
+# 如果是cuda10.1环境,可以使用下面的命令安装paddle-serving-server
+# wget https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_server_gpu-0.7.0.post101-py3-none-any.whl
+# pip3 install paddle_serving_server_gpu-0.7.0.post101-py3-none-any.whl
+
+# 安装client,用于向服务发送请求
+wget https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_client-0.7.0-cp37-none-any.whl
+pip3 install paddle_serving_client-0.7.0-cp37-none-any.whl
+
+# 安装serving-app
+wget https://paddle-serving.bj.bcebos.com/test-dev/whl/paddle_serving_app-0.7.0-py3-none-any.whl
+pip3 install paddle_serving_app-0.7.0-py3-none-any.whl
+```
- **Note:** 如果要安装最新版本的PaddleServing参考[链接](https://github.com/PaddlePaddle/Serving/blob/develop/doc/LATEST_PACKAGES.md)。
+**Note:** 如果要安装最新版本的PaddleServing参考[链接](https://github.com/PaddlePaddle/Serving/blob/v0.7.0/doc/Latest_Packages_CN.md)。
## 模型转换
使用PaddleServing做服务化部署时,需要将保存的inference模型转换为serving易于部署的模型。
-首先,下载PPOCR的[inference模型](https://github.com/PaddlePaddle/PaddleOCR#pp-ocr-20-series-model-listupdate-on-dec-15)
-```
+首先,下载PPOCR的[inference模型](https://github.com/PaddlePaddle/PaddleOCR#pp-ocr-series-model-listupdate-on-september-8th)
+
+```bash
# 下载并解压 OCR 文本检测模型
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar
+wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar -O ch_PP-OCRv2_det_infer.tar && tar -xf ch_PP-OCRv2_det_infer.tar
# 下载并解压 OCR 文本识别模型
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
+wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar -O ch_PP-OCRv2_rec_infer.tar && tar -xf ch_PP-OCRv2_rec_infer.tar
```
接下来,用安装的paddle_serving_client把下载的inference模型转换成易于server部署的模型格式。
-```
+```bash
# 转换检测模型
-python3 -m paddle_serving_client.convert --dirname ./ch_ppocr_mobile_v2.0_det_infer/ \
+python3 -m paddle_serving_client.convert --dirname ./ch_PP-OCRv2_det_infer/ \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
- --serving_server ./ppocr_det_mobile_2.0_serving/ \
- --serving_client ./ppocr_det_mobile_2.0_client/
+ --serving_server ./ppocrv2_det_serving/ \
+ --serving_client ./ppocrv2_det_client/
# 转换识别模型
-python3 -m paddle_serving_client.convert --dirname ./ch_ppocr_mobile_v2.0_rec_infer/ \
+python3 -m paddle_serving_client.convert --dirname ./ch_PP-OCRv2_rec_infer/ \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
- --serving_server ./ppocr_rec_mobile_2.0_serving/ \
- --serving_client ./ppocr_rec_mobile_2.0_client/
+ --serving_server ./ppocrv2_rec_serving/ \
+ --serving_client ./ppocrv2_rec_client/
```
-检测模型转换完成后,会在当前文件夹多出`ppocr_det_mobile_2.0_serving` 和`ppocr_det_mobile_2.0_client`的文件夹,具备如下格式:
+检测模型转换完成后,会在当前文件夹多出`ppocrv2_det_serving` 和`ppocrv2_det_client`的文件夹,具备如下格式:
```
-|- ppocr_det_mobile_2.0_serving/
+|- ppocrv2_det_serving/
|- __model__
|- __params__
|- serving_server_conf.prototxt
|- serving_server_conf.stream.prototxt
-|- ppocr_det_mobile_2.0_client
+|- ppocrv2_det_client
|- serving_client_conf.prototxt
|- serving_client_conf.stream.prototxt
diff --git a/deploy/pdserving/config.yml b/deploy/pdserving/config.yml
index 2aae922dfa12f46d1c0ebd352e8d3a7077065cf8..f3b0f7ec5a47bb9c513ab3d75f7d2d4138f88c4a 100644
--- a/deploy/pdserving/config.yml
+++ b/deploy/pdserving/config.yml
@@ -34,7 +34,7 @@ op:
client_type: local_predictor
#det模型路径
- model_config: ./ppocr_det_mobile_2.0_serving
+ model_config: ./ppocrv2_det_serving
#Fetch结果列表,以client_config中fetch_var的alias_name为准
fetch_list: ["save_infer_model/scale_0.tmp_1"]
@@ -60,7 +60,7 @@ op:
client_type: local_predictor
#rec模型路径
- model_config: ./ppocr_rec_mobile_2.0_serving
+ model_config: ./ppocrv2_rec_serving
#Fetch结果列表,以client_config中fetch_var的alias_name为准
fetch_list: ["save_infer_model/scale_0.tmp_1"]
diff --git a/deploy/pdserving/web_service.py b/deploy/pdserving/web_service.py
index 21db1e1411a8706dbbd9a22ce2ce7db8e16da5ec..b97c6e1f564a61bb9792542b9e9f1e88d782e80d 100644
--- a/deploy/pdserving/web_service.py
+++ b/deploy/pdserving/web_service.py
@@ -54,7 +54,7 @@ class DetOp(Op):
_, self.new_h, self.new_w = det_img.shape
return {"x": det_img[np.newaxis, :].copy()}, False, None, ""
- def postprocess(self, input_dicts, fetch_dict, log_id):
+ def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
det_out = fetch_dict["save_infer_model/scale_0.tmp_1"]
ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
@@ -129,7 +129,7 @@ class RecOp(Op):
return feed_list, False, None, ""
- def postprocess(self, input_dicts, fetch_data, log_id):
+ def postprocess(self, input_dicts, fetch_data, data_id, log_id):
res_list = []
if isinstance(fetch_data, dict):
if len(fetch_data) > 0:
diff --git a/deploy/pdserving/web_service_det.py b/deploy/pdserving/web_service_det.py
index 25ac2f37dbd3cdf05b3503abaab0c5651867fae9..ee39388425763d789ada76cf0a9db9f812fe8d2a 100644
--- a/deploy/pdserving/web_service_det.py
+++ b/deploy/pdserving/web_service_det.py
@@ -54,7 +54,7 @@ class DetOp(Op):
_, self.new_h, self.new_w = det_img.shape
return {"x": det_img[np.newaxis, :].copy()}, False, None, ""
- def postprocess(self, input_dicts, fetch_dict, log_id):
+ def postprocess(self, input_dicts, fetch_dict, data_id, log_id):
det_out = fetch_dict["save_infer_model/scale_0.tmp_1"]
ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
diff --git a/deploy/pdserving/web_service_rec.py b/deploy/pdserving/web_service_rec.py
index 6b3cf707f0f19034a0734fd27824feb4fb6cce20..f5cd8bf053c604786fecb9b71749b3c98f2552a2 100644
--- a/deploy/pdserving/web_service_rec.py
+++ b/deploy/pdserving/web_service_rec.py
@@ -56,7 +56,7 @@ class RecOp(Op):
feed_list.append(feed)
return feed_list, False, None, ""
- def postprocess(self, input_dicts, fetch_data, log_id):
+ def postprocess(self, input_dicts, fetch_data, data_id, log_id):
res_list = []
if isinstance(fetch_data, dict):
if len(fetch_data) > 0:
diff --git a/doc/doc_ch/update.md b/doc/doc_ch/update.md
index 0852e240886b4ca736a830c8c44651ca35ec1f25..de5cdaf2aa24aa4c32e81001cdccec1156ee8605 100644
--- a/doc/doc_ch/update.md
+++ b/doc/doc_ch/update.md
@@ -1,4 +1,6 @@
# 更新
+- 2021.12.21 《OCR十讲》课程开讲,12月21日起每晚八点半线上授课! 【免费】报名地址:https://aistudio.baidu.com/aistudio/course/introduce/25207
+- 2021.12.21 发布PaddleOCR v2.4。OCR算法新增1种文本检测算法(PSENet),3种文本识别算法(NRTR、SEED、SAR);文档结构化算法新增1种关键信息提取算法(SDMGR),3种DocVQA算法(LayoutLM、LayoutLMv2,LayoutXLM)。
- 2021.9.7 发布PaddleOCR v2.3,发布[PP-OCRv2](#PP-OCRv2),CPU推理速度相比于PP-OCR server提升220%;效果相比于PP-OCR mobile 提升7%。
- 2021.8.3 发布PaddleOCR v2.2,新增文档结构分析[PP-Structure](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/ppstructure/README_ch.md)工具包,支持版面分析与表格识别(含Excel导出)。
- 2021.6.29 [FAQ](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/doc/doc_ch/FAQ.md)新增5个高频问题,总数248个,每周一都会更新,欢迎大家持续关注。
diff --git a/doc/doc_en/update_en.md b/doc/doc_en/update_en.md
index 660688c6d6991a4744dbc327d24e9c677afa0fc1..6a95b5be279d7a0b8a204cadd46b283b5eb26690 100644
--- a/doc/doc_en/update_en.md
+++ b/doc/doc_en/update_en.md
@@ -1,4 +1,6 @@
# RECENT UPDATES
+- 2021.12.21 OCR open source online course starts. The lesson starts at 8:30 every night and lasts for ten days. Free registration: https://aistudio.baidu.com/aistudio/course/introduce/25207
+- 2021.12.21 release PaddleOCR v2.4, release 1 text detection algorithm (PSENet), 3 text recognition algorithms (NRTR、SEED、SAR), 1 key information extraction algorithm (SDMGR) and 3 DocVQA algorithms (LayoutLM、LayoutLMv2,LayoutXLM).
- 2021.9.7 release PaddleOCR v2.3, [PP-OCRv2](#PP-OCRv2) is proposed. The CPU inference speed of PP-OCRv2 is 220% higher than that of PP-OCR server. The F-score of PP-OCRv2 is 7% higher than that of PP-OCR mobile.
- 2021.8.3 released PaddleOCR v2.2, add a new structured documents analysis toolkit, i.e., [PP-Structure](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/ppstructure/README.md), support layout analysis and table recognition (One-key to export chart images to Excel files).
- 2021.4.8 release end-to-end text recognition algorithm [PGNet](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) which is published in AAAI 2021. Find tutorial [here](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/pgnet_en.md);release multi language recognition [models](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md), support more than 80 languages recognition; especically, the performance of [English recognition model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/models_list_en.md#English) is Optimized.
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 0a4fad621a9038e71a9d43eb4e12f78e7e92d73d..fc14fdbcf13a61b591d9ea6c2535aefe6e437ec6 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -19,6 +19,7 @@ from __future__ import unicode_literals
import numpy as np
import string
+from shapely.geometry import LineString, Point, Polygon
import json
from ppocr.utils.logging import get_logger
@@ -286,6 +287,168 @@ class E2ELabelEncodeTrain(object):
return data
+class KieLabelEncode(object):
+ def __init__(self, character_dict_path, norm=10, directed=False, **kwargs):
+ super(KieLabelEncode, self).__init__()
+ self.dict = dict({'': 0})
+ with open(character_dict_path, 'r', encoding='utf-8') as fr:
+ idx = 1
+ for line in fr:
+ char = line.strip()
+ self.dict[char] = idx
+ idx += 1
+ self.norm = norm
+ self.directed = directed
+
+ def compute_relation(self, boxes):
+ """Compute relation between every two boxes."""
+ x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
+ x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
+ ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
+ dxs = (x1s[:, 0][None] - x1s) / self.norm
+ dys = (y1s[:, 0][None] - y1s) / self.norm
+ xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
+ whs = ws / hs + np.zeros_like(xhhs)
+ relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
+ bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
+ return relations, bboxes
+
+ def pad_text_indices(self, text_inds):
+ """Pad text index to same length."""
+ max_len = 300
+ recoder_len = max([len(text_ind) for text_ind in text_inds])
+ padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
+ for idx, text_ind in enumerate(text_inds):
+ padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
+ return padded_text_inds, recoder_len
+
+ def list_to_numpy(self, ann_infos):
+ """Convert bboxes, relations, texts and labels to ndarray."""
+ boxes, text_inds = ann_infos['points'], ann_infos['text_inds']
+ boxes = np.array(boxes, np.int32)
+ relations, bboxes = self.compute_relation(boxes)
+
+ labels = ann_infos.get('labels', None)
+ if labels is not None:
+ labels = np.array(labels, np.int32)
+ edges = ann_infos.get('edges', None)
+ if edges is not None:
+ labels = labels[:, None]
+ edges = np.array(edges)
+ edges = (edges[:, None] == edges[None, :]).astype(np.int32)
+ if self.directed:
+ edges = (edges & labels == 1).astype(np.int32)
+ np.fill_diagonal(edges, -1)
+ labels = np.concatenate([labels, edges], -1)
+ padded_text_inds, recoder_len = self.pad_text_indices(text_inds)
+ max_num = 300
+ temp_bboxes = np.zeros([max_num, 4])
+ h, _ = bboxes.shape
+ temp_bboxes[:h, :h] = bboxes
+
+ temp_relations = np.zeros([max_num, max_num, 5])
+ temp_relations[:h, :h, :] = relations
+
+ temp_padded_text_inds = np.zeros([max_num, max_num])
+ temp_padded_text_inds[:h, :] = padded_text_inds
+
+ temp_labels = np.zeros([max_num, max_num])
+ temp_labels[:h, :h + 1] = labels
+
+ tag = np.array([h, recoder_len])
+ return dict(
+ image=ann_infos['image'],
+ points=temp_bboxes,
+ relations=temp_relations,
+ texts=temp_padded_text_inds,
+ labels=temp_labels,
+ tag=tag)
+
+ def convert_canonical(self, points_x, points_y):
+
+ assert len(points_x) == 4
+ assert len(points_y) == 4
+
+ points = [Point(points_x[i], points_y[i]) for i in range(4)]
+
+ polygon = Polygon([(p.x, p.y) for p in points])
+ min_x, min_y, _, _ = polygon.bounds
+ points_to_lefttop = [
+ LineString([points[i], Point(min_x, min_y)]) for i in range(4)
+ ]
+ distances = np.array([line.length for line in points_to_lefttop])
+ sort_dist_idx = np.argsort(distances)
+ lefttop_idx = sort_dist_idx[0]
+
+ if lefttop_idx == 0:
+ point_orders = [0, 1, 2, 3]
+ elif lefttop_idx == 1:
+ point_orders = [1, 2, 3, 0]
+ elif lefttop_idx == 2:
+ point_orders = [2, 3, 0, 1]
+ else:
+ point_orders = [3, 0, 1, 2]
+
+ sorted_points_x = [points_x[i] for i in point_orders]
+ sorted_points_y = [points_y[j] for j in point_orders]
+
+ return sorted_points_x, sorted_points_y
+
+ def sort_vertex(self, points_x, points_y):
+
+ assert len(points_x) == 4
+ assert len(points_y) == 4
+
+ x = np.array(points_x)
+ y = np.array(points_y)
+ center_x = np.sum(x) * 0.25
+ center_y = np.sum(y) * 0.25
+
+ x_arr = np.array(x - center_x)
+ y_arr = np.array(y - center_y)
+
+ angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
+ sort_idx = np.argsort(angle)
+
+ sorted_points_x, sorted_points_y = [], []
+ for i in range(4):
+ sorted_points_x.append(points_x[sort_idx[i]])
+ sorted_points_y.append(points_y[sort_idx[i]])
+
+ return self.convert_canonical(sorted_points_x, sorted_points_y)
+
+ def __call__(self, data):
+ import json
+ label = data['label']
+ annotations = json.loads(label)
+ boxes, texts, text_inds, labels, edges = [], [], [], [], []
+ for ann in annotations:
+ box = ann['points']
+ x_list = [box[i][0] for i in range(4)]
+ y_list = [box[i][1] for i in range(4)]
+ sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list)
+ sorted_box = []
+ for x, y in zip(sorted_x_list, sorted_y_list):
+ sorted_box.append(x)
+ sorted_box.append(y)
+ boxes.append(sorted_box)
+ text = ann['transcription']
+ texts.append(ann['transcription'])
+ text_ind = [self.dict[c] for c in text if c in self.dict]
+ text_inds.append(text_ind)
+ labels.append(ann['label'])
+ edges.append(ann.get('edge', 0))
+ ann_infos = dict(
+ image=data['image'],
+ points=boxes,
+ texts=texts,
+ text_inds=text_inds,
+ edges=edges,
+ labels=labels)
+
+ return self.list_to_numpy(ann_infos)
+
+
class AttnLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
@@ -344,8 +507,12 @@ class SEEDLabelEncode(BaseRecLabelEncode):
max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
+ self.padding = "padding"
self.end_str = "eos"
- dict_character = dict_character + [self.end_str]
+ self.unknown = "unknown"
+ dict_character = dict_character + [
+ self.end_str, self.padding, self.unknown
+ ]
return dict_character
def __call__(self, data):
@@ -356,8 +523,8 @@ class SEEDLabelEncode(BaseRecLabelEncode):
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text)) + 1 # conclude eos
- text = text + [len(self.character) - 1] * (self.max_text_len - len(text)
- )
+ text = text + [len(self.character) - 3] + [len(self.character) - 2] * (
+ self.max_text_len - len(text) - 1)
data['label'] = np.array(text)
return data
diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py
index 87e3088d07a8c5a2eea5d4deff87c69a753e215b..c3dfd316f86d88b5c7fd52eb6ae23d22a4dd32eb 100644
--- a/ppocr/data/imaug/operators.py
+++ b/ppocr/data/imaug/operators.py
@@ -111,7 +111,6 @@ class NormalizeImage(object):
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
-
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
@@ -367,3 +366,53 @@ class E2EResizeForTest(object):
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
+
+
+class KieResize(object):
+ def __init__(self, **kwargs):
+ super(KieResize, self).__init__()
+ self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
+ 'img_scale'][1]
+
+ def __call__(self, data):
+ img = data['image']
+ points = data['points']
+ src_h, src_w, _ = img.shape
+ im_resized, scale_factor, [ratio_h, ratio_w
+ ], [new_h, new_w] = self.resize_image(img)
+ resize_points = self.resize_boxes(img, points, scale_factor)
+ data['ori_image'] = img
+ data['ori_boxes'] = points
+ data['points'] = resize_points
+ data['image'] = im_resized
+ data['shape'] = np.array([new_h, new_w])
+ return data
+
+ def resize_image(self, img):
+ norm_img = np.zeros([1024, 1024, 3], dtype='float32')
+ scale = [512, 1024]
+ h, w = img.shape[:2]
+ max_long_edge = max(scale)
+ max_short_edge = min(scale)
+ scale_factor = min(max_long_edge / max(h, w),
+ max_short_edge / min(h, w))
+ resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
+ scale_factor) + 0.5)
+ max_stride = 32
+ resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+ resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+ im = cv2.resize(img, (resize_w, resize_h))
+ new_h, new_w = im.shape[:2]
+ w_scale = new_w / w
+ h_scale = new_h / h
+ scale_factor = np.array(
+ [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
+ norm_img[:new_h, :new_w, :] = im
+ return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
+
+ def resize_boxes(self, im, points, scale_factor):
+ points = points * scale_factor
+ img_shape = im.shape[:2]
+ points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
+ points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
+ return points
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index f3f4cd49332b605ec3a0e65e688d965fd91a5cdf..62ad2b6ad86edf9b5446aea03f9333f9d4981336 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -35,6 +35,7 @@ from .cls_loss import ClsLoss
# e2e loss
from .e2e_pg_loss import PGLoss
+from .kie_sdmgr_loss import SDMGRLoss
# basic loss function
from .basic_loss import DistanceLoss
@@ -50,7 +51,7 @@ def build_loss(config):
support_dict = [
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
- 'TableAttentionLoss', 'SARLoss', 'AsterLoss'
+ 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/kie_sdmgr_loss.py b/ppocr/losses/kie_sdmgr_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f2173e49904926ebab2c450890c4fafe3f36b50
--- /dev/null
+++ b/ppocr/losses/kie_sdmgr_loss.py
@@ -0,0 +1,113 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn
+import paddle
+
+
+class SDMGRLoss(nn.Layer):
+ def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0):
+ super().__init__()
+ self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore)
+ self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
+ self.node_weight = node_weight
+ self.edge_weight = edge_weight
+ self.ignore = ignore
+
+ def pre_process(self, gts, tag):
+ gts, tag = gts.numpy(), tag.numpy().tolist()
+ temp_gts = []
+ batch = len(tag)
+ for i in range(batch):
+ num, recoder_len = tag[i][0], tag[i][1]
+ temp_gts.append(
+ paddle.to_tensor(
+ gts[i, :num, :num + 1], dtype='int64'))
+ return temp_gts
+
+ def accuracy(self, pred, target, topk=1, thresh=None):
+ """Calculate accuracy according to the prediction and target.
+
+ Args:
+ pred (torch.Tensor): The model prediction, shape (N, num_class)
+ target (torch.Tensor): The target of each prediction, shape (N, )
+ topk (int | tuple[int], optional): If the predictions in ``topk``
+ matches the target, the predictions will be regarded as
+ correct ones. Defaults to 1.
+ thresh (float, optional): If not None, predictions with scores under
+ this threshold are considered incorrect. Default to None.
+
+ Returns:
+ float | tuple[float]: If the input ``topk`` is a single integer,
+ the function will return a single float as accuracy. If
+ ``topk`` is a tuple containing multiple integers, the
+ function will return a tuple containing accuracies of
+ each ``topk`` number.
+ """
+ assert isinstance(topk, (int, tuple))
+ if isinstance(topk, int):
+ topk = (topk, )
+ return_single = True
+ else:
+ return_single = False
+
+ maxk = max(topk)
+ if pred.shape[0] == 0:
+ accu = [pred.new_tensor(0.) for i in range(len(topk))]
+ return accu[0] if return_single else accu
+ pred_value, pred_label = paddle.topk(pred, maxk, axis=1)
+ pred_label = pred_label.transpose(
+ [1, 0]) # transpose to shape (maxk, N)
+ correct = paddle.equal(pred_label,
+ (target.reshape([1, -1]).expand_as(pred_label)))
+ res = []
+ for k in topk:
+ correct_k = paddle.sum(correct[:k].reshape([-1]).astype('float32'),
+ axis=0,
+ keepdim=True)
+ res.append(
+ paddle.multiply(correct_k,
+ paddle.to_tensor(100.0 / pred.shape[0])))
+ return res[0] if return_single else res
+
+ def forward(self, pred, batch):
+ node_preds, edge_preds = pred
+ gts, tag = batch[4], batch[5]
+ gts = self.pre_process(gts, tag)
+ node_gts, edge_gts = [], []
+ for gt in gts:
+ node_gts.append(gt[:, 0])
+ edge_gts.append(gt[:, 1:].reshape([-1]))
+ node_gts = paddle.concat(node_gts)
+ edge_gts = paddle.concat(edge_gts)
+
+ node_valids = paddle.nonzero(node_gts != self.ignore).reshape([-1])
+ edge_valids = paddle.nonzero(edge_gts != -1).reshape([-1])
+ loss_node = self.loss_node(node_preds, node_gts)
+ loss_edge = self.loss_edge(edge_preds, edge_gts)
+ loss = self.node_weight * loss_node + self.edge_weight * loss_edge
+ return dict(
+ loss=loss,
+ loss_node=loss_node,
+ loss_edge=loss_edge,
+ acc_node=self.accuracy(
+ paddle.gather(node_preds, node_valids),
+ paddle.gather(node_gts, node_valids)),
+ acc_edge=self.accuracy(
+ paddle.gather(edge_preds, edge_valids),
+ paddle.gather(edge_gts, edge_valids)))
diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py
index 64f62e51cdf922773c03bb784a4edffdc17f506f..28bff3cb4eb7784db876940f761208f1b084f0e2 100644
--- a/ppocr/metrics/__init__.py
+++ b/ppocr/metrics/__init__.py
@@ -27,10 +27,13 @@ from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
from .table_metric import TableMetric
+from .kie_metric import KIEMetric
+
def build_metric(config):
support_dict = [
- "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric"
+ "DetMetric", "RecMetric", "ClsMetric", "E2EMetric",
+ "DistillationMetric", "TableMetric", 'KIEMetric'
]
config = copy.deepcopy(config)
diff --git a/ppocr/metrics/kie_metric.py b/ppocr/metrics/kie_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..761965cfcc25d2a6de30342769d01b36d6212d98
--- /dev/null
+++ b/ppocr/metrics/kie_metric.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle
+
+__all__ = ['KIEMetric']
+
+
+class KIEMetric(object):
+ def __init__(self, main_indicator='hmean', **kwargs):
+ self.main_indicator = main_indicator
+ self.reset()
+ self.node = []
+ self.gt = []
+
+ def __call__(self, preds, batch, **kwargs):
+ nodes, _ = preds
+ gts, tag = batch[4].squeeze(0), batch[5].tolist()[0]
+ gts = gts[:tag[0], :1].reshape([-1])
+ self.node.append(nodes.numpy())
+ self.gt.append(gts)
+ # result = self.compute_f1_score(nodes, gts)
+ # self.results.append(result)
+
+ def compute_f1_score(self, preds, gts):
+ ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25]
+ C = preds.shape[1]
+ classes = np.array(sorted(set(range(C)) - set(ignores)))
+ hist = np.bincount(
+ (gts * C).astype('int64') + preds.argmax(1), minlength=C
+ **2).reshape([C, C]).astype('float32')
+ diag = np.diag(hist)
+ recalls = diag / hist.sum(1).clip(min=1)
+ precisions = diag / hist.sum(0).clip(min=1)
+ f1 = 2 * recalls * precisions / (recalls + precisions).clip(min=1e-8)
+ return f1[classes]
+
+ def combine_results(self, results):
+ node = np.concatenate(self.node, 0)
+ gts = np.concatenate(self.gt, 0)
+ results = self.compute_f1_score(node, gts)
+ data = {'hmean': results.mean()}
+ return data
+
+ def get_metric(self):
+
+ metircs = self.combine_results(self.results)
+ self.reset()
+ return metircs
+
+ def reset(self):
+ self.results = [] # clear results
+ self.node = []
+ self.gt = []
diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py
index db2f41c3a140ecebc42b71ee03f0ecb5cf50ca80..b0ccd974f24f1c7e0c9a8e1d414373021c4288e6 100644
--- a/ppocr/metrics/rec_metric.py
+++ b/ppocr/metrics/rec_metric.py
@@ -48,7 +48,7 @@ class RecMetric(object):
self.norm_edit_dis += norm_edit_dis
return {
'acc': correct_num / all_num,
- 'norm_edit_dis': 1 - norm_edit_dis / all_num
+ 'norm_edit_dis': 1 - norm_edit_dis / (all_num + 1e-3)
}
def get_metric(self):
@@ -58,8 +58,8 @@ class RecMetric(object):
'norm_edit_dis': 0,
}
"""
- acc = 1.0 * self.correct_num / self.all_num
- norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
+ acc = 1.0 * self.correct_num / (self.all_num + 1e-3)
+ norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + 1e-3)
self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 66b507fd24158ddf64d68dd7392f828a2e17c399..d10983487bedb0fc4278095db08d1f234ef5c595 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -35,7 +35,14 @@ def build_backbone(config, model_type):
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
- support_dict = ["ResNet"]
+ support_dict = ['ResNet']
+ elif model_type == 'kie':
+ from .kie_unet_sdmgr import Kie_backbone
+ support_dict = ['Kie_backbone']
+ elif model_type == "table":
+ from .table_resnet_vd import ResNet
+ from .table_mobilenet_v3 import MobileNetV3
+ support_dict = ["ResNet", "MobileNetV3"]
else:
raise NotImplementedError
diff --git a/ppocr/modeling/backbones/kie_unet_sdmgr.py b/ppocr/modeling/backbones/kie_unet_sdmgr.py
new file mode 100644
index 0000000000000000000000000000000000000000..545e4e7511e58c3d8220e9ec0be35474deba8806
--- /dev/null
+++ b/ppocr/modeling/backbones/kie_unet_sdmgr.py
@@ -0,0 +1,186 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+import numpy as np
+import cv2
+
+__all__ = ["Kie_backbone"]
+
+
+class Encoder(nn.Layer):
+ def __init__(self, num_channels, num_filters):
+ super(Encoder, self).__init__()
+ self.conv1 = nn.Conv2D(
+ num_channels,
+ num_filters,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn1 = nn.BatchNorm(num_filters, act='relu')
+
+ self.conv2 = nn.Conv2D(
+ num_filters,
+ num_filters,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn2 = nn.BatchNorm(num_filters, act='relu')
+
+ self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
+
+ def forward(self, inputs):
+ x = self.conv1(inputs)
+ x = self.bn1(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x_pooled = self.pool(x)
+ return x, x_pooled
+
+
+class Decoder(nn.Layer):
+ def __init__(self, num_channels, num_filters):
+ super(Decoder, self).__init__()
+
+ self.conv1 = nn.Conv2D(
+ num_channels,
+ num_filters,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn1 = nn.BatchNorm(num_filters, act='relu')
+
+ self.conv2 = nn.Conv2D(
+ num_filters,
+ num_filters,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn2 = nn.BatchNorm(num_filters, act='relu')
+
+ self.conv0 = nn.Conv2D(
+ num_channels,
+ num_filters,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias_attr=False)
+ self.bn0 = nn.BatchNorm(num_filters, act='relu')
+
+ def forward(self, inputs_prev, inputs):
+ x = self.conv0(inputs)
+ x = self.bn0(x)
+ x = paddle.nn.functional.interpolate(
+ x, scale_factor=2, mode='bilinear', align_corners=False)
+ x = paddle.concat([inputs_prev, x], axis=1)
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ return x
+
+
+class UNet(nn.Layer):
+ def __init__(self):
+ super(UNet, self).__init__()
+ self.down1 = Encoder(num_channels=3, num_filters=16)
+ self.down2 = Encoder(num_channels=16, num_filters=32)
+ self.down3 = Encoder(num_channels=32, num_filters=64)
+ self.down4 = Encoder(num_channels=64, num_filters=128)
+ self.down5 = Encoder(num_channels=128, num_filters=256)
+
+ self.up1 = Decoder(32, 16)
+ self.up2 = Decoder(64, 32)
+ self.up3 = Decoder(128, 64)
+ self.up4 = Decoder(256, 128)
+ self.out_channels = 16
+
+ def forward(self, inputs):
+ x1, _ = self.down1(inputs)
+ _, x2 = self.down2(x1)
+ _, x3 = self.down3(x2)
+ _, x4 = self.down4(x3)
+ _, x5 = self.down5(x4)
+
+ x = self.up4(x4, x5)
+ x = self.up3(x3, x)
+ x = self.up2(x2, x)
+ x = self.up1(x1, x)
+ return x
+
+
+class Kie_backbone(nn.Layer):
+ def __init__(self, in_channels, **kwargs):
+ super(Kie_backbone, self).__init__()
+ self.out_channels = 16
+ self.img_feat = UNet()
+ self.maxpool = nn.MaxPool2D(kernel_size=7)
+
+ def bbox2roi(self, bbox_list):
+ rois_list = []
+ rois_num = []
+ for img_id, bboxes in enumerate(bbox_list):
+ rois_num.append(bboxes.shape[0])
+ rois_list.append(bboxes)
+ rois = paddle.concat(rois_list, 0)
+ rois_num = paddle.to_tensor(rois_num, dtype='int32')
+ return rois, rois_num
+
+ def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):
+ img, relations, texts, gt_bboxes, tag, img_size = img.numpy(
+ ), relations.numpy(), texts.numpy(), gt_bboxes.numpy(), tag.numpy(
+ ).tolist(), img_size.numpy()
+ temp_relations, temp_texts, temp_gt_bboxes = [], [], []
+ h, w = int(np.max(img_size[:, 0])), int(np.max(img_size[:, 1]))
+ img = paddle.to_tensor(img[:, :, :h, :w])
+ batch = len(tag)
+ for i in range(batch):
+ num, recoder_len = tag[i][0], tag[i][1]
+ temp_relations.append(
+ paddle.to_tensor(
+ relations[i, :num, :num, :], dtype='float32'))
+ temp_texts.append(
+ paddle.to_tensor(
+ texts[i, :num, :recoder_len], dtype='float32'))
+ temp_gt_bboxes.append(
+ paddle.to_tensor(
+ gt_bboxes[i, :num, ...], dtype='float32'))
+ return img, temp_relations, temp_texts, temp_gt_bboxes
+
+ def forward(self, inputs):
+ img = inputs[0]
+ relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[
+ 2], inputs[3], inputs[5], inputs[-1]
+ img, relations, texts, gt_bboxes = self.pre_process(
+ img, relations, texts, gt_bboxes, tag, img_size)
+ x = self.img_feat(img)
+ boxes, rois_num = self.bbox2roi(gt_bboxes)
+ feats = paddle.fluid.layers.roi_align(
+ x,
+ boxes,
+ spatial_scale=1.0,
+ pooled_height=7,
+ pooled_width=7,
+ rois_num=rois_num)
+ feats = self.maxpool(feats).squeeze(-1).squeeze(-1)
+ return [relations, texts, feats]
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index fdadfed5e3fe30b6bd311a07d6ba36869f175488..4a27ce52a64da5a53d524f58d7613669171d5662 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -33,14 +33,19 @@ def build_head(config):
# cls head
from .cls_head import ClsHead
+
+ #kie head
+ from .kie_sdmgr_head import SDMGRHead
+
+ from .table_att_head import TableAttentionHead
+
support_dict = [
'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead',
'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
- 'TableAttentionHead', 'SARHead', 'AsterHead'
+ 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead'
]
#table head
- from .table_att_head import TableAttentionHead
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
diff --git a/ppocr/modeling/heads/kie_sdmgr_head.py b/ppocr/modeling/heads/kie_sdmgr_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..46ac0ed8dcaccb7628ef87fbe851a2b6acd60d55
--- /dev/null
+++ b/ppocr/modeling/heads/kie_sdmgr_head.py
@@ -0,0 +1,206 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+
+
+class SDMGRHead(nn.Layer):
+ def __init__(self,
+ in_channels,
+ num_chars=92,
+ visual_dim=16,
+ fusion_dim=1024,
+ node_input=32,
+ node_embed=256,
+ edge_input=5,
+ edge_embed=256,
+ num_gnn=2,
+ num_classes=26,
+ bidirectional=False):
+ super().__init__()
+
+ self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
+ self.node_embed = nn.Embedding(num_chars, node_input, 0)
+ hidden = node_embed // 2 if bidirectional else node_embed
+ self.rnn = nn.LSTM(
+ input_size=node_input, hidden_size=hidden, num_layers=1)
+ self.edge_embed = nn.Linear(edge_input, edge_embed)
+ self.gnn_layers = nn.LayerList(
+ [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
+ self.node_cls = nn.Linear(node_embed, num_classes)
+ self.edge_cls = nn.Linear(edge_embed, 2)
+
+ def forward(self, input, targets):
+ relations, texts, x = input
+ node_nums, char_nums = [], []
+ for text in texts:
+ node_nums.append(text.shape[0])
+ char_nums.append(paddle.sum((text > -1).astype(int), axis=-1))
+
+ max_num = max([char_num.max() for char_num in char_nums])
+ all_nodes = paddle.concat([
+ paddle.concat(
+ [text, paddle.zeros(
+ (text.shape[0], max_num - text.shape[1]))], -1)
+ for text in texts
+ ])
+ temp = paddle.clip(all_nodes, min=0).astype(int)
+ embed_nodes = self.node_embed(temp)
+ rnn_nodes, _ = self.rnn(embed_nodes)
+
+ b, h, w = rnn_nodes.shape
+ nodes = paddle.zeros([b, w])
+ all_nums = paddle.concat(char_nums)
+ valid = paddle.nonzero((all_nums > 0).astype(int))
+ temp_all_nums = (
+ paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1)
+ temp_all_nums = paddle.expand(temp_all_nums, [
+ temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1]
+ ])
+ temp_all_nodes = paddle.gather(rnn_nodes, valid)
+ N, C, A = temp_all_nodes.shape
+ one_hot = F.one_hot(
+ temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1])
+ one_hot = paddle.multiply(
+ temp_all_nodes, one_hot.astype("float32")).sum(axis=1, keepdim=True)
+ t = one_hot.expand([N, 1, A]).squeeze(1)
+ nodes = paddle.scatter(nodes, valid.squeeze(1), t)
+
+ if x is not None:
+ nodes = self.fusion([x, nodes])
+
+ all_edges = paddle.concat(
+ [rel.reshape([-1, rel.shape[-1]]) for rel in relations])
+ embed_edges = self.edge_embed(all_edges.astype('float32'))
+ embed_edges = F.normalize(embed_edges)
+
+ for gnn_layer in self.gnn_layers:
+ nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
+
+ node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
+ return node_cls, edge_cls
+
+
+class GNNLayer(nn.Layer):
+ def __init__(self, node_dim=256, edge_dim=256):
+ super().__init__()
+ self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim)
+ self.coef_fc = nn.Linear(node_dim, 1)
+ self.out_fc = nn.Linear(node_dim, node_dim)
+ self.relu = nn.ReLU()
+
+ def forward(self, nodes, edges, nums):
+ start, cat_nodes = 0, []
+ for num in nums:
+ sample_nodes = nodes[start:start + num]
+ cat_nodes.append(
+ paddle.concat([
+ paddle.expand(sample_nodes.unsqueeze(1), [-1, num, -1]),
+ paddle.expand(sample_nodes.unsqueeze(0), [num, -1, -1])
+ ], -1).reshape([num**2, -1]))
+ start += num
+ cat_nodes = paddle.concat([paddle.concat(cat_nodes), edges], -1)
+ cat_nodes = self.relu(self.in_fc(cat_nodes))
+ coefs = self.coef_fc(cat_nodes)
+
+ start, residuals = 0, []
+ for num in nums:
+ residual = F.softmax(
+ -paddle.eye(num).unsqueeze(-1) * 1e9 +
+ coefs[start:start + num**2].reshape([num, num, -1]), 1)
+ residuals.append((residual * cat_nodes[start:start + num**2]
+ .reshape([num, num, -1])).sum(1))
+ start += num**2
+
+ nodes += self.relu(self.out_fc(paddle.concat(residuals)))
+ return [nodes, cat_nodes]
+
+
+class Block(nn.Layer):
+ def __init__(self,
+ input_dims,
+ output_dim,
+ mm_dim=1600,
+ chunks=20,
+ rank=15,
+ shared=False,
+ dropout_input=0.,
+ dropout_pre_lin=0.,
+ dropout_output=0.,
+ pos_norm='before_cat'):
+ super().__init__()
+ self.rank = rank
+ self.dropout_input = dropout_input
+ self.dropout_pre_lin = dropout_pre_lin
+ self.dropout_output = dropout_output
+ assert (pos_norm in ['before_cat', 'after_cat'])
+ self.pos_norm = pos_norm
+ # Modules
+ self.linear0 = nn.Linear(input_dims[0], mm_dim)
+ self.linear1 = (self.linear0
+ if shared else nn.Linear(input_dims[1], mm_dim))
+ self.merge_linears0 = nn.LayerList()
+ self.merge_linears1 = nn.LayerList()
+ self.chunks = self.chunk_sizes(mm_dim, chunks)
+ for size in self.chunks:
+ ml0 = nn.Linear(size, size * rank)
+ self.merge_linears0.append(ml0)
+ ml1 = ml0 if shared else nn.Linear(size, size * rank)
+ self.merge_linears1.append(ml1)
+ self.linear_out = nn.Linear(mm_dim, output_dim)
+
+ def forward(self, x):
+ x0 = self.linear0(x[0])
+ x1 = self.linear1(x[1])
+ bs = x1.shape[0]
+ if self.dropout_input > 0:
+ x0 = F.dropout(x0, p=self.dropout_input, training=self.training)
+ x1 = F.dropout(x1, p=self.dropout_input, training=self.training)
+ x0_chunks = paddle.split(x0, self.chunks, -1)
+ x1_chunks = paddle.split(x1, self.chunks, -1)
+ zs = []
+ for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, self.merge_linears0,
+ self.merge_linears1):
+ m = m0(x0_c) * m1(x1_c) # bs x split_size*rank
+ m = m.reshape([bs, self.rank, -1])
+ z = paddle.sum(m, 1)
+ if self.pos_norm == 'before_cat':
+ z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
+ z = F.normalize(z)
+ zs.append(z)
+ z = paddle.concat(zs, 1)
+ if self.pos_norm == 'after_cat':
+ z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
+ z = F.normalize(z)
+
+ if self.dropout_pre_lin > 0:
+ z = F.dropout(z, p=self.dropout_pre_lin, training=self.training)
+ z = self.linear_out(z)
+ if self.dropout_output > 0:
+ z = F.dropout(z, p=self.dropout_output, training=self.training)
+ return z
+
+ def chunk_sizes(self, dim, chunks):
+ split_size = (dim + chunks - 1) // chunks
+ sizes_list = [split_size] * chunks
+ sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim)
+ return sizes_list
diff --git a/ppocr/modeling/heads/rec_aster_head.py b/ppocr/modeling/heads/rec_aster_head.py
index 9240f002d3a8bcbde517142be6b45559430de610..c95e8fd31f84c26cf58f7fbbdaab6c825b10eea8 100644
--- a/ppocr/modeling/heads/rec_aster_head.py
+++ b/ppocr/modeling/heads/rec_aster_head.py
@@ -47,7 +47,7 @@ class AsterHead(nn.Layer):
self.time_step = time_step
self.embeder = Embedding(self.time_step, in_channels)
self.beam_width = beam_width
- self.eos = self.num_classes - 1
+ self.eos = self.num_classes - 3
def forward(self, x, targets=None, embed=None):
return_dict = {}
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index c6cb0144f7efd9ff7976ad67a658a554eafce754..37dadd12d3f628b1802b6a31f611f49f3ac600c2 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -45,6 +45,8 @@ def build_post_process(config, global_config=None):
config = copy.deepcopy(config)
module_name = config.pop('name')
+ if module_name == "None":
+ return
if global_config is not None:
config.update(global_config)
assert module_name in support_dict, Exception(
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index ef1a43fd0ee65f3e55a8f72dfd2f96c478da1a9a..caaa2948522cb6ea7ed74b8ab79a3d0b465059a3 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -287,9 +287,12 @@ class SEEDLabelDecode(BaseRecLabelDecode):
use_space_char)
def add_special_char(self, dict_character):
- self.beg_str = "sos"
+ self.padding_str = "padding"
self.end_str = "eos"
- dict_character = dict_character + [self.end_str]
+ self.unknown = "unknown"
+ dict_character = dict_character + [
+ self.end_str, self.padding_str, self.unknown
+ ]
return dict_character
def get_ignored_tokens(self):
diff --git a/ppstructure/README.md b/ppstructure/README.md
index 849c5c5667ff0532dfee35479715880192df0dc5..a09a43299b11dccf99897d5a6c69704191253aaf 100644
--- a/ppstructure/README.md
+++ b/ppstructure/README.md
@@ -153,13 +153,12 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_in
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
-python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
+python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
```
After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image.
**Model List**
-
|model name|description|config|model size|download|
| --- | --- | --- | --- | --- |
|en_ppocr_mobile_v2.0_table_structure|Table structure prediction for English table scenarios|[table_mv3.yml](../configs/table/table_mv3.yml)|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
@@ -184,4 +183,5 @@ OCR and table recognition model
|en_ppocr_mobile_v2.0_table_rec|Text recognition of English table scene trained on PubLayNet dataset|6.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
+
If you need to use other models, you can download the model in [model_list](../doc/doc_en/models_list_en.md) or use your own trained model to configure it to the three fields of `det_model_dir`, `rec_model_dir`, `table_model_dir` .
diff --git a/ppstructure/README_ch.md b/ppstructure/README_ch.md
index 821a6c3e36361abefa4d754537fdbd694e844efe..607efac1bf6bfaa58f0e96ceef1a0ee344189e9c 100644
--- a/ppstructure/README_ch.md
+++ b/ppstructure/README_ch.md
@@ -1,6 +1,12 @@
[English](README.md) | 简体中文
-# PP-Structure
+## 简介
+PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包,旨在帮助开发者更好的完成文档理解相关任务。
+
+## 近期更新
+* 2021.12.07 新增VQA任务-SER和RE。
+
+## 特性
PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包,主要特性如下:
- 支持对图片形式的文档进行版面分析,可以划分**文字、标题、表格、图片以及列表**5类区域(与Layout-Parser联合使用)
@@ -8,181 +14,88 @@ PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包
- 支持表格区域进行结构化分析,最终结果输出Excel文件
- 支持python whl包和命令行两种方式,简单易用
- 支持版面分析和表格结构化两类任务自定义训练
+- 支持文档视觉问答(Document Visual Question Answering,DOC-VQA)任务-语义实体识别(Semantic Entity Recognition,SER)和关系抽取(Relation Extraction,RE)
-## 1. 效果展示
-
-
-
-
-
-## 2. 安装
-
-### 2.1 安装依赖
-
-- **(1) 安装PaddlePaddle**
-
-```bash
-pip3 install --upgrade pip
-
-# GPU安装
-python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/simple
-
-# CPU安装
- python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple
-
-```
-更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
-
-- **(2) 安装 Layout-Parser**
-
-```bash
-pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
-```
-
-### 2.2 安装PaddleOCR(包含PP-OCR和PP-Structure)
-
-- **(1) PIP快速安装PaddleOCR whl包(仅预测)**
-```bash
-pip install "paddleocr>=2.2" # 推荐使用2.2+版本
-```
-
-- **(2) 完整克隆PaddleOCR源码(预测+训练)**
-
-```bash
-【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR
-
-#如果因为网络问题无法pull成功,也可选择使用码云上的托管:
-git clone https://gitee.com/paddlepaddle/PaddleOCR
-
-#注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
-```
-
-
-## 3. PP-Structure 快速开始
-
-### 3.1 命令行使用(默认参数,极简)
-
-```bash
-paddleocr --image_dir=../doc/table/1.png --type=structure
-```
-
-### 3.2 Python脚本使用(自定义参数,灵活)
+## 1. 效果展示
-```python
-import os
-import cv2
-from paddleocr import PPStructure,draw_structure_result,save_structure_res
+### 1.1 版面分析和表格识别
-table_engine = PPStructure(show_log=True)
+
-save_folder = './output/table'
-img_path = '../doc/table/1.png'
-img = cv2.imread(img_path)
-result = table_engine(img)
-save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
+### 1.2 VQA
-for line in result:
- line.pop('img')
- print(line)
+* SER
-from PIL import Image
+ | 
+---|---
-font_path = '../doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
-image = Image.open(img_path).convert('RGB')
-im_show = draw_structure_result(image, result,font_path=font_path)
-im_show = Image.fromarray(im_show)
-im_show.save('result.jpg')
-```
+图中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别
-### 3.3 返回结果说明
-PP-Structure的返回结果为一个dict组成的list,示例如下
+* 深紫色:HEADER
+* 浅紫色:QUESTION
+* 军绿色:ANSWER
-```shell
-[
- { 'type': 'Text',
- 'bbox': [34, 432, 345, 462],
- 'res': ([[36.0, 437.0, 341.0, 437.0, 341.0, 446.0, 36.0, 447.0], [41.0, 454.0, 125.0, 453.0, 125.0, 459.0, 41.0, 460.0]],
- [('Tigure-6. The performance of CNN and IPT models using difforen', 0.90060663), ('Tent ', 0.465441)])
- }
-]
-```
-dict 里各个字段说明如下
+在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
-| 字段 | 说明 |
-| --------------- | -------------|
-|type|图片区域的类型|
-|bbox|图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y]|
-|res|图片区域的OCR或表格识别结果。
表格: 表格的HTML字符串;
OCR: 一个包含各个单行文字的检测坐标和识别结果的元组|
+* RE
+ | 
+---|---
-### 3.4 参数说明
-| 字段 | 说明 | 默认值 |
-| --------------- | ---------------------------------------- | ------------------------------------------- |
-| output | excel和识别结果保存的地址 | ./output/table |
-| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
-| table_model_dir | 表格结构模型 inference 模型地址 | None |
-| table_char_type | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.tx |
+图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
-大部分参数和paddleocr whl包保持一致,见 [whl包文档](../doc/doc_ch/whl.md)
+## 2. 快速体验
-运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
+代码体验:从 [快速安装](./docs/quickstart.md) 开始
+## 3. PP-Structure Pipeline介绍
-## 4. PP-Structure Pipeline介绍
+### 3.1 版面分析+表格识别

在PP-Structure中,图片会先经由Layout-Parser进行版面分析,在版面分析中,会对图片里的区域进行分类,包括**文字、标题、图片、列表和表格**5类。对于前4类区域,直接使用PP-OCR完成对应区域文字检测与识别。对于表格类区域,经过表格结构化处理后,表格图片转换为相同表格样式的Excel文件。
-### 4.1 版面分析
+#### 3.1.1 版面分析
版面分析对文档数据进行区域分类,其中包括版面分析工具的Python脚本使用、提取指定类别检测框、性能指标以及自定义训练版面分析模型,详细内容可以参考[文档](layout/README_ch.md)。
-### 4.2 表格识别
+#### 3.1.2 表格识别
表格识别将表格图片转换为excel文档,其中包含对于表格文本的检测和识别以及对于表格结构和单元格坐标的预测,详细说明参考[文档](table/README_ch.md)
-## 5. 预测引擎推理(与whl包效果相同)
-使用如下命令即可完成预测引擎的推理
+### 3.2 VQA
-```python
-cd ppstructure
+coming soon
-# 下载模型
-mkdir inference && cd inference
-# 下载超轻量级中文OCR模型的检测模型并解压
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar
-# 下载超轻量级中文OCR模型的识别模型并解压
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
-# 下载超轻量级英文表格英寸模型并解压
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
-cd ..
+## 4. 模型库
-python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
-```
-运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
+PP-Structure系列模型列表(更新中)
-**Model List**
-
-LayoutParser 模型
+* LayoutParser 模型
|模型名称|模型简介|下载地址|
| --- | --- | --- |
| ppyolov2_r50vd_dcn_365e_publaynet | PubLayNet 数据集训练的版面分析模型,可以划分**文字、标题、表格、图片以及列表**5类区域 | [PubLayNet](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) |
-| ppyolov2_r50vd_dcn_365e_tableBank_word | TableBank Word 数据集训练的版面分析模型,只能检测表格 | [TableBank Word](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar) |
-| ppyolov2_r50vd_dcn_365e_tableBank_latex | TableBank Latex 数据集训练的版面分析模型,只能检测表格 | [TableBank Latex](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar) |
-OCR和表格识别模型
-|模型名称|模型简介|推理模型大小|下载地址|
+* OCR和表格识别模型
+
+|模型名称|模型简介|模型大小|下载地址|
| --- | --- | --- | --- |
|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
-|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) |
-|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
-如需要使用其他模型,可以在 [model_list](../doc/doc_ch/models_list.md) 下载模型或者使用自己训练好的模型配置到`det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
+* VQA模型
+
+|模型名称|模型简介|模型大小|下载地址|
+| --- | --- | --- | --- |
+|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
+|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
+
+
+更多模型下载,可以参考 [模型库](./docs/model_list.md)
diff --git a/ppstructure/docs/imgs/0.png b/ppstructure/docs/imgs/0.png
new file mode 100644
index 0000000000000000000000000000000000000000..6fa4fe8be7589dcc70fa93e6ad0c12f641eccdc4
Binary files /dev/null and b/ppstructure/docs/imgs/0.png differ
diff --git a/ppstructure/docs/installation.md b/ppstructure/docs/installation.md
new file mode 100644
index 0000000000000000000000000000000000000000..30c25d5dc92f6ccdb0d93dafe9707f30eca0c0a9
--- /dev/null
+++ b/ppstructure/docs/installation.md
@@ -0,0 +1,28 @@
+# 快速安装
+
+## 1. PaddlePaddle 和 PaddleOCR
+
+可参考[PaddleOCR安装文档](../../doc/doc_ch/installation.md)
+
+## 2. 安装其他依赖
+
+### 2.1 版面分析所需 Layout-Parser
+
+Layout-Parser 可通过如下命令安装
+
+```bash
+pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
+```
+### 2.2 VQA所需依赖
+* paddleocr
+
+```bash
+pip3 install paddleocr
+```
+
+* PaddleNLP
+```bash
+git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
+cd PaddleNLP
+pip3 install -e .
+```
diff --git a/ppstructure/docs/kie.md b/ppstructure/docs/kie.md
new file mode 100644
index 0000000000000000000000000000000000000000..21854b0d24b0b2bbe6a4612b1112b201c5df255d
--- /dev/null
+++ b/ppstructure/docs/kie.md
@@ -0,0 +1,74 @@
+
+
+# 关键信息提取(Key Information Extraction)
+
+本节介绍PaddleOCR中关键信息提取SDMGR方法的快速使用和训练方法。
+
+SDMGR是一个关键信息提取算法,将每个检测到的文本区域分类为预定义的类别,如订单ID、发票号码,金额等。
+
+
+* [1. 快速使用](#1-----)
+* [2. 执行训练](#2-----)
+* [3. 执行评估](#3-----)
+
+
+## 1. 快速使用
+
+训练和测试的数据采用wildreceipt数据集,通过如下指令下载数据集:
+
+```
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar
+```
+
+执行预测:
+
+```
+cd PaddleOCR/
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/kie_vgg16.tar && tar xf kie_vgg16.tar
+python3.7 tools/infer_kie.py -c configs/kie/kie_unet_sdmgr.yml -o Global.checkpoints=kie_vgg16/best_accuracy Global.infer_img=../wildreceipt/1.txt
+```
+
+执行预测后的结果保存在`./output/sdmgr_kie/predicts_kie.txt`文件中,可视化结果保存在`/output/sdmgr_kie/kie_results/`目录下。
+
+可视化结果如下图所示:
+
+
+

+
+
+
+## 2. 执行训练
+
+创建数据集软链到PaddleOCR/train_data目录下:
+```
+cd PaddleOCR/ && mkdir train_data && cd train_data
+
+ln -s ../../wildreceipt ./
+```
+
+训练采用的配置文件是configs/kie/kie_unet_sdmgr.yml,配置文件中默认训练数据路径是`train_data/wildreceipt`,准备好数据后,可以通过如下指令执行训练:
+```
+python3.7 tools/train.py -c configs/kie/kie_unet_sdmgr.yml -o Global.save_model_dir=./output/kie/
+```
+
+## 3. 执行评估
+
+```
+python3.7 tools/eval.py -c configs/kie/kie_unet_sdmgr.yml -o Global.checkpoints=./output/kie/best_accuracy
+```
+
+
+**参考文献:**
+
+
+
+```bibtex
+@misc{sun2021spatial,
+ title={Spatial Dual-Modality Graph Reasoning for Key Information Extraction},
+ author={Hongbin Sun and Zhanghui Kuang and Xiaoyu Yue and Chenhao Lin and Wayne Zhang},
+ year={2021},
+ eprint={2103.14470},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
diff --git a/ppstructure/docs/model_list.md b/ppstructure/docs/model_list.md
new file mode 100644
index 0000000000000000000000000000000000000000..45004490c1c4b0ea01a5fb409024f1eeb922f1a3
--- /dev/null
+++ b/ppstructure/docs/model_list.md
@@ -0,0 +1,34 @@
+# Model List
+
+## 1. LayoutParser 模型
+
+|模型名称|模型简介|下载地址|
+| --- | --- | --- |
+| ppyolov2_r50vd_dcn_365e_publaynet | PubLayNet 数据集训练的版面分析模型,可以划分**文字、标题、表格、图片以及列表**5类区域 | [PubLayNet](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) |
+| ppyolov2_r50vd_dcn_365e_tableBank_word | TableBank Word 数据集训练的版面分析模型,只能检测表格 | [TableBank Word](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar) |
+| ppyolov2_r50vd_dcn_365e_tableBank_latex | TableBank Latex 数据集训练的版面分析模型,只能检测表格 | [TableBank Latex](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar) |
+
+## 2. OCR和表格识别模型
+
+|模型名称|模型简介|推理模型大小|下载地址|
+| --- | --- | --- | --- |
+|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
+|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
+|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) |
+|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
+|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
+
+如需要使用其他OCR模型,可以在 [model_list](../../doc/doc_ch/models_list.md) 下载模型或者使用自己训练好的模型配置到`det_model_dir`,`rec_model_dir`两个字段即可。
+
+## 3. VQA模型
+
+|模型名称|模型简介|推理模型大小|下载地址|
+| --- | --- | --- | --- |
+|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
+|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
+
+## 3. KIE模型
+
+|模型名称|模型简介|模型大小|下载地址|
+| --- | --- | --- | --- |
+|SDMGR|关键信息提取模型|-|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/kie_vgg16.tar)|
diff --git a/ppstructure/docs/quickstart.md b/ppstructure/docs/quickstart.md
new file mode 100644
index 0000000000000000000000000000000000000000..446c577ec39cf24dd4b8699558c633a1308fa444
--- /dev/null
+++ b/ppstructure/docs/quickstart.md
@@ -0,0 +1,171 @@
+# PP-Structure 快速开始
+
+* [1. 安装PaddleOCR whl包](#1)
+* [2. 便捷使用](#2)
+ + [2.1 命令行使用](#21)
+ + [2.2 Python脚本使用](#22)
+ + [2.3 返回结果说明](#23)
+ + [2.4 参数说明](#24)
+* [3. Python脚本使用](#3)
+
+
+
+
+## 1. 安装依赖包
+
+```bash
+pip install "paddleocr>=2.3.0.2" # 推荐使用2.3.0.2+版本
+pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
+
+# 安装 PaddleNLP
+git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
+cd PaddleNLP
+pip3 install -e .
+
+```
+
+
+
+## 2. 便捷使用
+
+
+
+### 2.1 命令行使用
+
+* 版面分析+表格识别
+```bash
+paddleocr --image_dir=../doc/table/1.png --type=structure
+```
+
+* VQA
+
+coming soon
+
+
+
+### 2.2 Python脚本使用
+
+* 版面分析+表格识别
+```python
+import os
+import cv2
+from paddleocr import PPStructure,draw_structure_result,save_structure_res
+
+table_engine = PPStructure(show_log=True)
+
+save_folder = './output/table'
+img_path = '../doc/table/1.png'
+img = cv2.imread(img_path)
+result = table_engine(img)
+save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
+
+for line in result:
+ line.pop('img')
+ print(line)
+
+from PIL import Image
+
+font_path = '../doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
+image = Image.open(img_path).convert('RGB')
+im_show = draw_structure_result(image, result,font_path=font_path)
+im_show = Image.fromarray(im_show)
+im_show.save('result.jpg')
+```
+
+* VQA
+
+comming soon
+
+
+
+### 2.3 返回结果说明
+PP-Structure的返回结果为一个dict组成的list,示例如下
+
+* 版面分析+表格识别
+```shell
+[
+ { 'type': 'Text',
+ 'bbox': [34, 432, 345, 462],
+ 'res': ([[36.0, 437.0, 341.0, 437.0, 341.0, 446.0, 36.0, 447.0], [41.0, 454.0, 125.0, 453.0, 125.0, 459.0, 41.0, 460.0]],
+ [('Tigure-6. The performance of CNN and IPT models using difforen', 0.90060663), ('Tent ', 0.465441)])
+ }
+]
+```
+dict 里各个字段说明如下
+
+| 字段 | 说明 |
+| --------------- | -------------|
+|type|图片区域的类型|
+|bbox|图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y]|
+|res|图片区域的OCR或表格识别结果。
表格: 表格的HTML字符串;
OCR: 一个包含各个单行文字的检测坐标和识别结果的元组|
+
+* VQA
+
+comming soon
+
+
+
+### 2.4 参数说明
+
+| 字段 | 说明 | 默认值 |
+| --------------- | ---------------------------------------- | ------------------------------------------- |
+| output | excel和识别结果保存的地址 | ./output/table |
+| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
+| table_model_dir | 表格结构模型 inference 模型地址 | None |
+| table_char_type | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt |
+| model_name_or_path | VQA SER模型地址 | None |
+| max_seq_length | VQA SER模型最大支持token长度 | 512 |
+| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt |
+| mode | pipeline预测模式,structure: 版面分析+表格识别; vqa: ser文档信息抽取 | structure |
+
+大部分参数和paddleocr whl包保持一致,见 [whl包文档](../doc/doc_ch/whl.md)
+
+运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
+
+
+
+## 3. Python脚本使用
+
+* 版面分析+表格识别
+
+```bash
+cd ppstructure
+
+# 下载模型
+mkdir inference && cd inference
+# 下载超轻量级中文OCR模型的检测模型并解压
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar
+# 下载超轻量级中文OCR模型的识别模型并解压
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
+# 下载超轻量级英文表格英寸模型并解压
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
+cd ..
+
+python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
+ --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
+ --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \
+ --image_dir=../doc/table/1.png \
+ --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
+ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
+ --output=../output/table \
+ --vis_font_path=../doc/fonts/simfang.ttf
+```
+运行完成后,每张图片会在`output`字段指定的目录下的`talbe`目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
+
+* VQA
+
+```bash
+cd ppstructure
+
+# 下载模型
+mkdir inference && cd inference
+# 下载SER xfun 模型并解压
+wget https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar && tar xf PP-Layout_v1.0_ser_pretrained.tar
+cd ..
+
+python3 predict_system.py --model_name_or_path=vqa/PP-Layout_v1.0_ser_pretrained/ \
+ --mode=vqa \
+ --image_dir=vqa/images/input/zh_val_0.jpg \
+ --vis_font_path=../doc/fonts/simfang.ttf
+```
+运行完成后,每张图片会在`output`字段指定的目录下的`vqa`目录下存放可视化之后的图片,图片名和输入图片名一致。
diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py
index b2de3d4de80b39f046cf6cbc8a9ebbc52bf69334..e87499ccc410ae67a170f63301e5a99ef948b161 100644
--- a/ppstructure/predict_system.py
+++ b/ppstructure/predict_system.py
@@ -30,6 +30,7 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel
+from ppstructure.vqa.infer_ser_e2e import SerPredictor, draw_ser_results
from ppstructure.utility import parse_args, draw_structure_result
logger = get_logger()
@@ -37,53 +38,75 @@ logger = get_logger()
class OCRSystem(object):
def __init__(self, args):
- import layoutparser as lp
- # args.det_limit_type = 'resize_long'
- args.drop_score = 0
- if not args.show_log:
- logger.setLevel(logging.INFO)
- self.text_system = TextSystem(args)
- self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
-
- config_path = None
- model_path = None
- if os.path.isdir(args.layout_path_model):
- model_path = args.layout_path_model
- else:
- config_path = args.layout_path_model
- self.table_layout = lp.PaddleDetectionLayoutModel(config_path=config_path,
- model_path=model_path,
- threshold=0.5, enable_mkldnn=args.enable_mkldnn,
- enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads)
- self.use_angle_cls = args.use_angle_cls
- self.drop_score = args.drop_score
+ self.mode = args.mode
+ if self.mode == 'structure':
+ import layoutparser as lp
+ # args.det_limit_type = 'resize_long'
+ args.drop_score = 0
+ if not args.show_log:
+ logger.setLevel(logging.INFO)
+ self.text_system = TextSystem(args)
+ self.table_system = TableSystem(args,
+ self.text_system.text_detector,
+ self.text_system.text_recognizer)
+
+ config_path = None
+ model_path = None
+ if os.path.isdir(args.layout_path_model):
+ model_path = args.layout_path_model
+ else:
+ config_path = args.layout_path_model
+ self.table_layout = lp.PaddleDetectionLayoutModel(
+ config_path=config_path,
+ model_path=model_path,
+ threshold=0.5,
+ enable_mkldnn=args.enable_mkldnn,
+ enforce_cpu=not args.use_gpu,
+ thread_num=args.cpu_threads)
+ self.use_angle_cls = args.use_angle_cls
+ self.drop_score = args.drop_score
+ elif self.mode == 'vqa':
+ self.vqa_engine = SerPredictor(args)
def __call__(self, img):
- ori_im = img.copy()
- layout_res = self.table_layout.detect(img[..., ::-1])
- res_list = []
- for region in layout_res:
- x1, y1, x2, y2 = region.coordinates
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
- roi_img = ori_im[y1:y2, x1:x2, :]
- if region.type == 'Table':
- res = self.table_system(roi_img)
- else:
- filter_boxes, filter_rec_res = self.text_system(roi_img)
- filter_boxes = [x + [x1, y1] for x in filter_boxes]
- filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
- # remove style char
- style_token = ['', '', '', '', '', '', '', '',
- '', '', '', '', '', '']
- filter_rec_res_tmp = []
- for rec_res in filter_rec_res:
- rec_str, rec_conf = rec_res
- for token in style_token:
- if token in rec_str:
- rec_str = rec_str.replace(token, '')
- filter_rec_res_tmp.append((rec_str, rec_conf))
- res = (filter_boxes, filter_rec_res_tmp)
- res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'img': roi_img, 'res': res})
+ if self.mode == 'structure':
+ ori_im = img.copy()
+ layout_res = self.table_layout.detect(img[..., ::-1])
+ res_list = []
+ for region in layout_res:
+ x1, y1, x2, y2 = region.coordinates
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
+ roi_img = ori_im[y1:y2, x1:x2, :]
+ if region.type == 'Table':
+ res = self.table_system(roi_img)
+ else:
+ filter_boxes, filter_rec_res = self.text_system(roi_img)
+ filter_boxes = [x + [x1, y1] for x in filter_boxes]
+ filter_boxes = [
+ x.reshape(-1).tolist() for x in filter_boxes
+ ]
+ # remove style char
+ style_token = [
+ '', '', '', '', '',
+ '', '', '', '', '',
+ '', '', '', ''
+ ]
+ filter_rec_res_tmp = []
+ for rec_res in filter_rec_res:
+ rec_str, rec_conf = rec_res
+ for token in style_token:
+ if token in rec_str:
+ rec_str = rec_str.replace(token, '')
+ filter_rec_res_tmp.append((rec_str, rec_conf))
+ res = (filter_boxes, filter_rec_res_tmp)
+ res_list.append({
+ 'type': region.type,
+ 'bbox': [x1, y1, x2, y2],
+ 'img': roi_img,
+ 'res': res
+ })
+ elif self.mode == 'vqa':
+ res_list, _ = self.vqa_engine(img)
return res_list
@@ -91,29 +114,35 @@ def save_structure_res(res, save_folder, img_name):
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
# save res
- with open(os.path.join(excel_save_folder, 'res.txt'), 'w', encoding='utf8') as f:
+ with open(
+ os.path.join(excel_save_folder, 'res.txt'), 'w',
+ encoding='utf8') as f:
for region in res:
if region['type'] == 'Table':
- excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
+ excel_path = os.path.join(excel_save_folder,
+ '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
if region['type'] == 'Figure':
roi_img = region['img']
- img_path = os.path.join(excel_save_folder, '{}.jpg'.format(region['bbox']))
+ img_path = os.path.join(excel_save_folder,
+ '{}.jpg'.format(region['bbox']))
cv2.imwrite(img_path, roi_img)
else:
for box, rec_res in zip(region['res'][0], region['res'][1]):
- f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
+ f.write('{}\t{}\n'.format(
+ np.array(box).reshape(-1).tolist(), rec_res))
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list
image_file_list = image_file_list[args.process_id::args.total_process_num]
- save_folder = args.output
- os.makedirs(save_folder, exist_ok=True)
structure_sys = OCRSystem(args)
img_num = len(image_file_list)
+ save_folder = os.path.join(args.output, structure_sys.mode)
+ os.makedirs(save_folder, exist_ok=True)
+
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
@@ -126,10 +155,16 @@ def main(args):
continue
starttime = time.time()
res = structure_sys(img)
- save_structure_res(res, save_folder, img_name)
- draw_img = draw_structure_result(img, res, args.vis_font_path)
- cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img)
- logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
+
+ if structure_sys.mode == 'structure':
+ save_structure_res(res, save_folder, img_name)
+ draw_img = draw_structure_result(img, res, args.vis_font_path)
+ img_save_path = os.path.join(save_folder, img_name, 'show.jpg')
+ elif structure_sys.mode == 'vqa':
+ draw_img = draw_ser_results(img, res, args.vis_font_path)
+ img_save_path = os.path.join(save_folder, img_name + '.jpg')
+ cv2.imwrite(img_save_path, draw_img)
+ logger.info('result save to {}'.format(img_save_path))
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
diff --git a/ppstructure/table/README.md b/ppstructure/table/README.md
index 67c4d8e26d5c615f4a930752005420ba1abcc834..30a11a20e5de90500d1408f671ba914f336a0b43 100644
--- a/ppstructure/table/README.md
+++ b/ppstructure/table/README.md
@@ -20,9 +20,9 @@ We evaluated the algorithm on the PubTabNet[1] eval dataset, and the
|Method|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
-| --- | --- |
-| EDD[2] | 88.3 |
-| Ours | 93.32 |
+| --- | --- |
+| EDD[2] | 88.3 |
+| Ours | 93.32 |
## 3. How to use
@@ -41,7 +41,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
# run
-python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
+python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_dict_path=../ppocr/utils/dict/en_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
@@ -82,8 +82,8 @@ python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./yo
The table uses [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) as the evaluation metric of the model. Before the model evaluation, the three models in the pipeline need to be exported as inference models (we have provided them), and the gt for evaluation needs to be prepared. Examples of gt are as follows:
```json
{"PMC4289340_004_00.png": [
- ["", "", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "
", "", ""],
- [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
+ ["", "", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "
", "", ""],
+ [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
[["", "F", "e", "a", "t", "u", "r", "e", ""], ["", "G", "b", "3", " ", "+", ""], ["", "G", "b", "3", " ", "-", ""], ["", "P", "a", "t", "i", "e", "n", "t", "s", ""], ["6", "2"], ["4", "5"]]
]}
```
@@ -95,7 +95,7 @@ In gt json, the key is the image name, the value is the corresponding gt, and gt
Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output.
```python
cd PaddleOCR/ppstructure
-python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
+python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
```
If the PubLatNet eval dataset is used, it will be output
@@ -113,4 +113,4 @@ After running, the excel sheet of each picture will be saved in the directory sp
Reference
1. https://github.com/ibm-aur-nlp/PubTabNet
-2. https://arxiv.org/pdf/1911.10683
\ No newline at end of file
+2. https://arxiv.org/pdf/1911.10683
diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md
index 2e90ad33423da347b5a51444f2be53ed2eb67a7a..33276b36e4973e83d7efa673b90013cf5727dfe2 100644
--- a/ppstructure/table/README_ch.md
+++ b/ppstructure/table/README_ch.md
@@ -34,9 +34,9 @@
|算法|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
-| --- | --- |
-| EDD[2] | 88.3 |
-| Ours | 93.32 |
+| --- | --- |
+| EDD[2] | 88.3 |
+| Ours | 93.32 |
## 3. 使用
@@ -56,7 +56,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
# 执行预测
-python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
+python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_dict_path=../ppocr/utils/dict/en_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
运行完成后,每张图片的excel表格会保存到output字段指定的目录下
@@ -94,8 +94,8 @@ python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./yo
表格使用 [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下:
```json
{"PMC4289340_004_00.png": [
- ["", "", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "
", "", ""],
- [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
+ ["", "", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "
", "", ""],
+ [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
[["", "F", "e", "a", "t", "u", "r", "e", ""], ["", "G", "b", "3", " ", "+", ""], ["", "G", "b", "3", " ", "-", ""], ["", "P", "a", "t", "i", "e", "n", "t", "s", ""], ["6", "2"], ["4", "5"]]
]}
```
@@ -107,7 +107,7 @@ json 中,key为图片名,value为对应的gt,gt是一个由三个item组
准备完成后使用如下命令进行评估,评估完成后会输出teds指标。
```python
cd PaddleOCR/ppstructure
-python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
+python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
```
如使用PubLatNet评估数据集,将会输出
```bash
@@ -123,4 +123,4 @@ python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model
Reference
1. https://github.com/ibm-aur-nlp/PubTabNet
-2. https://arxiv.org/pdf/1911.10683
\ No newline at end of file
+2. https://arxiv.org/pdf/1911.10683
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
index 7d9fa76d0ada58e363243c114519d001de3fbf2a..ce7a801b1bb4094d3f4d2ba467332c6763ad6287 100644
--- a/ppstructure/utility.py
+++ b/ppstructure/utility.py
@@ -21,13 +21,31 @@ def init_args():
parser = infer_args()
# params for output
- parser.add_argument("--output", type=str, default='./output/table')
+ parser.add_argument("--output", type=str, default='./output')
# params for table structure
parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_model_dir", type=str)
parser.add_argument("--table_char_type", type=str, default='en')
- parser.add_argument("--table_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
- parser.add_argument("--layout_path_model", type=str, default="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config")
+ parser.add_argument(
+ "--table_char_dict_path",
+ type=str,
+ default="../ppocr/utils/dict/table_structure_dict.txt")
+ parser.add_argument(
+ "--layout_path_model",
+ type=str,
+ default="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config")
+
+ # params for ser
+ parser.add_argument("--model_name_or_path", type=str)
+ parser.add_argument("--max_seq_length", type=int, default=512)
+ parser.add_argument(
+ "--label_map_path", type=str, default='./vqa/labels/labels_ser.txt')
+
+ parser.add_argument(
+ "--mode",
+ type=str,
+ default='structure',
+ help='structure and vqa is supported')
return parser
@@ -48,5 +66,6 @@ def draw_structure_result(image, result, font_path):
boxes.append(np.array(box).reshape(-1, 2))
txts.append(rec_res[0])
scores.append(rec_res[1])
- im_show = draw_ocr_box_txt(image, boxes, txts, scores, font_path=font_path,drop_score=0)
- return im_show
\ No newline at end of file
+ im_show = draw_ocr_box_txt(
+ image, boxes, txts, scores, font_path=font_path, drop_score=0)
+ return im_show
diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md
index 23fe28f8494ce84e774c3dd21811003f772c41f8..708f0ea4cfa9c38d96abff0768932eba65259c78 100644
--- a/ppstructure/vqa/README.md
+++ b/ppstructure/vqa/README.md
@@ -18,12 +18,13 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
## 1 性能
-我们在 [XFUN](https://github.com/doc-analysis/XFUND) 评估数据集上对算法进行了评估,性能如下
+我们在 [XFUN](https://github.com/doc-analysis/XFUND) 的中文数据集上对算法进行了评估,性能如下
-|任务| f1 | 模型下载地址|
-|:---:|:---:| :---:|
-|SER|0.9056| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)|
-|RE|0.7113| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar)|
+| 模型 | 任务 | f1 | 模型下载地址 |
+|:---:|:---:|:---:| :---:|
+| LayoutXLM | RE | 0.7113 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
+| LayoutXLM | SER | 0.9056 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
+| LayoutLM | SER | 0.78 | [链接](https://paddleocr.bj.bcebos.com/pplayout/LayoutLM_ser_pretrained.tar) |
@@ -98,7 +99,7 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR
# 需要使用PaddleNLP最新的代码版本进行安装
git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
cd PaddleNLP
-pip install -e .
+pip3 install -e .
```
@@ -135,13 +136,13 @@ wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
```shell
python3.7 train_ser.py \
--model_name_or_path "layoutxlm-base-uncased" \
+ --ser_model_type "LayoutXLM" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--num_train_epochs 200 \
--eval_steps 10 \
- --save_steps 500 \
--output_dir "./output/ser/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
@@ -151,13 +152,50 @@ python3.7 train_ser.py \
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/ser/`文件夹中。
+* 恢复训练
+
+```shell
+python3.7 train_ser.py \
+ --model_name_or_path "model_path" \
+ --ser_model_type "LayoutXLM" \
+ --train_data_dir "XFUND/zh_train/image" \
+ --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
+ --eval_data_dir "XFUND/zh_val/image" \
+ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
+ --num_train_epochs 200 \
+ --eval_steps 10 \
+ --output_dir "./output/ser/" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --evaluate_during_training \
+ --num_workers 8 \
+ --seed 2048 \
+ --resume
+```
+
+* 评估
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python3 eval_ser.py \
+ --model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
+ --ser_model_type "LayoutXLM" \
+ --eval_data_dir "XFUND/zh_val/image" \
+ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
+ --per_gpu_eval_batch_size 8 \
+ --num_workers 8 \
+ --output_dir "output/ser/" \
+ --seed 2048
+```
+最终会打印出`precision`, `recall`, `f1`等指标
+
* 使用评估集合中提供的OCR识别结果进行预测
```shell
export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser.py \
- --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
- --output_dir "output_res/" \
+ --model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
+ --ser_model_type "LayoutXLM" \
+ --output_dir "output/ser/" \
--infer_imgs "XFUND/zh_val/image/" \
--ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
```
@@ -169,9 +207,10 @@ python3.7 infer_ser.py \
```shell
export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser_e2e.py \
- --model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
+ --model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
+ --ser_model_type "LayoutXLM" \
--max_seq_length 512 \
- --output_dir "output_res_e2e/" \
+ --output_dir "output/ser_e2e/" \
--infer_imgs "images/input/zh_val_0.jpg"
```
@@ -188,6 +227,7 @@ python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_nor
* 启动训练
```shell
+export CUDA_VISIBLE_DEVICES=0
python3 train_re.py \
--model_name_or_path "layoutxlm-base-uncased" \
--train_data_dir "XFUND/zh_train/image" \
@@ -195,32 +235,74 @@ python3 train_re.py \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path 'labels/labels_ser.txt' \
- --num_train_epochs 2 \
+ --num_train_epochs 200 \
--eval_steps 10 \
- --save_steps 500 \
--output_dir "output/re/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--per_gpu_train_batch_size 8 \
--per_gpu_eval_batch_size 8 \
+ --num_workers 8 \
--evaluate_during_training \
--seed 2048
```
+* 恢复训练
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python3 train_re.py \
+ --model_name_or_path "model_path" \
+ --train_data_dir "XFUND/zh_train/image" \
+ --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
+ --eval_data_dir "XFUND/zh_val/image" \
+ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
+ --label_map_path 'labels/labels_ser.txt' \
+ --num_train_epochs 2 \
+ --eval_steps 10 \
+ --output_dir "output/re/" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --per_gpu_train_batch_size 8 \
+ --per_gpu_eval_batch_size 8 \
+ --num_workers 8 \
+ --evaluate_during_training \
+ --seed 2048 \
+ --resume
+
+```
+
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/re/`文件夹中。
+* 评估
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python3 eval_re.py \
+ --model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
+ --max_seq_length 512 \
+ --eval_data_dir "XFUND/zh_val/image" \
+ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
+ --label_map_path 'labels/labels_ser.txt' \
+ --output_dir "output/re/" \
+ --per_gpu_eval_batch_size 8 \
+ --num_workers 8 \
+ --seed 2048
+```
+最终会打印出`precision`, `recall`, `f1`等指标
+
+
* 使用评估集合中提供的OCR识别结果进行预测
```shell
export CUDA_VISIBLE_DEVICES=0
python3 infer_re.py \
- --model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
+ --model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path 'labels/labels_ser.txt' \
- --output_dir "output_res" \
+ --output_dir "output/re/" \
--per_gpu_eval_batch_size 1 \
--seed 2048
```
@@ -231,11 +313,12 @@ python3 infer_re.py \
```shell
export CUDA_VISIBLE_DEVICES=0
-# python3.7 infer_ser_re_e2e.py \
- --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
- --re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
+python3.7 infer_ser_re_e2e.py \
+ --model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
+ --re_model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
+ --ser_model_type "LayoutXLM" \
--max_seq_length 512 \
- --output_dir "output_ser_re_e2e_train/" \
+ --output_dir "output/ser_re_e2e/" \
--infer_imgs "images/input/zh_val_21.jpg"
```
diff --git a/ppstructure/vqa/eval_re.py b/ppstructure/vqa/eval_re.py
new file mode 100644
index 0000000000000000000000000000000000000000..12bb9cabdb8b4d6482a121ca3b73089b3d0244ff
--- /dev/null
+++ b/ppstructure/vqa/eval_re.py
@@ -0,0 +1,125 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+import paddle
+
+from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
+
+from xfun import XFUNDataset
+from utils import parse_args, get_bio_label_maps, print_arguments
+from data_collator import DataCollator
+from metric import re_score
+
+from ppocr.utils.logging import get_logger
+
+
+def cal_metric(re_preds, re_labels, entities):
+ gt_relations = []
+ for b in range(len(re_labels)):
+ rel_sent = []
+ for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
+ rel = {}
+ rel["head_id"] = head
+ rel["head"] = (entities[b]["start"][rel["head_id"]],
+ entities[b]["end"][rel["head_id"]])
+ rel["head_type"] = entities[b]["label"][rel["head_id"]]
+
+ rel["tail_id"] = tail
+ rel["tail"] = (entities[b]["start"][rel["tail_id"]],
+ entities[b]["end"][rel["tail_id"]])
+ rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
+
+ rel["type"] = 1
+ rel_sent.append(rel)
+ gt_relations.append(rel_sent)
+ re_metrics = re_score(re_preds, gt_relations, mode="boundaries")
+ return re_metrics
+
+
+def evaluate(model, eval_dataloader, logger, prefix=""):
+ # Eval!
+ logger.info("***** Running evaluation {} *****".format(prefix))
+ logger.info(" Num examples = {}".format(len(eval_dataloader.dataset)))
+
+ re_preds = []
+ re_labels = []
+ entities = []
+ eval_loss = 0.0
+ model.eval()
+ for idx, batch in enumerate(eval_dataloader):
+ with paddle.no_grad():
+ outputs = model(**batch)
+ loss = outputs['loss'].mean().item()
+ if paddle.distributed.get_rank() == 0:
+ logger.info("[Eval] process: {}/{}, loss: {:.5f}".format(
+ idx, len(eval_dataloader), loss))
+
+ eval_loss += loss
+ re_preds.extend(outputs['pred_relations'])
+ re_labels.extend(batch['relations'])
+ entities.extend(batch['entities'])
+ re_metrics = cal_metric(re_preds, re_labels, entities)
+ re_metrics = {
+ "precision": re_metrics["ALL"]["p"],
+ "recall": re_metrics["ALL"]["r"],
+ "f1": re_metrics["ALL"]["f1"],
+ }
+ model.train()
+ return re_metrics
+
+
+def eval(args):
+ logger = get_logger()
+ label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
+ pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+
+ tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
+
+ model = LayoutXLMForRelationExtraction.from_pretrained(
+ args.model_name_or_path)
+
+ eval_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.eval_data_dir,
+ label_path=args.eval_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ max_seq_len=args.max_seq_length,
+ pad_token_label_id=pad_token_label_id,
+ contains_re=True,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ eval_dataloader = paddle.io.DataLoader(
+ eval_dataset,
+ batch_size=args.per_gpu_eval_batch_size,
+ num_workers=args.num_workers,
+ shuffle=False,
+ collate_fn=DataCollator())
+
+ results = evaluate(model, eval_dataloader, logger)
+ logger.info("eval results: {}".format(results))
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ eval(args)
diff --git a/ppstructure/vqa/eval_ser.py b/ppstructure/vqa/eval_ser.py
new file mode 100644
index 0000000000000000000000000000000000000000..52eeb8a1da82d6e7dcdc726c1a77fd9ab18f0608
--- /dev/null
+++ b/ppstructure/vqa/eval_ser.py
@@ -0,0 +1,177 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+import random
+import time
+import copy
+import logging
+
+import argparse
+import paddle
+import numpy as np
+from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
+from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
+
+from xfun import XFUNDataset
+from losses import SERLoss
+from utils import parse_args, get_bio_label_maps, print_arguments
+
+from ppocr.utils.logging import get_logger
+
+MODELS = {
+ 'LayoutXLM':
+ (LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
+ 'LayoutLM':
+ (LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
+}
+
+
+def eval(args):
+ logger = get_logger()
+ print_arguments(args, logger)
+
+ label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
+ pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+
+ tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
+ tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
+ model = model_class.from_pretrained(args.model_name_or_path)
+
+ eval_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.eval_data_dir,
+ label_path=args.eval_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ pad_token_label_id=pad_token_label_id,
+ contains_re=False,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ eval_dataloader = paddle.io.DataLoader(
+ eval_dataset,
+ batch_size=args.per_gpu_eval_batch_size,
+ num_workers=args.num_workers,
+ use_shared_memory=True,
+ collate_fn=None, )
+
+ loss_class = SERLoss(len(label2id_map))
+
+ results, _ = evaluate(args, model, tokenizer, loss_class, eval_dataloader,
+ label2id_map, id2label_map, pad_token_label_id,
+ logger)
+
+ logger.info(results)
+
+
+def evaluate(args,
+ model,
+ tokenizer,
+ loss_class,
+ eval_dataloader,
+ label2id_map,
+ id2label_map,
+ pad_token_label_id,
+ logger,
+ prefix=""):
+
+ eval_loss = 0.0
+ nb_eval_steps = 0
+ preds = None
+ out_label_ids = None
+ model.eval()
+ for idx, batch in enumerate(eval_dataloader):
+ with paddle.no_grad():
+ if args.ser_model_type == 'LayoutLM':
+ if 'image' in batch:
+ batch.pop('image')
+ labels = batch.pop('labels')
+ outputs = model(**batch)
+ if args.ser_model_type == 'LayoutXLM':
+ outputs = outputs[0]
+ loss = loss_class(labels, outputs, batch['attention_mask'])
+
+ loss = loss.mean()
+
+ if paddle.distributed.get_rank() == 0:
+ logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
+ idx, len(eval_dataloader), loss.numpy()[0]))
+
+ eval_loss += loss.item()
+ nb_eval_steps += 1
+ if preds is None:
+ preds = outputs.numpy()
+ out_label_ids = labels.numpy()
+ else:
+ preds = np.append(preds, outputs.numpy(), axis=0)
+ out_label_ids = np.append(out_label_ids, labels.numpy(), axis=0)
+
+ eval_loss = eval_loss / nb_eval_steps
+ preds = np.argmax(preds, axis=2)
+
+ # label_map = {i: label.upper() for i, label in enumerate(labels)}
+
+ out_label_list = [[] for _ in range(out_label_ids.shape[0])]
+ preds_list = [[] for _ in range(out_label_ids.shape[0])]
+
+ for i in range(out_label_ids.shape[0]):
+ for j in range(out_label_ids.shape[1]):
+ if out_label_ids[i, j] != pad_token_label_id:
+ out_label_list[i].append(id2label_map[out_label_ids[i][j]])
+ preds_list[i].append(id2label_map[preds[i][j]])
+
+ results = {
+ "loss": eval_loss,
+ "precision": precision_score(out_label_list, preds_list),
+ "recall": recall_score(out_label_list, preds_list),
+ "f1": f1_score(out_label_list, preds_list),
+ }
+
+ with open(
+ os.path.join(args.output_dir, "test_gt.txt"), "w",
+ encoding='utf-8') as fout:
+ for lbl in out_label_list:
+ for l in lbl:
+ fout.write(l + "\t")
+ fout.write("\n")
+ with open(
+ os.path.join(args.output_dir, "test_pred.txt"), "w",
+ encoding='utf-8') as fout:
+ for lbl in preds_list:
+ for l in lbl:
+ fout.write(l + "\t")
+ fout.write("\n")
+
+ report = classification_report(out_label_list, preds_list)
+ logger.info("\n" + report)
+
+ logger.info("***** Eval results %s *****", prefix)
+ for key in sorted(results.keys()):
+ logger.info(" %s = %s", key, str(results[key]))
+ model.train()
+ return results, preds_list
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ eval(args)
diff --git a/ppstructure/vqa/helper/eval_with_label_end2end.py b/ppstructure/vqa/helper/eval_with_label_end2end.py
index c8dd3e0ad437e51e21ebc53daeec9fdf9aa76b63..3aa439acb269d74165543fac7e0042cfc213f08d 100644
--- a/ppstructure/vqa/helper/eval_with_label_end2end.py
+++ b/ppstructure/vqa/helper/eval_with_label_end2end.py
@@ -15,13 +15,12 @@
import os
import re
import sys
-# import Polygon
import shapely
from shapely.geometry import Polygon
import numpy as np
from collections import defaultdict
import operator
-import editdistance
+import Levenshtein
import argparse
import json
import copy
@@ -38,7 +37,7 @@ def parse_ser_results_fp(fp, fp_type="gt", ignore_background=True):
assert fp_type in ["gt", "pred"]
key = "label" if fp_type == "gt" else "pred"
res_dict = dict()
- with open(fp, "r") as fin:
+ with open(fp, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for _, line in enumerate(lines):
@@ -95,7 +94,7 @@ def ed(args, str1, str2):
if args.ignore_case:
str1 = str1.lower()
str2 = str2.lower()
- return editdistance.eval(str1, str2)
+ return Levenshtein.distance(str1, str2)
def convert_bbox_to_polygon(bbox):
@@ -115,8 +114,6 @@ def eval_e2e(args):
# pred
dt_results = parse_ser_results_fp(args.pred_json_path, "pred",
args.ignore_background)
- assert set(gt_results.keys()) == set(dt_results.keys())
-
iou_thresh = args.iou_thres
num_gt_chars = 0
gt_count = 0
@@ -124,7 +121,7 @@ def eval_e2e(args):
hit = 0
ed_sum = 0
- for img_name in gt_results:
+ for img_name in dt_results:
gt_info = gt_results[img_name]
gt_count += len(gt_info)
diff --git a/ppstructure/vqa/helper/trans_xfun_data.py b/ppstructure/vqa/helper/trans_xfun_data.py
index b5ebd5dfbd8addda0701a7cfd2387133f7a8776b..25b3963d8362d28ea1df4c62d1491095b8c49253 100644
--- a/ppstructure/vqa/helper/trans_xfun_data.py
+++ b/ppstructure/vqa/helper/trans_xfun_data.py
@@ -16,13 +16,13 @@ import json
def transfer_xfun_data(json_path=None, output_file=None):
- with open(json_path, "r") as fin:
+ with open(json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
json_info = json.loads(lines[0])
documents = json_info["documents"]
label_info = {}
- with open(output_file, "w") as fout:
+ with open(output_file, "w", encoding='utf-8') as fout:
for idx, document in enumerate(documents):
img_info = document["img"]
document = document["document"]
diff --git a/ppstructure/vqa/infer.sh b/ppstructure/vqa/infer.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2cd1cea4476672732b3a7f9ad97a3e42172dbb92
--- /dev/null
+++ b/ppstructure/vqa/infer.sh
@@ -0,0 +1,61 @@
+export CUDA_VISIBLE_DEVICES=6
+# python3.7 infer_ser_e2e.py \
+# --model_name_or_path "output/ser_distributed/best_model" \
+# --max_seq_length 512 \
+# --output_dir "output_res_e2e/" \
+# --infer_imgs "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val/zh_val_0.jpg"
+
+
+# python3.7 infer_ser_re_e2e.py \
+# --model_name_or_path "output/ser_distributed/best_model" \
+# --re_model_name_or_path "output/re_test/best_model" \
+# --max_seq_length 512 \
+# --output_dir "output_ser_re_e2e_train/" \
+# --infer_imgs "images/input/zh_val_21.jpg"
+
+# python3.7 infer_ser.py \
+# --model_name_or_path "output/ser_LayoutLM/best_model" \
+# --ser_model_type "LayoutLM" \
+# --output_dir "ser_LayoutLM/" \
+# --infer_imgs "images/input/zh_val_21.jpg" \
+# --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
+
+python3.7 infer_ser.py \
+ --model_name_or_path "output/ser_new/best_model" \
+ --ser_model_type "LayoutXLM" \
+ --output_dir "ser_new/" \
+ --infer_imgs "images/input/zh_val_21.jpg" \
+ --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
+
+# python3.7 infer_ser_e2e.py \
+# --model_name_or_path "output/ser_new/best_model" \
+# --ser_model_type "LayoutXLM" \
+# --max_seq_length 512 \
+# --output_dir "output/ser_new/" \
+# --infer_imgs "images/input/zh_val_0.jpg"
+
+
+# python3.7 infer_ser_e2e.py \
+# --model_name_or_path "output/ser_LayoutLM/best_model" \
+# --ser_model_type "LayoutLM" \
+# --max_seq_length 512 \
+# --output_dir "output/ser_LayoutLM/" \
+# --infer_imgs "images/input/zh_val_0.jpg"
+
+# python3 infer_re.py \
+# --model_name_or_path "/ssd1/zhoujun20/VQA/PaddleOCR/ppstructure/vqa/output/re_test/best_model/" \
+# --max_seq_length 512 \
+# --eval_data_dir "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val" \
+# --eval_label_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" \
+# --label_map_path 'labels/labels_ser.txt' \
+# --output_dir "output_res" \
+# --per_gpu_eval_batch_size 1 \
+# --seed 2048
+
+# python3.7 infer_ser_re_e2e.py \
+# --model_name_or_path "output/ser_LayoutLM/best_model" \
+# --ser_model_type "LayoutLM" \
+# --re_model_name_or_path "output/re_new/best_model" \
+# --max_seq_length 512 \
+# --output_dir "output_ser_re_e2e/" \
+# --infer_imgs "images/input/zh_val_21.jpg"
\ No newline at end of file
diff --git a/ppstructure/vqa/infer_re.py b/ppstructure/vqa/infer_re.py
index ae2f52550294b072179c3bdba28c3572369e11a3..7937700a78c28eff19a78dc5ddc8ccef82cc0148 100644
--- a/ppstructure/vqa/infer_re.py
+++ b/ppstructure/vqa/infer_re.py
@@ -56,15 +56,19 @@ def infer(args):
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
for idx, batch in enumerate(eval_dataloader):
- logger.info("[Infer] process: {}/{}".format(idx, len(eval_dataloader)))
- with paddle.no_grad():
- outputs = model(**batch)
- pred_relations = outputs['pred_relations']
-
ocr_info = ocr_info_list[idx]
image_path = ocr_info['image_path']
ocr_info = ocr_info['ocr_info']
+ save_img_path = os.path.join(
+ args.output_dir,
+ os.path.splitext(os.path.basename(image_path))[0] + "_re.jpg")
+ logger.info("[Infer] process: {}/{}, save result to {}".format(
+ idx, len(eval_dataloader), save_img_path))
+ with paddle.no_grad():
+ outputs = model(**batch)
+ pred_relations = outputs['pred_relations']
+
# 根据entity里的信息,做token解码后去过滤不要的ocr_info
ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
@@ -85,14 +89,13 @@ def infer(args):
img = cv2.imread(image_path)
img_show = draw_re_results(img, result)
- save_path = os.path.join(args.output_dir, os.path.basename(image_path))
- cv2.imwrite(save_path, img_show)
+ cv2.imwrite(save_img_path, img_show)
def load_ocr(img_folder, json_path):
import json
d = []
- with open(json_path, "r") as fin:
+ with open(json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for line in lines:
image_name, info_str = line.split("\t")
diff --git a/ppstructure/vqa/infer_ser.py b/ppstructure/vqa/infer_ser.py
index 4ad220094a26b330555fbe9122a46fb56e64fe1e..7994b5449af97988e9ad458f4047b59b81f72560 100644
--- a/ppstructure/vqa/infer_ser.py
+++ b/ppstructure/vqa/infer_ser.py
@@ -24,6 +24,14 @@ import paddle
# relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
+from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
+
+MODELS = {
+ 'LayoutXLM':
+ (LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
+ 'LayoutLM':
+ (LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
+}
def pad_sentences(tokenizer,
@@ -59,7 +67,8 @@ def pad_sentences(tokenizer,
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
] * difference
else:
- assert False, f"padding_side of tokenizer just supports [\"right\"] but got {tokenizer.padding_side}"
+ assert False, "padding_side of tokenizer just supports [\"right\"] but got {}".format(
+ tokenizer.padding_side)
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
@@ -216,15 +225,15 @@ def infer(args):
os.makedirs(args.output_dir, exist_ok=True)
# init token and model
- tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
- # model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
- model = LayoutXLMForTokenClassification.from_pretrained(
- args.model_name_or_path)
+ tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
+ tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
+ model = model_class.from_pretrained(args.model_name_or_path)
+
model.eval()
# load ocr results json
ocr_results = dict()
- with open(args.ocr_json_path, "r") as fin:
+ with open(args.ocr_json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for line in lines:
img_name, json_info = line.split("\t")
@@ -234,9 +243,15 @@ def infer(args):
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
- with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
+ with open(
+ os.path.join(args.output_dir, "infer_results.txt"),
+ "w",
+ encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
- print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
+ save_img_path = os.path.join(args.output_dir,
+ os.path.basename(img_path))
+ print("process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
@@ -246,15 +261,21 @@ def infer(args):
ori_img=img,
ocr_info=ocr_info,
max_seq_len=args.max_seq_length)
+ if args.ser_model_type == 'LayoutLM':
+ preds = model(
+ input_ids=inputs["input_ids"],
+ bbox=inputs["bbox"],
+ token_type_ids=inputs["token_type_ids"],
+ attention_mask=inputs["attention_mask"])
+ elif args.ser_model_type == 'LayoutXLM':
+ preds = model(
+ input_ids=inputs["input_ids"],
+ bbox=inputs["bbox"],
+ image=inputs["image"],
+ token_type_ids=inputs["token_type_ids"],
+ attention_mask=inputs["attention_mask"])
+ preds = preds[0]
- outputs = model(
- input_ids=inputs["input_ids"],
- bbox=inputs["bbox"],
- image=inputs["image"],
- token_type_ids=inputs["token_type_ids"],
- attention_mask=inputs["attention_mask"])
-
- preds = outputs[0]
preds = postprocess(inputs["attention_mask"], preds,
args.label_map_path)
ocr_info = merge_preds_list_with_ocr_info(
@@ -267,9 +288,7 @@ def infer(args):
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, ocr_info)
- cv2.imwrite(
- os.path.join(args.output_dir, os.path.basename(img_path)),
- img_res)
+ cv2.imwrite(save_img_path, img_res)
return
diff --git a/ppstructure/vqa/infer_ser_e2e.py b/ppstructure/vqa/infer_ser_e2e.py
index 1638e78a11105feb1cb037a545005b2384672eb8..6bb0247501bc98ea709d3ad47c284268b3d9503b 100644
--- a/ppstructure/vqa/infer_ser_e2e.py
+++ b/ppstructure/vqa/infer_ser_e2e.py
@@ -22,14 +22,20 @@ from PIL import Image
import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
-
-from paddleocr import PaddleOCR
+from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
# relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
+MODELS = {
+ 'LayoutXLM':
+ (LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
+ 'LayoutLM':
+ (LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
+}
+
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
@@ -52,19 +58,23 @@ def parse_ocr_info_for_ser(ocr_result):
class SerPredictor(object):
def __init__(self, args):
+ self.args = args
self.max_seq_length = args.max_seq_length
# init ser token and model
- self.tokenizer = LayoutXLMTokenizer.from_pretrained(
- args.model_name_or_path)
- self.model = LayoutXLMForTokenClassification.from_pretrained(
+ tokenizer_class, base_model_class, model_class = MODELS[
+ args.ser_model_type]
+ self.tokenizer = tokenizer_class.from_pretrained(
args.model_name_or_path)
+ self.model = model_class.from_pretrained(args.model_name_or_path)
self.model.eval()
# init ocr_engine
+ from paddleocr import PaddleOCR
+
self.ocr_engine = PaddleOCR(
- rec_model_dir=args.ocr_rec_model_dir,
- det_model_dir=args.ocr_det_model_dir,
+ rec_model_dir=args.rec_model_dir,
+ det_model_dir=args.det_model_dir,
use_angle_cls=False,
show_log=False)
# init dict
@@ -88,14 +98,21 @@ class SerPredictor(object):
ocr_info=ocr_info,
max_seq_len=self.max_seq_length)
- outputs = self.model(
- input_ids=inputs["input_ids"],
- bbox=inputs["bbox"],
- image=inputs["image"],
- token_type_ids=inputs["token_type_ids"],
- attention_mask=inputs["attention_mask"])
+ if self.args.ser_model_type == 'LayoutLM':
+ preds = self.model(
+ input_ids=inputs["input_ids"],
+ bbox=inputs["bbox"],
+ token_type_ids=inputs["token_type_ids"],
+ attention_mask=inputs["attention_mask"])
+ elif self.args.ser_model_type == 'LayoutXLM':
+ preds = self.model(
+ input_ids=inputs["input_ids"],
+ bbox=inputs["bbox"],
+ image=inputs["image"],
+ token_type_ids=inputs["token_type_ids"],
+ attention_mask=inputs["attention_mask"])
+ preds = preds[0]
- preds = outputs[0]
preds = postprocess(inputs["attention_mask"], preds, self.id2label_map)
ocr_info = merge_preds_list_with_ocr_info(
ocr_info, inputs["segment_offset_id"], preds,
@@ -112,9 +129,16 @@ if __name__ == "__main__":
# loop for infer
ser_engine = SerPredictor(args)
- with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
+ with open(
+ os.path.join(args.output_dir, "infer_results.txt"),
+ "w",
+ encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
- print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
+ save_img_path = os.path.join(
+ args.output_dir,
+ os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
+ print("process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
@@ -125,7 +149,4 @@ if __name__ == "__main__":
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, result)
- cv2.imwrite(
- os.path.join(args.output_dir,
- os.path.splitext(os.path.basename(img_path))[0] +
- "_ser.jpg"), img_res)
+ cv2.imwrite(save_img_path, img_res)
diff --git a/ppstructure/vqa/infer_ser_re_e2e.py b/ppstructure/vqa/infer_ser_re_e2e.py
index a1d0f52eeecbc6c2ceba5964355008f638f371dd..32d8850a16eebee7e43fd58cbaa604fc2bc00b7c 100644
--- a/ppstructure/vqa/infer_ser_re_e2e.py
+++ b/ppstructure/vqa/infer_ser_re_e2e.py
@@ -112,9 +112,16 @@ if __name__ == "__main__":
# loop for infer
ser_re_engine = SerReSystem(args)
- with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
+ with open(
+ os.path.join(args.output_dir, "infer_results.txt"),
+ "w",
+ encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
- print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
+ save_img_path = os.path.join(
+ args.output_dir,
+ os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg")
+ print("process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
@@ -125,7 +132,4 @@ if __name__ == "__main__":
}, ensure_ascii=False) + "\n")
img_res = draw_re_results(img, result)
- cv2.imwrite(
- os.path.join(args.output_dir,
- os.path.splitext(os.path.basename(img_path))[0] +
- "_re.jpg"), img_res)
+ cv2.imwrite(save_img_path, img_res)
diff --git a/ppstructure/vqa/losses.py b/ppstructure/vqa/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8dad01c3198f200788c7898d1b77b38d917d1ca
--- /dev/null
+++ b/ppstructure/vqa/losses.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle import nn
+
+
+class SERLoss(nn.Layer):
+ def __init__(self, num_classes):
+ super().__init__()
+ self.loss_class = nn.CrossEntropyLoss()
+ self.num_classes = num_classes
+ self.ignore_index = self.loss_class.ignore_index
+
+ def forward(self, labels, outputs, attention_mask):
+ if attention_mask is not None:
+ active_loss = attention_mask.reshape([-1, ]) == 1
+ active_outputs = outputs.reshape(
+ [-1, self.num_classes])[active_loss]
+ active_labels = labels.reshape([-1, ])[active_loss]
+ loss = self.loss_class(active_outputs, active_labels)
+ else:
+ loss = self.loss_class(
+ outputs.reshape([-1, self.num_classes]), labels.reshape([-1, ]))
+ return loss
diff --git a/ppstructure/vqa/requirements.txt b/ppstructure/vqa/requirements.txt
index c259fadc395335b336cb0ecdb5aa6bca48631987..9c935ae619024c9f47ced820eae35a3a1c976953 100644
--- a/ppstructure/vqa/requirements.txt
+++ b/ppstructure/vqa/requirements.txt
@@ -1,2 +1,3 @@
sentencepiece
yacs
+seqeval
\ No newline at end of file
diff --git a/ppstructure/vqa/train_re.py b/ppstructure/vqa/train_re.py
index ed19646cf57e69ac99e417ae27568655a4e00039..47d694678013295a1c664a7bdb6a7fe13a0b36a5 100644
--- a/ppstructure/vqa/train_re.py
+++ b/ppstructure/vqa/train_re.py
@@ -20,82 +20,25 @@ sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
+import time
import numpy as np
import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
-from utils import parse_args, get_bio_label_maps, print_arguments
+from utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from data_collator import DataCollator
-from metric import re_score
+from eval_re import evaluate
from ppocr.utils.logging import get_logger
-def set_seed(seed):
- random.seed(seed)
- np.random.seed(seed)
- paddle.seed(seed)
-
-
-def cal_metric(re_preds, re_labels, entities):
- gt_relations = []
- for b in range(len(re_labels)):
- rel_sent = []
- for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
- rel = {}
- rel["head_id"] = head
- rel["head"] = (entities[b]["start"][rel["head_id"]],
- entities[b]["end"][rel["head_id"]])
- rel["head_type"] = entities[b]["label"][rel["head_id"]]
-
- rel["tail_id"] = tail
- rel["tail"] = (entities[b]["start"][rel["tail_id"]],
- entities[b]["end"][rel["tail_id"]])
- rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
-
- rel["type"] = 1
- rel_sent.append(rel)
- gt_relations.append(rel_sent)
- re_metrics = re_score(re_preds, gt_relations, mode="boundaries")
- return re_metrics
-
-
-def evaluate(model, eval_dataloader, logger, prefix=""):
- # Eval!
- logger.info("***** Running evaluation {} *****".format(prefix))
- logger.info(" Num examples = {}".format(len(eval_dataloader.dataset)))
-
- re_preds = []
- re_labels = []
- entities = []
- eval_loss = 0.0
- model.eval()
- for idx, batch in enumerate(eval_dataloader):
- with paddle.no_grad():
- outputs = model(**batch)
- loss = outputs['loss'].mean().item()
- if paddle.distributed.get_rank() == 0:
- logger.info("[Eval] process: {}/{}, loss: {:.5f}".format(
- idx, len(eval_dataloader), loss))
-
- eval_loss += loss
- re_preds.extend(outputs['pred_relations'])
- re_labels.extend(batch['relations'])
- entities.extend(batch['entities'])
- re_metrics = cal_metric(re_preds, re_labels, entities)
- re_metrics = {
- "precision": re_metrics["ALL"]["p"],
- "recall": re_metrics["ALL"]["r"],
- "f1": re_metrics["ALL"]["f1"],
- }
- model.train()
- return re_metrics
-
-
def train(args):
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
+ rank = paddle.distributed.get_rank()
+ distributed = paddle.distributed.get_world_size() > 1
+
print_arguments(args, logger)
# Added here for reproducibility (even between python 2 and 3)
@@ -105,17 +48,22 @@ def train(args):
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
# dist mode
- if paddle.distributed.get_world_size() > 1:
+ if distributed:
paddle.distributed.init_parallel_env()
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
-
- model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
- model = LayoutXLMForRelationExtraction(model, dropout=None)
+ if not args.resume:
+ model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
+ model = LayoutXLMForRelationExtraction(model, dropout=None)
+ logger.info('train from scratch')
+ else:
+ logger.info('resume from {}'.format(args.model_name_or_path))
+ model = LayoutXLMForRelationExtraction.from_pretrained(
+ args.model_name_or_path)
# dist mode
- if paddle.distributed.get_world_size() > 1:
- model = paddle.distributed.DataParallel(model)
+ if distributed:
+ model = paddle.DataParallel(model)
train_dataset = XFUNDataset(
tokenizer,
@@ -145,19 +93,18 @@ def train(args):
train_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
- args.train_batch_size = args.per_gpu_train_batch_size * \
- max(1, paddle.distributed.get_world_size())
+
train_dataloader = paddle.io.DataLoader(
train_dataset,
batch_sampler=train_sampler,
- num_workers=8,
+ num_workers=args.num_workers,
use_shared_memory=True,
collate_fn=DataCollator())
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
- num_workers=8,
+ num_workers=args.num_workers,
shuffle=False,
collate_fn=DataCollator())
@@ -191,7 +138,8 @@ def train(args):
args.per_gpu_train_batch_size))
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = {}".
- format(args.train_batch_size * paddle.distributed.get_world_size()))
+ format(args.per_gpu_train_batch_size *
+ paddle.distributed.get_world_size()))
logger.info(" Total optimization steps = {}".format(t_total))
global_step = 0
@@ -200,58 +148,78 @@ def train(args):
best_metirc = {'f1': 0}
model.train()
+ train_reader_cost = 0.0
+ train_run_cost = 0.0
+ total_samples = 0
+ reader_start = time.time()
+
+ print_step = 1
+
for epoch in range(int(args.num_train_epochs)):
for step, batch in enumerate(train_dataloader):
+ train_reader_cost += time.time() - reader_start
+ train_start = time.time()
outputs = model(**batch)
+ train_run_cost += time.time() - train_start
# model outputs are always tuple in ppnlp (see doc)
loss = outputs['loss']
loss = loss.mean()
- logger.info(
- "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}".
- format(epoch, args.num_train_epochs, step, train_dataloader_len,
- global_step, np.mean(loss.numpy()), optimizer.get_lr()))
-
loss.backward()
optimizer.step()
optimizer.clear_grad()
# lr_scheduler.step() # Update learning rate schedule
global_step += 1
-
- if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and
- global_step % args.eval_steps == 0):
+ total_samples += batch['image'].shape[0]
+
+ if rank == 0 and step % print_step == 0:
+ logger.info(
+ "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
+ format(epoch, args.num_train_epochs, step,
+ train_dataloader_len, global_step,
+ np.mean(loss.numpy()),
+ optimizer.get_lr(), train_reader_cost / print_step, (
+ train_reader_cost + train_run_cost) / print_step,
+ total_samples / print_step, total_samples / (
+ train_reader_cost + train_run_cost)))
+
+ train_reader_cost = 0.0
+ train_run_cost = 0.0
+ total_samples = 0
+
+ if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
# Log metrics
- if (paddle.distributed.get_rank() == 0 and args.
- evaluate_during_training): # Only evaluate when single GPU otherwise metrics may not average well
- results = evaluate(model, eval_dataloader, logger)
- if results['f1'] > best_metirc['f1']:
- best_metirc = results
- output_dir = os.path.join(args.output_dir,
- "checkpoint-best")
- os.makedirs(output_dir, exist_ok=True)
+ # Only evaluate when single GPU otherwise metrics may not average well
+ results = evaluate(model, eval_dataloader, logger)
+ if results['f1'] >= best_metirc['f1']:
+ best_metirc = results
+ output_dir = os.path.join(args.output_dir, "best_model")
+ os.makedirs(output_dir, exist_ok=True)
+ if distributed:
+ model._layers.save_pretrained(output_dir)
+ else:
model.save_pretrained(output_dir)
- tokenizer.save_pretrained(output_dir)
- paddle.save(args,
- os.path.join(output_dir,
- "training_args.bin"))
- logger.info("Saving model checkpoint to {}".format(
- output_dir))
- logger.info("eval results: {}".format(results))
- logger.info("best_metirc: {}".format(best_metirc))
-
- if (paddle.distributed.get_rank() == 0 and args.save_steps > 0 and
- global_step % args.save_steps == 0):
- # Save model checkpoint
- output_dir = os.path.join(args.output_dir, "checkpoint-latest")
- os.makedirs(output_dir, exist_ok=True)
- if paddle.distributed.get_rank() == 0:
- model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args,
os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to {}".format(
output_dir))
+ logger.info("eval results: {}".format(results))
+ logger.info("best_metirc: {}".format(best_metirc))
+ reader_start = time.time()
+
+ if rank == 0:
+ # Save model checkpoint
+ output_dir = os.path.join(args.output_dir, "latest_model")
+ os.makedirs(output_dir, exist_ok=True)
+ if distributed:
+ model._layers.save_pretrained(output_dir)
+ else:
+ model.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ paddle.save(args, os.path.join(output_dir, "training_args.bin"))
+ logger.info("Saving model checkpoint to {}".format(output_dir))
logger.info("best_metirc: {}".format(best_metirc))
diff --git a/ppstructure/vqa/train_ser.py b/ppstructure/vqa/train_ser.py
index d3144e7167c59b5883047a948abaedfd21ba9b1c..2670ef9eeaf75cc1c9bb7d8f41d6f76d4300e597 100644
--- a/ppstructure/vqa/train_ser.py
+++ b/ppstructure/vqa/train_ser.py
@@ -20,6 +20,7 @@ sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
+import time
import copy
import logging
@@ -28,39 +29,52 @@ import paddle
import numpy as np
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
-from xfun import XFUNDataset
-from utils import parse_args
-from utils import get_bio_label_maps
-from utils import print_arguments
+from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
+from xfun import XFUNDataset
+from utils import parse_args, get_bio_label_maps, print_arguments, set_seed
+from eval_ser import evaluate
+from losses import SERLoss
from ppocr.utils.logging import get_logger
-
-def set_seed(args):
- random.seed(args.seed)
- np.random.seed(args.seed)
- paddle.seed(args.seed)
+MODELS = {
+ 'LayoutXLM':
+ (LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
+ 'LayoutLM':
+ (LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
+}
def train(args):
os.makedirs(args.output_dir, exist_ok=True)
+ rank = paddle.distributed.get_rank()
+ distributed = paddle.distributed.get_world_size() > 1
+
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
print_arguments(args, logger)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
- pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+ loss_class = SERLoss(len(label2id_map))
+
+ pad_token_label_id = loss_class.ignore_index
# dist mode
- if paddle.distributed.get_world_size() > 1:
+ if distributed:
paddle.distributed.init_parallel_env()
- tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
- base_model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
- model = LayoutXLMForTokenClassification(
- base_model, num_classes=len(label2id_map), dropout=None)
+ tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
+ tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
+ if not args.resume:
+ base_model = base_model_class.from_pretrained(args.model_name_or_path)
+ model = model_class(
+ base_model, num_classes=len(label2id_map), dropout=None)
+ logger.info('train from scratch')
+ else:
+ logger.info('resume from {}'.format(args.model_name_or_path))
+ model = model_class.from_pretrained(args.model_name_or_path)
# dist mode
- if paddle.distributed.get_world_size() > 1:
+ if distributed:
model = paddle.DataParallel(model)
train_dataset = XFUNDataset(
@@ -74,17 +88,32 @@ def train(args):
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
+ eval_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.eval_data_dir,
+ label_path=args.eval_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ pad_token_label_id=pad_token_label_id,
+ contains_re=False,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
train_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
- args.train_batch_size = args.per_gpu_train_batch_size * max(
- 1, paddle.distributed.get_world_size())
-
train_dataloader = paddle.io.DataLoader(
train_dataset,
batch_sampler=train_sampler,
- num_workers=0,
+ num_workers=args.num_workers,
+ use_shared_memory=True,
+ collate_fn=None, )
+
+ eval_dataloader = paddle.io.DataLoader(
+ eval_dataset,
+ batch_size=args.per_gpu_eval_batch_size,
+ num_workers=args.num_workers,
use_shared_memory=True,
collate_fn=None, )
@@ -117,182 +146,103 @@ def train(args):
args.per_gpu_train_batch_size)
logger.info(
" Total train batch size (w. parallel, distributed) = %d",
- args.train_batch_size * paddle.distributed.get_world_size(), )
+ args.per_gpu_train_batch_size * paddle.distributed.get_world_size(), )
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
tr_loss = 0.0
- set_seed(args)
+ set_seed(args.seed)
best_metrics = None
+ train_reader_cost = 0.0
+ train_run_cost = 0.0
+ total_samples = 0
+ reader_start = time.time()
+
+ print_step = 1
+ model.train()
for epoch_id in range(args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
- model.train()
+ train_reader_cost += time.time() - reader_start
+
+ if args.ser_model_type == 'LayoutLM':
+ if 'image' in batch:
+ batch.pop('image')
+ labels = batch.pop('labels')
+
+ train_start = time.time()
outputs = model(**batch)
+ train_run_cost += time.time() - train_start
+ if args.ser_model_type == 'LayoutXLM':
+ outputs = outputs[0]
+ loss = loss_class(labels, outputs, batch['attention_mask'])
+
# model outputs are always tuple in ppnlp (see doc)
- loss = outputs[0]
loss = loss.mean()
- logger.info(
- "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}".
- format(epoch_id, args.num_train_epochs, step,
- len(train_dataloader), global_step,
- loss.numpy()[0], lr_scheduler.get_lr()))
-
loss.backward()
tr_loss += loss.item()
optimizer.step()
lr_scheduler.step() # Update learning rate schedule
optimizer.clear_grad()
global_step += 1
-
- if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and
- global_step % args.eval_steps == 0):
+ total_samples += batch['input_ids'].shape[0]
+
+ if rank == 0 and step % print_step == 0:
+ logger.info(
+ "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
+ format(epoch_id, args.num_train_epochs, step,
+ len(train_dataloader), global_step,
+ loss.numpy()[0],
+ lr_scheduler.get_lr(), train_reader_cost /
+ print_step, (train_reader_cost + train_run_cost) /
+ print_step, total_samples / print_step, total_samples
+ / (train_reader_cost + train_run_cost)))
+
+ train_reader_cost = 0.0
+ train_run_cost = 0.0
+ total_samples = 0
+
+ if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
- if paddle.distributed.get_rank(
- ) == 0 and args.evaluate_during_training:
- results, _ = evaluate(args, model, tokenizer, label2id_map,
- id2label_map, pad_token_label_id,
- logger)
-
- if best_metrics is None or results["f1"] >= best_metrics[
- "f1"]:
- best_metrics = copy.deepcopy(results)
- output_dir = os.path.join(args.output_dir, "best_model")
- os.makedirs(output_dir, exist_ok=True)
- if paddle.distributed.get_rank() == 0:
- model.save_pretrained(output_dir)
- tokenizer.save_pretrained(output_dir)
- paddle.save(
- args,
- os.path.join(output_dir, "training_args.bin"))
- logger.info("Saving model checkpoint to %s",
- output_dir)
-
- logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format(
- epoch_id, args.num_train_epochs, step,
- len(train_dataloader), results))
- if best_metrics is not None:
- logger.info("best metrics: {}".format(best_metrics))
-
- if paddle.distributed.get_rank(
- ) == 0 and args.save_steps > 0 and global_step % args.save_steps == 0:
- # Save model checkpoint
- output_dir = os.path.join(args.output_dir,
- "checkpoint-{}".format(global_step))
- os.makedirs(output_dir, exist_ok=True)
- if paddle.distributed.get_rank() == 0:
- model.save_pretrained(output_dir)
+ results, _ = evaluate(args, model, tokenizer, loss_class,
+ eval_dataloader, label2id_map,
+ id2label_map, pad_token_label_id, logger)
+
+ if best_metrics is None or results["f1"] >= best_metrics["f1"]:
+ best_metrics = copy.deepcopy(results)
+ output_dir = os.path.join(args.output_dir, "best_model")
+ os.makedirs(output_dir, exist_ok=True)
+ if distributed:
+ model._layers.save_pretrained(output_dir)
+ else:
+ model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args,
os.path.join(output_dir, "training_args.bin"))
- logger.info("Saving model checkpoint to %s", output_dir)
-
+ logger.info("Saving model checkpoint to {}".format(
+ output_dir))
+
+ logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format(
+ epoch_id, args.num_train_epochs, step,
+ len(train_dataloader), results))
+ if best_metrics is not None:
+ logger.info("best metrics: {}".format(best_metrics))
+ reader_start = time.time()
+ if rank == 0:
+ # Save model checkpoint
+ output_dir = os.path.join(args.output_dir, "latest_model")
+ os.makedirs(output_dir, exist_ok=True)
+ if distributed:
+ model._layers.save_pretrained(output_dir)
+ else:
+ model.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ paddle.save(args, os.path.join(output_dir, "training_args.bin"))
+ logger.info("Saving model checkpoint to {}".format(output_dir))
return global_step, tr_loss / global_step
-def evaluate(args,
- model,
- tokenizer,
- label2id_map,
- id2label_map,
- pad_token_label_id,
- logger,
- prefix=""):
- eval_dataset = XFUNDataset(
- tokenizer,
- data_dir=args.eval_data_dir,
- label_path=args.eval_label_path,
- label2id_map=label2id_map,
- img_size=(224, 224),
- pad_token_label_id=pad_token_label_id,
- contains_re=False,
- add_special_ids=False,
- return_attention_mask=True,
- load_mode='all')
-
- args.eval_batch_size = args.per_gpu_eval_batch_size * max(
- 1, paddle.distributed.get_world_size())
-
- eval_dataloader = paddle.io.DataLoader(
- eval_dataset,
- batch_size=args.eval_batch_size,
- num_workers=0,
- use_shared_memory=True,
- collate_fn=None, )
-
- # Eval!
- logger.info("***** Running evaluation %s *****", prefix)
- logger.info(" Num examples = %d", len(eval_dataset))
- logger.info(" Batch size = %d", args.eval_batch_size)
- eval_loss = 0.0
- nb_eval_steps = 0
- preds = None
- out_label_ids = None
- model.eval()
- for idx, batch in enumerate(eval_dataloader):
- with paddle.no_grad():
- outputs = model(**batch)
- tmp_eval_loss, logits = outputs[:2]
-
- tmp_eval_loss = tmp_eval_loss.mean()
-
- if paddle.distributed.get_rank() == 0:
- logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
- idx, len(eval_dataloader), tmp_eval_loss.numpy()[0]))
-
- eval_loss += tmp_eval_loss.item()
- nb_eval_steps += 1
- if preds is None:
- preds = logits.numpy()
- out_label_ids = batch["labels"].numpy()
- else:
- preds = np.append(preds, logits.numpy(), axis=0)
- out_label_ids = np.append(
- out_label_ids, batch["labels"].numpy(), axis=0)
-
- eval_loss = eval_loss / nb_eval_steps
- preds = np.argmax(preds, axis=2)
-
- # label_map = {i: label.upper() for i, label in enumerate(labels)}
-
- out_label_list = [[] for _ in range(out_label_ids.shape[0])]
- preds_list = [[] for _ in range(out_label_ids.shape[0])]
-
- for i in range(out_label_ids.shape[0]):
- for j in range(out_label_ids.shape[1]):
- if out_label_ids[i, j] != pad_token_label_id:
- out_label_list[i].append(id2label_map[out_label_ids[i][j]])
- preds_list[i].append(id2label_map[preds[i][j]])
-
- results = {
- "loss": eval_loss,
- "precision": precision_score(out_label_list, preds_list),
- "recall": recall_score(out_label_list, preds_list),
- "f1": f1_score(out_label_list, preds_list),
- }
-
- with open(os.path.join(args.output_dir, "test_gt.txt"), "w") as fout:
- for lbl in out_label_list:
- for l in lbl:
- fout.write(l + "\t")
- fout.write("\n")
- with open(os.path.join(args.output_dir, "test_pred.txt"), "w") as fout:
- for lbl in preds_list:
- for l in lbl:
- fout.write(l + "\t")
- fout.write("\n")
-
- report = classification_report(out_label_list, preds_list)
- logger.info("\n" + report)
-
- logger.info("***** Eval results %s *****", prefix)
- for key in sorted(results.keys()):
- logger.info(" %s = %s", key, str(results[key]))
-
- return results, preds_list
-
-
if __name__ == "__main__":
args = parse_args()
train(args)
diff --git a/ppstructure/vqa/utils.py b/ppstructure/vqa/utils.py
index 0af180ada2eae740c042378c73b884239ddbf7b9..b9f2edc860b1ce48c22bf602cef48466c357834f 100644
--- a/ppstructure/vqa/utils.py
+++ b/ppstructure/vqa/utils.py
@@ -25,8 +25,14 @@ import paddle
from PIL import Image, ImageDraw, ImageFont
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ paddle.seed(seed)
+
+
def get_bio_label_maps(label_map_path):
- with open(label_map_path, "r") as fin:
+ with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
lines = [line.strip() for line in lines]
if "O" not in lines:
@@ -344,6 +350,8 @@ def parse_args():
# yapf: disable
parser.add_argument("--model_name_or_path",
default=None, type=str, required=True,)
+ parser.add_argument("--ser_model_type",
+ default='LayoutXLM', type=str)
parser.add_argument("--re_model_name_or_path",
default=None, type=str, required=False,)
parser.add_argument("--train_data_dir", default=None,
@@ -357,6 +365,7 @@ def parse_args():
parser.add_argument("--output_dir", default=None, type=str, required=True,)
parser.add_argument("--max_seq_length", default=512, type=int,)
parser.add_argument("--evaluate_during_training", action="store_true",)
+ parser.add_argument("--num_workers", default=8, type=int,)
parser.add_argument("--per_gpu_train_batch_size", default=8,
type=int, help="Batch size per GPU/CPU for training.",)
parser.add_argument("--per_gpu_eval_batch_size", default=8,
@@ -375,16 +384,15 @@ def parse_args():
help="Linear warmup over warmup_steps.",)
parser.add_argument("--eval_steps", type=int, default=10,
help="eval every X updates steps.",)
- parser.add_argument("--save_steps", type=int, default=50,
- help="Save checkpoint every X updates steps.",)
parser.add_argument("--seed", type=int, default=2048,
help="random seed for initialization",)
- parser.add_argument("--ocr_rec_model_dir", default=None, type=str, )
- parser.add_argument("--ocr_det_model_dir", default=None, type=str, )
+ parser.add_argument("--rec_model_dir", default=None, type=str, )
+ parser.add_argument("--det_model_dir", default=None, type=str, )
parser.add_argument(
"--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
parser.add_argument("--infer_imgs", default=None, type=str, required=False)
+ parser.add_argument("--resume", action='store_true')
parser.add_argument("--ocr_json_path", default=None,
type=str, required=False, help="ocr prediction results")
# yapf: enable
diff --git a/ppstructure/vqa/xfun.py b/ppstructure/vqa/xfun.py
index d62cdb5da5514280b62687d80d345ede9484ee90..f5dbe507e8f6d22087d7913241f7365cbede9bdf 100644
--- a/ppstructure/vqa/xfun.py
+++ b/ppstructure/vqa/xfun.py
@@ -79,14 +79,36 @@ class XFUNDataset(Dataset):
self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
self.return_keys = {
- 'bbox': 'np',
- 'input_ids': 'np',
- 'labels': 'np',
- 'attention_mask': 'np',
- 'image': 'np',
- 'token_type_ids': 'np',
- 'entities': 'dict',
- 'relations': 'dict',
+ 'bbox': {
+ 'type': 'np',
+ 'dtype': 'int64'
+ },
+ 'input_ids': {
+ 'type': 'np',
+ 'dtype': 'int64'
+ },
+ 'labels': {
+ 'type': 'np',
+ 'dtype': 'int64'
+ },
+ 'attention_mask': {
+ 'type': 'np',
+ 'dtype': 'int64'
+ },
+ 'image': {
+ 'type': 'np',
+ 'dtype': 'float32'
+ },
+ 'token_type_ids': {
+ 'type': 'np',
+ 'dtype': 'int64'
+ },
+ 'entities': {
+ 'type': 'dict'
+ },
+ 'relations': {
+ 'type': 'dict'
+ }
}
if load_mode == "all":
@@ -103,7 +125,7 @@ class XFUNDataset(Dataset):
return_special_tokens_mask=False):
# Padding
needs_to_be_padded = pad_to_max_seq_len and \
- max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
+ max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
if needs_to_be_padded:
difference = max_seq_len - len(encoded_inputs["input_ids"])
@@ -162,7 +184,7 @@ class XFUNDataset(Dataset):
return encoded_inputs
def read_all_lines(self, ):
- with open(self.label_path, "r") as fin:
+ with open(self.label_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
return lines
@@ -412,8 +434,8 @@ class XFUNDataset(Dataset):
return_data = {}
for k, v in data.items():
if k in self.return_keys:
- if self.return_keys[k] == 'np':
- v = np.array(v)
+ if self.return_keys[k]['type'] == 'np':
+ v = np.array(v, dtype=self.return_keys[k]['dtype'])
return_data[k] = v
return return_data
diff --git a/requirements.txt b/requirements.txt
index 0c87c5c95069a2699f5a3a50320c883c6118ffe7..9900588b25df99e0853ec4521f0632578c55f530 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,4 +12,5 @@ cython
lxml
premailer
openpyxl
-fasttext==0.9.1
\ No newline at end of file
+fasttext==0.9.1
+
diff --git a/test_tipc/configs/ch_PP-OCRv2_det/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_det/train_infer_python.txt
index 9520ede3acd33b0e12300ee2de1b715605c9a0eb..b8db0ff19287c6db3d48758b22602252b5b2c6cc 100644
--- a/test_tipc/configs/ch_PP-OCRv2_det/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_det/train_infer_python.txt
@@ -12,9 +12,9 @@ train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
-trainer:norm_train|pact_train
+trainer:norm_train
norm_train:tools/train.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
-pact_train:deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
+pact_train:null
fpgm_train:null
distill_train:null
null:null
@@ -26,9 +26,9 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
-quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
+quant_export:null
fpgm_export:
distill_export:null
export1:null
diff --git a/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt
index 1246e380c1c113e3c96e2b2962f28fd865a8717d..70292f49c960c14cf390d0168a510f3f20a5631f 100644
--- a/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:PPOCRv2_det
+model_name:ch_PPOCRv2_det_PACT
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt
index 4607b0a7f5d2ffb082ecb84d80b3534d75e14f5f..b2de2a5e52f75071dc0d3b8e8f26d8b87cfecfd7 100644
--- a/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
quant_export:
fpgm_export:
@@ -43,7 +43,7 @@ inference:tools/infer/predict_rec.py
--cpu_threads:1|6
--rec_batch_num:1|6
--use_tensorrt:False|True
---precision:fp32|fp16|int8
+--precision:fp32|int8
--rec_model_dir:
--image_dir:./inference/rec_inference
null:null
diff --git a/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt
index 6127896ae29dc5f4d2813e84824cda5fa0bac7ca..9102382fd314101753fdc895d3219329b42263f9 100644
--- a/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:null
quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
fpgm_export: null
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
inference_dir:Student
-infer_model:./inference/ch_PP-OCRv2_rec_infer
+infer_model:./inference/ch_PP-OCRv2_rec_slim_quant_infer
infer_export:null
infer_quant:True
inference:tools/infer/predict_rec.py
@@ -43,7 +43,7 @@ inference:tools/infer/predict_rec.py
--cpu_threads:1|6
--rec_batch_num:1|6
--use_tensorrt:False|True
---precision:fp32|fp16|int8
+--precision:fp32|int8
--rec_model_dir:
--image_dir:./inference/rec_inference
null:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt
index 05cde05467d75769965ee23bce2cebfc20408251..372b8ad4137cc19a8c1dfc59b99a00d525ae466f 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ocr_det
+model_name:ch_ppocr_mobile_v2.0_det_PACT
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
inference_dir:null
-train_model:null
+train_model:./inference/ch_ppocr_mobile_v2.0_det_prune_infer/
infer_export:null
infer_quant:False
inference:tools/infer/predict_det.py
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt
index c93b83e5dcab1aab56ea5fa1a178e3dc7ec3c2e4..63d1f0583ea114ed89b7f2cdc6e2299e6bc8f2a4 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt
@@ -4,9 +4,9 @@ python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
-Global.epoch_num:lite_train_infer=2|whole_train_infer=300
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_infer=128|whole_train_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c configs/rec/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
@@ -34,16 +34,16 @@ distill_export:null
export1:null
export2:null
##
-infer_model:null
+train_model:./inference/ch_ppocr_mobile_v2.0_rec_train/best_accuracy
infer_export:tools/export_model.py -c configs/rec/rec_icdar15_train.yml -o
infer_quant:False
-inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="RARE"
+inference:tools/infer/predict_rec.py
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1|6
--use_tensorrt:True|False
---precision:fp32|fp16|int8
+--precision:fp32|int8
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt
index 56b9e1896c2a1e9a7ab002884cfbc5de86997535..afbf2ef5e19344e8144e1cea81e3671fdd44559d 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
inference_dir:null
-train_model:null
+infer_model:./inference/ch_ppocr_mobile_v2.0_rec_slim_infer/
infer_export:null
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ppocr_keys_v1.txt --rec_image_shape="3,32,100"
@@ -43,7 +43,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ppocr_ke
--cpu_threads:1|6
--rec_batch_num:1|6
--use_tensorrt:False|True
---precision:fp32|fp16|int8
+--precision:fp32|int8
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt b/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt
index f10985a91902716968660af5188473e4f1a7ae3d..c42edbee4dd2a26afff94f6028ca7a8f4170648e 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
##
-infer_model:null
+train_model:./inference/ch_ppocr_server_v2.0_rec_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py
diff --git a/test_tipc/configs/det_mv3_east_v2.0/train_infer_python.txt b/test_tipc/configs/det_mv3_east_v2.0/train_infer_python.txt
index 7a3aced57aaf31bb54075d8ba3119d1626a2c58a..230a799f2e6d49b6bc5816fd53724259e1b881c3 100644
--- a/test_tipc/configs/det_mv3_east_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_mv3_east_v2.0/train_infer_python.txt
@@ -34,15 +34,15 @@ distill_export:null
export1:null
export2:null
##
-train_model:./inference/det_mv3_east/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/cconfigs/det_mv3_east_v2.0/det_mv3_east.yml -o
+train_model:./inference/det_mv3_east_v2.0_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/det_mv3_east_v2.0/det_mv3_east.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
---use_tensorrt:False|True
+--use_tensorrt:False
--precision:fp32|fp16|int8
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
diff --git a/test_tipc/configs/det_mv3_pse_v2.0/train_infer_python.txt b/test_tipc/configs/det_mv3_pse_v2.0/train_infer_python.txt
index f9909027f10d9e9f96d65f9f5a1c5f3fd5c9e1c6..0171a97ae6c88dd13e74d85eb59bb019dad954f7 100644
--- a/test_tipc/configs/det_mv3_pse_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_mv3_pse_v2.0/train_infer_python.txt
@@ -34,16 +34,16 @@ distill_export:null
export1:null
export2:null
##
-train_model:./inference/det_mv3_pse/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/cconfigs/det_mv3_pse_v2.0/det_mv3_pse.yml -o
+train_model:./inference/det_mv3_pse_v2.0_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
---use_tensorrt:False|True
---precision:fp32|fp16|int8
+--use_tensorrt:False
+--precision:fp32|fp16
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
--save_log_path:null
diff --git a/test_tipc/configs/det_r50_vd_east_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_east_v2.0/train_infer_python.txt
index dfb376237ee35c277fcd86a88328c562d5c0429a..45023ae3eeebc925d61e1686e0c18c75085b2ab4 100644
--- a/test_tipc/configs/det_r50_vd_east_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_east_v2.0/train_infer_python.txt
@@ -34,15 +34,15 @@ distill_export:null
export1:null
export2:null
##
-train_model:./inference/det_r50_vd_east/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/cconfigs/det_r50_vd_east_v2.0/det_r50_vd_east.yml -o
+train_model:./inference/det_r50_vd_east_v2.0_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_east_v2.0/det_r50_vd_east.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
---use_tensorrt:False|True
+--use_tensorrt:False
--precision:fp32|fp16|int8
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
diff --git a/test_tipc/configs/det_r50_vd_pse_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_pse_v2.0/train_infer_python.txt
index c60f4263ebc734acf3136a6542bb9e882658af2b..d81542ea2e11fcddfc403fae686bbfab419de254 100644
--- a/test_tipc/configs/det_r50_vd_pse_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_pse_v2.0/train_infer_python.txt
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
##
-train_model:./inference/det_r50_vd_pse/best_accuracy
+train_model:./inference/det_r50_vd_pse_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_pse_v2.0/det_r50_vd_pse.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
@@ -42,7 +42,7 @@ inference:tools/infer/predict_det.py
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
---use_tensorrt:False|True
+--use_tensorrt:False
--precision:fp32|fp16|int8
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
diff --git a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
index e6fb2ca5b459d26cd4b099c17f81bb47cc59bc71..f6ff061ff5a1e0ba914bbe69684a1fa60cdfff5d 100644
--- a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
@@ -42,7 +42,7 @@ inference:tools/infer/predict_det.py
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
---use_tensorrt:False|True
+--use_tensorrt:False
--precision:fp32|fp16|int8
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
diff --git a/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt
index 2387ba7b5e9bac09b4c85fa5273d0c6ba5bebcb5..54921cb1a8d361cdaba7c7c5154cb2730ef0ec77 100644
--- a/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
inference_dir:null
-train_model:./inference/det_r50_vd_sast_totaltext_v2.0/best_accuracy
+train_model:./inference/det_r50_vd_sast_totaltext_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/det_r50_vd_sast_totaltext.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
@@ -42,7 +42,7 @@ inference:tools/infer/predict_det.py
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
---use_tensorrt:False|True
+--use_tensorrt:False
--precision:fp32|fp16|int8
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
diff --git a/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt b/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt
index 695fc8a42ef0f6b79901e8b62ce09d72e3500793..2adca464a63d548f2b218ed1de91692ed25da89a 100644
--- a/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt
+++ b/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_mtb_nrtr/rec_mtb_nrtr.yml -o
quant_export:null
fpgm_export:null
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
##
-infer_model:null
+train_model:./inference/rec_mtb_nrtr_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_mtb_nrtr/rec_mtb_nrtr.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbol_dict.txt --rec_image_shape="1,32,100" --rec_algorithm="NRTR"
diff --git a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt
index fdc39f1b851a4e05735744f878917a3dfcc1d405..ac565d8c55b1924e7a39fd8e36456a74fbbce042 100644
--- a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
##
-infer_model:null
+train_model:./inference/rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
diff --git a/test_tipc/configs/rec_mv3_none_none_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_none_none_ctc_v2.0/train_infer_python.txt
index 9810689679903b4cedff253834a1a999c4e8a5f8..947399a83cedc1f4262374e2c5ba5f3221561f0d 100644
--- a/test_tipc/configs/rec_mv3_none_none_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_none_none_ctc_v2.0/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
##
-infer_model:null
+train_model:./inference/rec_mv3_none_none_ctc_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt
index 18504d068740deeec42cf9620c2d9e816d88c5cc..5fcfeee5e1835504d08cf24b0180a5af105be092 100644
--- a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml -o
quant_export:null
fpgm_export:null
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
##
-infer_model:null
+train_model:./inference/rec_mv3_tps_bilstm_att_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="RARE"
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt
index 3bec644ced183fff4329ff08991a137c45bacfc9..ac3fce6141ccbf96169d862b8b92f59af597db56 100644
--- a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
##
-infer_model:null
+train_model:./inference/rec_mv3_tps_bilstm_ctc_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="StarNet"
diff --git a/test_tipc/configs/rec_r31_sar/train_infer_python.txt b/test_tipc/configs/rec_r31_sar/train_infer_python.txt
index 42dfc6b0275c05aef358682d031275488893e5fb..e4d7825243378709b965f59740c0360f11bdb957 100644
--- a/test_tipc/configs/rec_r31_sar/train_infer_python.txt
+++ b/test_tipc/configs/rec_r31_sar/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_r31_sar/rec_r31_sar.yml -o
quant_export:null
fpgm_export:null
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
##
-infer_model:null
+train_model:./inference/rec_r31_sar_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r31_sar/rec_r31_sar.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.txt --rec_image_shape="3,48,48,160" --rec_algorithm="SAR"
@@ -43,7 +43,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.t
--cpu_threads:1|6
--rec_batch_num:1|6
--use_tensorrt:True|False
---precision:fp32|fp16|int8
+--precision:fp32|int8
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
diff --git a/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/train_infer_python.txt
index 857a212fe6f5e0bd9612b55841e748e6b4409061..99f86872574bc300d3447efc0e4c83eaa88aab6c 100644
--- a/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
##
-infer_model:null
+train_model:./inference/rec_r34_vd_none_bilstm_ctc_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
diff --git a/test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/train_infer_python.txt
index 85804b7453729dad6d2e87d0efd1a053dd9a0aac..fb1ece49f71338307bfdf30714cd68cb382ea5e2 100644
--- a/test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
##
-infer_model:null
+train_model:./inference/rec_r34_vd_none_none_ctc_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt
index 84bda52480118f84ec5efbc1d4831950b1cdee68..acc9749f08b42f7fa2200da7ef865f710afc77c3 100644
--- a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml -o
quant_export:null
fpgm_export:null
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
##
-infer_model:null
+train_model:./inference/rec_r34_vd_tps_bilstm_att_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="RARE"
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt
index ac43bd9703d7744220af40fa36b29adf64e89334..d11850528604074e9bb3d3d92b58ec709238b24b 100644
--- a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
##
-infer_model:null
+train_model:./inference/rec_r34_vd_tps_bilstm_ctc_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="StarNet"
diff --git a/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt b/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt
index 55b25122e3d934ae66051595cc0bdc75aa3386fc..fb135df60b7716fd46a48482c0d7e8a3faca579a 100644
--- a/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt
+++ b/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_r50_fpn_vd_none_srn/rec_r50_fpn_srn.yml -o
quant_export:null
fpgm_export:null
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
##
-infer_model:null
+train_model:./inference/rec_r50_vd_srn_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r50_fpn_vd_none_srn/rec_r50_fpn_srn.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="1,64,256" --rec_algorithm="SRN" --use_space_char=False
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index 71d4010f4b2c3abe698e22b7e1e8f33e9ef9d45f..6b9d3abb7e8e22ce45811bcdd7ea8791ebeb8312 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -104,13 +104,17 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then
elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
- cd ./inference && tar xf rec_inference.tar && cd ../
+ cd ./inference && tar xf rec_inference.tar && tar xf ch_det_data_50.tar && cd ../
if [ ${model_name} = "ch_ppocr_mobile_v2.0_det" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_det_train"
rm -rf ./train_data/icdar2015
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && cd ../
+ elif [ ${model_name} = "ch_ppocr_mobile_v2.0_det_PACT" ]; then
+ eval_model_name="ch_ppocr_mobile_v2.0_det_prune_infer"
+ wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar --no-check-certificate
+ cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_train.tar && tar xf ch_det_data_50.tar && cd ../
@@ -122,21 +126,13 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec" ]; then
- eval_model_name="ch_ppocr_mobile_v2.0_rec_infer"
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
- cd ./inference && tar xf ${eval_model_name}.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_server_v2.0_rec" ]; then
- eval_model_name="ch_ppocr_server_v2.0_rec_infer"
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
- cd ./inference && tar xf ${eval_model_name}.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_PACT" ]; then
- eval_model_name="ch_PP-OCRv2_rec_slim_quant_train"
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar --no-check-certificate
+ eval_model_name="ch_ppocr_mobile_v2.0_rec_slim_infer"
+ wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_FPGM" ]; then
- eval_model_name="ch_PP-OCRv2_rec_train"
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar --no-check-certificate
+ eval_model_name="ch_PP-OCRv2_rec_infer"
+ wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
fi
if [[ ${model_name} =~ "ch_PPOCRv2_det" ]]; then
@@ -147,7 +143,8 @@ elif [ ${MODE} = "whole_infer" ];then
if [[ ${model_name} =~ "PPOCRv2_ocr_rec" ]]; then
eval_model_name="ch_PP-OCRv2_rec_infer"
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar --no-check-certificate
- cd ./inference && tar xf ${eval_model_name}.tar && cd ../
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar --no-check-certificate
+ cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_PP-OCRv2_rec_slim_quant_infer.tar && cd ../
fi
if [ ${model_name} == "en_server_pgnetA" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate
@@ -157,6 +154,63 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_sast_icdar15_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
+ if [ ${model_name} == "rec_mv3_none_none_ctc_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf rec_mv3_none_none_ctc_v2.0_train.tar && cd ../
+ fi
+ if [ ${model_name} == "rec_r34_vd_none_none_ctc_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf rec_r34_vd_none_none_ctc_v2.0_train.tar && cd ../
+ fi
+ if [ ${model_name} == "rec_mv3_none_bilstm_ctc_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && cd ../
+ fi
+ if [ ${model_name} == "rec_r34_vd_none_bilstm_ctc_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf rec_r34_vd_none_bilstm_ctc_v2.0_train.tar && cd ../
+ fi
+ if [ ${model_name} == "rec_mv3_tps_bilstm_ctc_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf rec_mv3_tps_bilstm_ctc_v2.0_train.tar && cd ../
+ fi
+ if [ ${model_name} == "rec_r34_vd_tps_bilstm_ctc_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar && cd ../
+ fi
+ if [ ${model_name} == "ch_ppocr_server_v2.0_rec" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/ch_ppocr_server_v2.0_rec_train.tar --no-check-certificate
+ cd ./inference/ && tar xf ch_ppocr_server_v2.0_rec_train.tar && cd ../
+ fi
+ if [ ${model_name} == "ch_ppocr_mobile_v2.0_rec" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate
+ cd ./inference/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../
+ fi
+ if [ ${model_name} == "rec_mtb_nrtr" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar --no-check-certificate
+ cd ./inference/ && tar xf rec_mtb_nrtr_train.tar && cd ../
+ fi
+ if [ ${model_name} == "rec_mv3_tps_bilstm_att_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf rec_mv3_tps_bilstm_att_v2.0_train.tar && cd ../
+ fi
+ if [ ${model_name} == "rec_r34_vd_tps_bilstm_att_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf rec_r34_vd_tps_bilstm_att_v2.0_train.tar && cd ../
+ fi
+ if [ ${model_name} == "rec_r31_sar" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar --no-check-certificate
+ cd ./inference/ && tar xf rec_r31_sar_train.tar && cd ../
+ fi
+ if [ ${model_name} == "rec_r50_fpn_vd_none_srn" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar --no-check-certificate
+ cd ./inference/ && tar xf rec_r50_vd_srn_train.tar && cd ../
+ fi
+
+ if [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf det_r50_vd_sast_totaltext_v2.0_train.tar && cd ../
+ fi
if [ ${model_name} == "det_mv3_db_v2.0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
@@ -165,7 +219,24 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
+ if [ ${model_name} == "det_mv3_pse_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf det_mv3_pse_v2.0_train.tar & cd ../
+ fi
+ if [ ${model_name} == "det_r50_vd_pse_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf det_r50_vd_pse_v2.0_train.tar & cd ../
+ fi
+ if [ ${model_name} == "det_mv3_east_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf det_mv3_east_v2.0_train.tar & cd ../
+ fi
+ if [ ${model_name} == "det_r50_vd_east_v2.0" ]; then
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
+ cd ./inference/ && tar xf det_r50_vd_east_v2.0_train.tar & cd ../
+ fi
fi
+
if [ ${MODE} = "klquant_whole_infer" ]; then
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_lite.tar
diff --git a/tools/eval.py b/tools/eval.py
index c85490a316772e9dfdfe3267087ea3946a2a3b72..13a4a0882f5a20b47e8999042713e1623b32ff5a 100755
--- a/tools/eval.py
+++ b/tools/eval.py
@@ -54,7 +54,8 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
- extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"]
+ extra_input = config['Architecture'][
+ 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type']
else:
@@ -68,7 +69,6 @@ def main():
# build metric
eval_class = build_metric(config['Metric'])
-
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, extra_input)
diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py
index a25cac2600e67667badc76c648c1fcda12981a0f..ab3f4b04f0c306aaf7e26eb98e781938b7528275 100755
--- a/tools/infer/predict_cls.py
+++ b/tools/infer/predict_cls.py
@@ -145,8 +145,6 @@ def main(args):
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
cls_res[ino]))
- logger.info(
- "The predict time about text angle classify module is as follows: ")
if __name__ == "__main__":
diff --git a/tools/infer_det.py b/tools/infer_det.py
index 1c679e0faf0d3ebdb6ca7ed4c317ce3eecfa910f..9d2daf13ad6ad3de396ea1587c3d25cccb126eac 100755
--- a/tools/infer_det.py
+++ b/tools/infer_det.py
@@ -126,9 +126,6 @@ def main():
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
- save_det_path = os.path.dirname(config['Global'][
- 'save_res_path']) + "/det_results/"
- draw_det_res(boxes, config, src_img, file, save_det_path)
logger.info("success!")
diff --git a/tools/infer_kie.py b/tools/infer_kie.py
new file mode 100755
index 0000000000000000000000000000000000000000..16294e59cc51727f39af77d16255ef4d0f2a1bd8
--- /dev/null
+++ b/tools/infer_kie.py
@@ -0,0 +1,153 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle.nn.functional as F
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import paddle
+
+from ppocr.data import create_operators, transform
+from ppocr.modeling.architectures import build_model
+from ppocr.utils.save_load import load_model
+import tools.program as program
+import time
+
+
+def read_class_list(filepath):
+ dict = {}
+ with open(filepath, "r") as f:
+ lines = f.readlines()
+ for line in lines:
+ key, value = line.split(" ")
+ dict[key] = value.rstrip()
+ return dict
+
+
+def draw_kie_result(batch, node, idx_to_cls, count):
+ img = batch[6].copy()
+ boxes = batch[7]
+ h, w = img.shape[:2]
+ pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255
+ max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
+ node_pred_label = max_idx.numpy().tolist()
+ node_pred_score = max_value.numpy().tolist()
+
+ for i, box in enumerate(boxes):
+ if i >= len(node_pred_label):
+ break
+ new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
+ [box[0], box[3]]]
+ Pts = np.array([new_box], np.int32)
+ cv2.polylines(
+ img, [Pts.reshape((-1, 1, 2))],
+ True,
+ color=(255, 255, 0),
+ thickness=1)
+ x_min = int(min([point[0] for point in new_box]))
+ y_min = int(min([point[1] for point in new_box]))
+
+ pred_label = str(node_pred_label[i])
+ if pred_label in idx_to_cls:
+ pred_label = idx_to_cls[pred_label]
+ pred_score = '{:.2f}'.format(node_pred_score[i])
+ text = pred_label + '(' + pred_score + ')'
+ cv2.putText(pred_img, text, (x_min * 2, y_min),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
+ vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
+ vis_img[:, :w] = img
+ vis_img[:, w:] = pred_img
+ save_kie_path = os.path.dirname(config['Global'][
+ 'save_res_path']) + "/kie_results/"
+ if not os.path.exists(save_kie_path):
+ os.makedirs(save_kie_path)
+ save_path = os.path.join(save_kie_path, str(count) + ".png")
+ cv2.imwrite(save_path, vis_img)
+ logger.info("The Kie Image saved in {}".format(save_path))
+
+
+def main():
+ global_config = config['Global']
+
+ # build model
+ model = build_model(config['Architecture'])
+ load_model(config, model)
+
+ # create data ops
+ transforms = []
+ for op in config['Eval']['dataset']['transforms']:
+ transforms.append(op)
+
+ data_dir = config['Eval']['dataset']['data_dir']
+
+ ops = create_operators(transforms, global_config)
+
+ save_res_path = config['Global']['save_res_path']
+ class_path = config['Global']['class_path']
+ idx_to_cls = read_class_list(class_path)
+ if not os.path.exists(os.path.dirname(save_res_path)):
+ os.makedirs(os.path.dirname(save_res_path))
+
+ model.eval()
+
+ warmup_times = 0
+ count_t = []
+ with open(save_res_path, "wb") as fout:
+ with open(config['Global']['infer_img'], "rb") as f:
+ lines = f.readlines()
+ for index, data_line in enumerate(lines):
+ if index == 10:
+ warmup_t = time.time()
+ data_line = data_line.decode('utf-8')
+ substr = data_line.strip("\n").split("\t")
+ img_path, label = data_dir + "/" + substr[0], substr[1]
+ data = {'img_path': img_path, 'label': label}
+ with open(data['img_path'], 'rb') as f:
+ img = f.read()
+ data['image'] = img
+ st = time.time()
+ batch = transform(data, ops)
+ batch_pred = [0] * len(batch)
+ for i in range(len(batch)):
+ batch_pred[i] = paddle.to_tensor(
+ np.expand_dims(
+ batch[i], axis=0))
+ st = time.time()
+ node, edge = model(batch_pred)
+ node = F.softmax(node, -1)
+ count_t.append(time.time() - st)
+ draw_kie_result(batch, node, idx_to_cls, index)
+ logger.info("success!")
+ logger.info("It took {} s for predict {} images.".format(
+ np.sum(count_t), len(count_t)))
+ ips = len(count_t[warmup_times:]) / np.sum(count_t[warmup_times:])
+ logger.info("The ips is {} images/s".format(ips))
+
+
+if __name__ == '__main__':
+ config, device, logger, vdl_writer = program.preprocess()
+ main()
diff --git a/tools/program.py b/tools/program.py
index d110f70704028948dff2bc889e07d128e0bc94ea..333e8ed9770cad08ba5e9aa47edec850a74a1808 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -239,6 +239,8 @@ def train(config,
else:
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
+ elif model_type == "kie":
+ preds = model(batch)
else:
preds = model(images)
loss = loss_class(preds, batch)
@@ -266,7 +268,7 @@ def train(config,
if cal_metric_during_train: # only rec and cls need
batch = [item.numpy() for item in batch]
- if model_type == 'table':
+ if model_type in ['table', 'kie']:
eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
@@ -399,17 +401,20 @@ def eval(model,
start = time.time()
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
+ elif model_type == "kie":
+ preds = model(batch)
else:
preds = model(images)
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods
total_time += time.time() - start
# Evaluate the results of the current batch
- if model_type == 'table':
+ if model_type in ['table', 'kie']:
eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
+
pbar.update(1)
total_frame += len(images)
# Get final metric,eg. acc or hmean
@@ -498,8 +503,13 @@ def preprocess(is_train=False):
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
- 'SEED'
+ 'SEED', 'SDMGR'
]
+ windows_not_support_list = ['PSE']
+ if platform.system() == "Windows" and alg in windows_not_support_list:
+ logger.warning('{} is not support in Windows now'.format(
+ windows_not_support_list))
+ sys.exit()
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)