提交 06194524 编写于 作者: 文幕地方's avatar 文幕地方

add re predict

上级 0a59848d
...@@ -68,6 +68,7 @@ Train: ...@@ -68,6 +68,7 @@ Train:
- VQAReTokenRelation: - VQAReTokenRelation:
- VQAReTokenChunk: - VQAReTokenChunk:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize: - Resize:
size: [224,224] size: [224,224]
- NormalizeImage: - NormalizeImage:
...@@ -83,7 +84,6 @@ Train: ...@@ -83,7 +84,6 @@ Train:
drop_last: False drop_last: False
batch_size_per_card: 2 batch_size_per_card: 2
num_workers: 8 num_workers: 8
collate_fn: ListCollator
Eval: Eval:
dataset: dataset:
...@@ -105,6 +105,7 @@ Eval: ...@@ -105,6 +105,7 @@ Eval:
- VQAReTokenRelation: - VQAReTokenRelation:
- VQAReTokenChunk: - VQAReTokenChunk:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize: - Resize:
size: [224,224] size: [224,224]
- NormalizeImage: - NormalizeImage:
...@@ -120,4 +121,3 @@ Eval: ...@@ -120,4 +121,3 @@ Eval:
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 8
num_workers: 8 num_workers: 8
collate_fn: ListCollator
...@@ -73,6 +73,7 @@ Train: ...@@ -73,6 +73,7 @@ Train:
- VQAReTokenRelation: - VQAReTokenRelation:
- VQAReTokenChunk: - VQAReTokenChunk:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize: - Resize:
size: [224,224] size: [224,224]
- NormalizeImage: - NormalizeImage:
...@@ -88,7 +89,6 @@ Train: ...@@ -88,7 +89,6 @@ Train:
drop_last: False drop_last: False
batch_size_per_card: 2 batch_size_per_card: 2
num_workers: 4 num_workers: 4
collate_fn: ListCollator
Eval: Eval:
dataset: dataset:
...@@ -112,6 +112,7 @@ Eval: ...@@ -112,6 +112,7 @@ Eval:
- VQAReTokenRelation: - VQAReTokenRelation:
- VQAReTokenChunk: - VQAReTokenChunk:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize: - Resize:
size: [224,224] size: [224,224]
- NormalizeImage: - NormalizeImage:
...@@ -127,5 +128,3 @@ Eval: ...@@ -127,5 +128,3 @@ Eval:
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 8
num_workers: 8 num_workers: 8
collate_fn: ListCollator
...@@ -116,6 +116,7 @@ Train: ...@@ -116,6 +116,7 @@ Train:
- VQAReTokenRelation: - VQAReTokenRelation:
- VQAReTokenChunk: - VQAReTokenChunk:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize: - Resize:
size: [224,224] size: [224,224]
- NormalizeImage: - NormalizeImage:
...@@ -155,6 +156,7 @@ Eval: ...@@ -155,6 +156,7 @@ Eval:
- VQAReTokenRelation: - VQAReTokenRelation:
- VQAReTokenChunk: - VQAReTokenChunk:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize: - Resize:
size: [224,224] size: [224,224]
- NormalizeImage: - NormalizeImage:
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
|模型|骨干网络|任务|配置文件|hmean|下载链接| |模型|骨干网络|任务|配置文件|hmean|下载链接|
| --- | --- |--|--- | --- | --- | | --- | --- |--|--- | --- | --- |
|LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)| |LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)|
|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型(coming soon)]()| |LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar)|
<a name="2"></a> <a name="2"></a>
...@@ -52,14 +52,14 @@ ...@@ -52,14 +52,14 @@
### 4.1 Python推理 ### 4.1 Python推理
**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于LayoutXLM模型的关键信息抽取过程。 - SER
首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。 首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。
``` bash ``` bash
wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar
tar -xf ser_LayoutXLM_xfun_zh.tar tar -xf ser_LayoutXLM_xfun_zh.tar
python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
``` ```
LayoutXLM模型基于SER任务进行推理,可以执行如下命令: LayoutXLM模型基于SER任务进行推理,可以执行如下命令:
...@@ -80,6 +80,34 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下 ...@@ -80,6 +80,34 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下
<img src="../../ppstructure/docs/kie/result_ser/zh_val_42_ser.jpg" width="800"> <img src="../../ppstructure/docs/kie/result_ser/zh_val_42_ser.jpg" width="800">
</div> </div>
- RE
首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。
``` bash
wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar
tar -xf re_LayoutXLM_xfun_zh.tar
python3 tools/export_model.py -c configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
```
LayoutXLM模型基于RE任务进行推理,可以执行如下命令:
```bash
cd ppstructure
python3 kie/predict_kie_token_ser_re.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=../inference/re_layoutxlm_infer \
--ser_model_dir=../inference/ser_layoutxlm_infer \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf
```
RE可视化结果默认保存到`./output`文件夹里面,结果示例如下:
<div align="center">
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
</div>
<a name="4-2"></a> <a name="4-2"></a>
### 4.2 C++推理部署 ### 4.2 C++推理部署
......
...@@ -23,7 +23,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去 ...@@ -23,7 +23,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去
|模型|骨干网络|任务|配置文件|hmean|下载链接| |模型|骨干网络|任务|配置文件|hmean|下载链接|
| --- | --- |---| --- | --- | --- | | --- | --- |---| --- | --- | --- |
|VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)| |VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)|
|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型(coming soon)]()| |VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)|
<a name="2"></a> <a name="2"></a>
...@@ -45,7 +45,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去 ...@@ -45,7 +45,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去
### 4.1 Python推理 ### 4.1 Python推理
**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于VI-LayoutXLM模型的关键信息抽取过程。 -SER
首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。 首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。
...@@ -74,6 +74,36 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下 ...@@ -74,6 +74,36 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下
<img src="../../ppstructure/docs/kie/result_ser/zh_val_42_ser.jpg" width="800"> <img src="../../ppstructure/docs/kie/result_ser/zh_val_42_ser.jpg" width="800">
</div> </div>
-RE
首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。
``` bash
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar
tar -xf re_vi_layoutxlm_xfund_pretrained.tar
python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer
```
VI-LayoutXLM模型基于RE任务进行推理,可以执行如下命令:
```bash
cd ppstructure
python3 kie/predict_kie_token_ser_re.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=../inference/re_vi_layoutxlm_infer \
--ser_model_dir=../inference/ser_vi_layoutxlm_infer \
--use_visual_backbone=False \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx"
```
RE可视化结果默认保存到`./output`文件夹里面,结果示例如下:
<div align="center">
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
</div>
<a name="4-2"></a> <a name="4-2"></a>
### 4.2 C++推理部署 ### 4.2 C++推理部署
......
...@@ -28,7 +28,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows. ...@@ -28,7 +28,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows.
|Model|Backbone|Task |Cnnfig|Hmean|Download link| |Model|Backbone|Task |Cnnfig|Hmean|Download link|
| --- | --- |--|--- | --- | --- | | --- | --- |--|--- | --- | --- |
|LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)| |LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)|
|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[inference model(coming soon)]()| |LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[inference model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar)|
## 2. Environment ## 2. Environment
...@@ -46,7 +46,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code ...@@ -46,7 +46,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code
### 4.1 Python Inference ### 4.1 Python Inference
**Note:** Currently, the RE model inference process is still in the process of adaptation. We take SER model as an example to introduce the KIE process based on LayoutXLM model. - SER
First, we need to export the trained model into inference model. Take LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)). Use the following command to export. First, we need to export the trained model into inference model. Take LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)). Use the following command to export.
...@@ -54,7 +54,7 @@ First, we need to export the trained model into inference model. Take LayoutXLM ...@@ -54,7 +54,7 @@ First, we need to export the trained model into inference model. Take LayoutXLM
``` bash ``` bash
wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar
tar -xf ser_LayoutXLM_xfun_zh.tar tar -xf ser_LayoutXLM_xfun_zh.tar
python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
``` ```
Use the following command to infer using LayoutXLM SER model. Use the following command to infer using LayoutXLM SER model.
...@@ -77,6 +77,38 @@ The SER visualization results are saved in the `./output` directory by default. ...@@ -77,6 +77,38 @@ The SER visualization results are saved in the `./output` directory by default.
</div> </div>
- RE
First, we need to export the trained model into inference model. Take LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)). Use the following command to export.
``` bash
wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar
tar -xf re_LayoutXLM_xfun_zh.tar
python3 tools/export_model.py -c configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/re_layoutxlm_infer
```
Use the following command to infer using LayoutXLM RE model.
```bash
cd ppstructure
python3 kie/predict_kie_token_ser_re.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=../inference/re_layoutxlm_infer \
--ser_model_dir=../inference/ser_layoutxlm_infer \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf
```
The RE visualization results are saved in the `./output` directory by default. The results are as follows.
<div align="center">
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
</div>
### 4.2 C++ Inference ### 4.2 C++ Inference
Not supported Not supported
......
...@@ -22,7 +22,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows. ...@@ -22,7 +22,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows.
|Model|Backbone|Task |Cnnfig|Hmean|Download link| |Model|Backbone|Task |Cnnfig|Hmean|Download link|
| --- | --- |---| --- | --- | --- | | --- | --- |---| --- | --- | --- |
|VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)| |VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)|
|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model(coming soon)]()| |VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)|
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
...@@ -37,7 +37,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code ...@@ -37,7 +37,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code
### 4.1 Python Inference ### 4.1 Python Inference
**Note:** Currently, the RE model inference process is still in the process of adaptation. We take SER model as an example to introduce the KIE process based on VI-LayoutXLM model. -SER
First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export. First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export.
...@@ -70,6 +70,41 @@ The SER visualization results are saved in the `./output` folder by default. The ...@@ -70,6 +70,41 @@ The SER visualization results are saved in the `./output` folder by default. The
</div> </div>
-RE
First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export.
``` bash
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar
tar -xf re_vi_layoutxlm_xfund_pretrained.tar
python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer
```
Use the following command to infer using VI-LayoutXLM RE model.
```bash
cd ppstructure
python3 kie/predict_kie_token_ser_re.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=../inference/re_vi_layoutxlm_infer \
--ser_model_dir=../inference/ser_vi_layoutxlm_infer \
--use_visual_backbone=False \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx"
```
The RE visualization results are saved in the `./output` folder by default. The results are as follows.
<div align="center">
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
</div>
### 4.2 C++ Inference ### 4.2 C++ Inference
Not supported Not supported
......
...@@ -12,11 +12,9 @@ ...@@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation, TensorizeEntitiesRelations
__all__ = [ __all__ = [
'VQATokenPad', 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation',
'VQASerTokenChunk', 'TensorizeEntitiesRelations'
'VQAReTokenChunk',
'VQAReTokenRelation',
] ]
...@@ -15,3 +15,4 @@ ...@@ -15,3 +15,4 @@
from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
from .vqa_token_pad import VQATokenPad from .vqa_token_pad import VQATokenPad
from .vqa_token_relation import VQAReTokenRelation from .vqa_token_relation import VQAReTokenRelation
from .vqa_re_convert import TensorizeEntitiesRelations
\ No newline at end of file
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
class TensorizeEntitiesRelations(object):
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
self.max_seq_len = max_seq_len
self.infer_mode = infer_mode
def __call__(self, data):
entities = data['entities']
relations = data['relations']
entities_new = np.full(
shape=[self.max_seq_len + 1, 3], fill_value=-1, dtype='int64')
entities_new[0, 0] = len(entities['start'])
entities_new[0, 1] = len(entities['end'])
entities_new[0, 2] = len(entities['label'])
entities_new[1:len(entities['start']) + 1, 0] = np.array(entities[
'start'])
entities_new[1:len(entities['end']) + 1, 1] = np.array(entities['end'])
entities_new[1:len(entities['label']) + 1, 2] = np.array(entities[
'label'])
relations_new = np.full(
shape=[self.max_seq_len * self.max_seq_len + 1, 2],
fill_value=-1,
dtype='int64')
relations_new[0, 0] = len(relations['head'])
relations_new[0, 1] = len(relations['tail'])
relations_new[1:len(relations['head']) + 1, 0] = np.array(relations[
'head'])
relations_new[1:len(relations['tail']) + 1, 1] = np.array(relations[
'tail'])
data['entities'] = entities_new
data['relations'] = relations_new
return data
...@@ -37,23 +37,25 @@ class VQAReTokenMetric(object): ...@@ -37,23 +37,25 @@ class VQAReTokenMetric(object):
gt_relations = [] gt_relations = []
for b in range(len(self.relations_list)): for b in range(len(self.relations_list)):
rel_sent = [] rel_sent = []
if "head" in self.relations_list[b]: relation_list = self.relations_list[b]
for head, tail in zip(self.relations_list[b]["head"], entitie_list = self.entities_list[b]
self.relations_list[b]["tail"]): head_len = relation_list[0, 0]
if head_len > 0:
entitie_start_list = entitie_list[1:entitie_list[0, 0] + 1, 0]
entitie_end_list = entitie_list[1:entitie_list[0, 1] + 1, 1]
entitie_label_list = entitie_list[1:entitie_list[0, 2] + 1, 2]
for head, tail in zip(relation_list[1:head_len + 1, 0],
relation_list[1:head_len + 1, 1]):
rel = {} rel = {}
rel["head_id"] = head rel["head_id"] = head
rel["head"] = ( rel["head"] = (entitie_start_list[head],
self.entities_list[b]["start"][rel["head_id"]], entitie_end_list[head])
self.entities_list[b]["end"][rel["head_id"]]) rel["head_type"] = entitie_label_list[head]
rel["head_type"] = self.entities_list[b]["label"][rel[
"head_id"]]
rel["tail_id"] = tail rel["tail_id"] = tail
rel["tail"] = ( rel["tail"] = (entitie_start_list[tail],
self.entities_list[b]["start"][rel["tail_id"]], entitie_end_list[tail])
self.entities_list[b]["end"][rel["tail_id"]]) rel["tail_type"] = entitie_label_list[tail]
rel["tail_type"] = self.entities_list[b]["label"][rel[
"tail_id"]]
rel["type"] = 1 rel["type"] = 1
rel_sent.append(rel) rel_sent.append(rel)
......
...@@ -218,8 +218,12 @@ class LayoutXLMForRe(NLPBaseModel): ...@@ -218,8 +218,12 @@ class LayoutXLMForRe(NLPBaseModel):
def forward(self, x): def forward(self, x):
if self.use_visual_backbone is True: if self.use_visual_backbone is True:
image = x[4] image = x[4]
entities = x[5]
relations = x[6]
else: else:
image = None image = None
entities = x[4]
relations = x[5]
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[1], bbox=x[1],
...@@ -229,6 +233,6 @@ class LayoutXLMForRe(NLPBaseModel): ...@@ -229,6 +233,6 @@ class LayoutXLMForRe(NLPBaseModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None, labels=None,
entities=x[5], entities=entities,
relations=x[6]) relations=relations)
return x return x
...@@ -21,18 +21,22 @@ class VQAReTokenLayoutLMPostProcess(object): ...@@ -21,18 +21,22 @@ class VQAReTokenLayoutLMPostProcess(object):
super(VQAReTokenLayoutLMPostProcess, self).__init__() super(VQAReTokenLayoutLMPostProcess, self).__init__()
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
pred_relations = preds['pred_relations']
if isinstance(preds['pred_relations'], paddle.Tensor):
pred_relations = pred_relations.numpy()
pred_relations = self.decode_pred(pred_relations)
if label is not None: if label is not None:
return self._metric(preds, label) return self._metric(pred_relations, label)
else: else:
return self._infer(preds, *args, **kwargs) return self._infer(pred_relations, *args, **kwargs)
def _metric(self, preds, label): def _metric(self, pred_relations, label):
return preds['pred_relations'], label[6], label[5] return pred_relations, label[6], label[5]
def _infer(self, preds, *args, **kwargs): def _infer(self, pred_relations, *args, **kwargs):
ser_results = kwargs['ser_results'] ser_results = kwargs['ser_results']
entity_idx_dict_batch = kwargs['entity_idx_dict_batch'] entity_idx_dict_batch = kwargs['entity_idx_dict_batch']
pred_relations = preds['pred_relations']
# merge relations and ocr info # merge relations and ocr info
results = [] results = []
...@@ -50,6 +54,24 @@ class VQAReTokenLayoutLMPostProcess(object): ...@@ -50,6 +54,24 @@ class VQAReTokenLayoutLMPostProcess(object):
results.append(result) results.append(result)
return results return results
def decode_pred(self, pred_relations):
pred_relations_new = []
for pred_relation in pred_relations:
pred_relation_new = []
pred_relation = pred_relation[1:pred_relation[0, 0, 0] + 1]
for relation in pred_relation:
relation_new = dict()
relation_new['head_id'] = relation[0, 0]
relation_new['head'] = tuple(relation[1])
relation_new['head_type'] = relation[2, 0]
relation_new['tail_id'] = relation[3, 0]
relation_new['tail'] = tuple(relation[4])
relation_new['tail_type'] = relation[5, 0]
relation_new['type'] = relation[6, 0]
pred_relation_new.append(relation_new)
pred_relations_new.append(pred_relation_new)
return pred_relations_new
class DistillationRePostProcess(VQAReTokenLayoutLMPostProcess): class DistillationRePostProcess(VQAReTokenLayoutLMPostProcess):
""" """
......
...@@ -51,9 +51,9 @@ ...@@ -51,9 +51,9 @@
|模型名称|模型简介 | 推理模型大小| 精度(hmean) | 预测耗时(ms) | 下载地址| |模型名称|模型简介 | 推理模型大小| 精度(hmean) | 预测耗时(ms) | 下载地址|
| --- | --- | --- |--- |--- | --- | | --- | --- | --- |--- |--- | --- |
|ser_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的SER模型|1.1G| 93.19% | 15.49 | [推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar) | |ser_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的SER模型|1.1G| 93.19% | 15.49 | [推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar) |
|re_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的RE模型|1.1G| 83.92% | 15.49 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar) | |re_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的RE模型|1.1G| 83.92% | 15.49 |[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar) |
|ser_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的SER模型|1.4G| 90.38% | 19.49 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) | |ser_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的SER模型|1.4G| 90.38% | 19.49 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的RE模型|1.4G| 74.83% | 19.49 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | |re_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的RE模型|1.4G| 74.83% | 19.49 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfund中文数据集上训练的SER模型|778M| 85.44% | 31.46 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) | |ser_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfund中文数据集上训练的SER模型|778M| 85.44% | 31.46 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M| 67.77% | 31.46 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) | |re_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M| 67.77% | 31.46 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfund_zh|基于LayoutLM在xfund中文数据集上训练的SER模型|430M| 77.31% | - |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) | |ser_LayoutLM_xfund_zh|基于LayoutLM在xfund中文数据集上训练的SER模型|430M| 77.31% | - |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
......
...@@ -209,17 +209,18 @@ python3 ./tools/infer_kie_token_ser_re.py \ ...@@ -209,17 +209,18 @@ python3 ./tools/infer_kie_token_ser_re.py \
#### 4.2.3 Inference using PaddleInference #### 4.2.3 Inference using PaddleInference
At present, only SER model supports inference using PaddleInference.
Firstly, download the inference SER inference model. Firstly, download the inference SER inference model.
```bash ```bash
mkdir inference mkdir inference
cd inference cd inference
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar
cd ..
``` ```
-SER
Use the following command for inference. Use the following command for inference.
...@@ -236,6 +237,26 @@ python3 kie/predict_kie_token_ser.py \ ...@@ -236,6 +237,26 @@ python3 kie/predict_kie_token_ser.py \
The visual results and text file will be saved in directory `output`. The visual results and text file will be saved in directory `output`.
-RE
Use the following command for inference.
```bash
cd ppstructure
python3 kie/predict_kie_token_ser_re.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=../inference/re_vi_layoutxlm_xfund_infer \
--ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \
--use_visual_backbone=False \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx"
```
The visual results and text file will be saved in directory `output`.
### 4.3 More ### 4.3 More
......
...@@ -193,17 +193,18 @@ python3 ./tools/infer_kie_token_ser_re.py \ ...@@ -193,17 +193,18 @@ python3 ./tools/infer_kie_token_ser_re.py \
#### 4.2.3 基于PaddleInference的预测 #### 4.2.3 基于PaddleInference的预测
目前仅SER模型支持PaddleInference推理。 首先下载SER和RE的推理模型。
首先下载SER的推理模型。
```bash ```bash
mkdir inference mkdir inference
cd inference cd inference
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar
cd ..
``` ```
-SER
执行下面的命令进行预测。 执行下面的命令进行预测。
```bash ```bash
...@@ -219,6 +220,26 @@ python3 kie/predict_kie_token_ser.py \ ...@@ -219,6 +220,26 @@ python3 kie/predict_kie_token_ser.py \
可视化结果保存在`output`目录下。 可视化结果保存在`output`目录下。
-RE
执行下面的命令进行预测。
```bash
cd ppstructure
python3 kie/predict_kie_token_ser_re.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=../inference/re_vi_layoutxlm_xfund_infer \
--ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \
--use_visual_backbone=False \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx"
```
可视化结果保存在`output`目录下。
### 4.3 更多 ### 4.3 更多
关于KIE模型的训练评估与推理,请参考:[关键信息抽取教程](../../doc/doc_ch/kie.md) 关于KIE模型的训练评估与推理,请参考:[关键信息抽取教程](../../doc/doc_ch/kie.md)
......
...@@ -102,16 +102,18 @@ class SerPredictor(object): ...@@ -102,16 +102,18 @@ class SerPredictor(object):
ori_im = img.copy() ori_im = img.copy()
data = {'image': img} data = {'image': img}
data = transform(data, self.preprocess_op) data = transform(data, self.preprocess_op)
img = data[0] if data[0] is None:
if img is None:
return None, 0 return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
starttime = time.time() starttime = time.time()
for idx in range(len(data)):
if isinstance(data[idx], np.ndarray):
data[idx] = np.expand_dims(data[idx], axis=0)
else:
data[idx] = [data[idx]]
for idx in range(len(self.input_tensor)): for idx in range(len(self.input_tensor)):
expand_input = np.expand_dims(data[idx], axis=0) self.input_tensor[idx].copy_from_cpu(data[idx])
self.input_tensor[idx].copy_from_cpu(expand_input)
self.predictor.run() self.predictor.run()
...@@ -122,9 +124,9 @@ class SerPredictor(object): ...@@ -122,9 +124,9 @@ class SerPredictor(object):
preds = outputs[0] preds = outputs[0]
post_result = self.postprocess_op( post_result = self.postprocess_op(
preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]]) preds, segment_offset_ids=data[6], ocr_infos=data[7])
elapse = time.time() - starttime elapse = time.time() - starttime
return post_result, elapse return post_result, data, elapse
def main(args): def main(args):
...@@ -145,7 +147,7 @@ def main(args): ...@@ -145,7 +147,7 @@ def main(args):
if img is None: if img is None:
logger.info("error in loading image:{}".format(image_file)) logger.info("error in loading image:{}".format(image_file))
continue continue
ser_res, elapse = ser_predictor(img) ser_res, _, elapse = ser_predictor(img)
ser_res = ser_res[0] ser_res = ser_res[0]
res_str = '{}\t{}\n'.format( res_str = '{}\t{}\n'.format(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import numpy as np
import time
import tools.infer.utility as utility
from tools.infer_kie_token_ser_re import make_input
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.visual import draw_re_results
from ppocr.utils.utility import get_image_file_list, check_and_read
from ppstructure.utility import parse_args
from ppstructure.kie.predict_kie_token_ser import SerPredictor
from paddleocr import PaddleOCR
logger = get_logger()
class SerRePredictor(object):
def __init__(self, args):
self.use_visual_backbone = args.use_visual_backbone
self.ser_engine = SerPredictor(args)
postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 're', logger)
def __call__(self, img):
ori_im = img.copy()
starttime = time.time()
ser_results, ser_inputs, _ = self.ser_engine(img)
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
if self.use_visual_backbone == False:
re_input.pop(4)
for idx in range(len(self.input_tensor)):
self.input_tensor[idx].copy_from_cpu(re_input[idx])
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = dict(loss=outputs[0], pred_relations=outputs[1])
post_result = self.postprocess_op(
preds,
ser_results=ser_results,
entity_idx_dict_batch=entity_idx_dict_batch)
elapse = time.time() - starttime
return post_result, elapse
def main(args):
image_file_list = get_image_file_list(args.image_dir)
ser_predictor = SerRePredictor(args)
count = 0
total_time = 0
os.makedirs(args.output, exist_ok=True)
with open(
os.path.join(args.output, 'infer.txt'), mode='w',
encoding='utf-8') as f_w:
for image_file in image_file_list:
img, flag, _ = check_and_read(image_file)
if not flag:
img = cv2.imread(image_file)
img = img[:, :, ::-1]
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
re_res, elapse = ser_predictor(img)
re_res = re_res[0]
res_str = '{}\t{}\n'.format(
image_file,
json.dumps(
{
"ocr_info": re_res,
}, ensure_ascii=False))
f_w.write(res_str)
img_res = draw_re_results(
image_file, re_res, font_path=args.vis_font_path)
img_save_path = os.path.join(
args.output,
os.path.splitext(os.path.basename(image_file))[0] +
"_ser_re.jpg")
cv2.imwrite(img_save_path, img_res)
logger.info("save vis result to {}".format(img_save_path))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
main(parse_args())
...@@ -52,6 +52,8 @@ def init_args(): ...@@ -52,6 +52,8 @@ def init_args():
# params for kie # params for kie
parser.add_argument("--kie_algorithm", type=str, default='LayoutXLM') parser.add_argument("--kie_algorithm", type=str, default='LayoutXLM')
parser.add_argument("--ser_model_dir", type=str) parser.add_argument("--ser_model_dir", type=str)
parser.add_argument("--re_model_dir", type=str)
parser.add_argument("--use_visual_backbone", type=str2bool, default=True)
parser.add_argument( parser.add_argument(
"--ser_dict_path", "--ser_dict_path",
type=str, type=str,
......
...@@ -115,15 +115,11 @@ def export_single_model(model, ...@@ -115,15 +115,11 @@ def export_single_model(model,
max_text_length = arch_config["Head"]["max_text_length"] max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [ other_shape = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 3, 48, 160], dtype="float32"), shape=[None, 3, 48, 160], dtype="float32"), [
[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, ], shape=[None, ], dtype="float32"),
dtype="float32"),
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, max_text_length], shape=[None, max_text_length], dtype="int64")
dtype="int64")
] ]
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
...@@ -140,6 +136,13 @@ def export_single_model(model, ...@@ -140,6 +136,13 @@ def export_single_model(model,
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype="int64"), # image shape=[None, 3, 224, 224], dtype="int64"), # image
] ]
if 'Re' in arch_config['Backbone']['name']:
input_spec.extend([
paddle.static.InputSpec(
shape=[None, 512, 3], dtype="int64"), # entities
paddle.static.InputSpec(
shape=[None, None, 2], dtype="int64"), # relations
])
if model.backbone.use_visual_backbone is False: if model.backbone.use_visual_backbone is False:
input_spec.pop(4) input_spec.pop(4)
model = to_static(model, input_spec=[input_spec]) model = to_static(model, input_spec=[input_spec])
......
...@@ -162,6 +162,8 @@ def create_predictor(args, mode, logger): ...@@ -162,6 +162,8 @@ def create_predictor(args, mode, logger):
model_dir = args.table_model_dir model_dir = args.table_model_dir
elif mode == 'ser': elif mode == 'ser':
model_dir = args.ser_model_dir model_dir = args.ser_model_dir
elif mode == 're':
model_dir = args.re_model_dir
elif mode == "sr": elif mode == "sr":
model_dir = args.sr_model_dir model_dir = args.sr_model_dir
elif mode == 'layout': elif mode == 'layout':
...@@ -227,7 +229,8 @@ def create_predictor(args, mode, logger): ...@@ -227,7 +229,8 @@ def create_predictor(args, mode, logger):
use_calib_mode=False) use_calib_mode=False)
# collect shape # collect shape
trt_shape_f = os.path.join(model_dir, f"{mode}_trt_dynamic_shape.txt") trt_shape_f = os.path.join(model_dir,
f"{mode}_trt_dynamic_shape.txt")
if not os.path.exists(trt_shape_f): if not os.path.exists(trt_shape_f):
config.collect_shape_range_info(trt_shape_f) config.collect_shape_range_info(trt_shape_f)
...@@ -262,6 +265,8 @@ def create_predictor(args, mode, logger): ...@@ -262,6 +265,8 @@ def create_predictor(args, mode, logger):
config.disable_glog_info() config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.delete_pass("matmul_transpose_reshape_fuse_pass") config.delete_pass("matmul_transpose_reshape_fuse_pass")
if mode == 're':
config.delete_pass("simplify_with_basic_ops_pass")
if mode == 'table': if mode == 'table':
config.delete_pass("fc_fuse_pass") # not supported for table config.delete_pass("fc_fuse_pass") # not supported for table
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
......
...@@ -63,7 +63,7 @@ class ReArgsParser(ArgsParser): ...@@ -63,7 +63,7 @@ class ReArgsParser(ArgsParser):
def make_input(ser_inputs, ser_results): def make_input(ser_inputs, ser_results):
entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
batch_size, max_seq_len = ser_inputs[0].shape[:2]
entities = ser_inputs[8][0] entities = ser_inputs[8][0]
ser_results = ser_results[0] ser_results = ser_results[0]
assert len(entities) == len(ser_results) assert len(entities) == len(ser_results)
...@@ -80,34 +80,44 @@ def make_input(ser_inputs, ser_results): ...@@ -80,34 +80,44 @@ def make_input(ser_inputs, ser_results):
start.append(entity['start']) start.append(entity['start'])
end.append(entity['end']) end.append(entity['end'])
label.append(entities_labels[res['pred']]) label.append(entities_labels[res['pred']])
entities = dict(start=start, end=end, label=label)
entities = np.full([max_seq_len + 1, 3], fill_value=-1)
entities[0, 0] = len(start)
entities[1:len(start) + 1, 0] = start
entities[0, 1] = len(end)
entities[1:len(end) + 1, 1] = end
entities[0, 2] = len(label)
entities[1:len(label) + 1, 2] = label
# relations # relations
head = [] head = []
tail = [] tail = []
for i in range(len(entities["label"])): for i in range(len(label)):
for j in range(len(entities["label"])): for j in range(len(label)):
if entities["label"][i] == 1 and entities["label"][j] == 2: if label[i] == 1 and label[j] == 2:
head.append(i) head.append(i)
tail.append(j) tail.append(j)
relations = dict(head=head, tail=tail) relations = np.full([len(head) + 1, 2], fill_value=-1)
relations[0, 0] = len(head)
relations[1:len(head) + 1, 0] = head
relations[0, 1] = len(tail)
relations[1:len(tail) + 1, 1] = tail
entities = np.expand_dims(entities, axis=0)
entities = np.repeat(entities, batch_size, axis=0)
relations = np.expand_dims(relations, axis=0)
relations = np.repeat(relations, batch_size, axis=0)
# remove ocr_info segment_offset_id and label in ser input
if isinstance(ser_inputs[0], paddle.Tensor):
entities = paddle.to_tensor(entities)
relations = paddle.to_tensor(relations)
ser_inputs = ser_inputs[:5] + [entities, relations]
batch_size = ser_inputs[0].shape[0]
entities_batch = []
relations_batch = []
entity_idx_dict_batch = [] entity_idx_dict_batch = []
for b in range(batch_size): for b in range(batch_size):
entities_batch.append(entities)
relations_batch.append(relations)
entity_idx_dict_batch.append(entity_idx_dict) entity_idx_dict_batch.append(entity_idx_dict)
ser_inputs[8] = entities_batch
ser_inputs.append(relations_batch)
# remove ocr_info segment_offset_id and label in ser input
ser_inputs.pop(7)
ser_inputs.pop(6)
ser_inputs.pop(5)
return ser_inputs, entity_idx_dict_batch return ser_inputs, entity_idx_dict_batch
...@@ -136,6 +146,8 @@ class SerRePredictor(object): ...@@ -136,6 +146,8 @@ class SerRePredictor(object):
def __call__(self, data): def __call__(self, data):
ser_results, ser_inputs = self.ser_engine(data) ser_results, ser_inputs = self.ser_engine(data)
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
if self.model.backbone.use_visual_backbone is False:
re_input.pop(4)
preds = self.model(re_input) preds = self.model(re_input)
post_result = self.post_process_class( post_result = self.post_process_class(
preds, preds,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册