diff --git a/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml b/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml index 2401cf317987c5614a476065191e750587bc09b5..99dc771d150b15847486c096529a2828b9c0c05a 100644 --- a/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml +++ b/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml @@ -68,6 +68,7 @@ Train: - VQAReTokenRelation: - VQAReTokenChunk: max_seq_len: *max_seq_len + - TensorizeEntitiesRelations: - Resize: size: [224,224] - NormalizeImage: @@ -83,7 +84,6 @@ Train: drop_last: False batch_size_per_card: 2 num_workers: 8 - collate_fn: ListCollator Eval: dataset: @@ -105,6 +105,7 @@ Eval: - VQAReTokenRelation: - VQAReTokenChunk: max_seq_len: *max_seq_len + - TensorizeEntitiesRelations: - Resize: size: [224,224] - NormalizeImage: @@ -120,4 +121,3 @@ Eval: drop_last: False batch_size_per_card: 8 num_workers: 8 - collate_fn: ListCollator diff --git a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml index ea9f50ef56ec8b169333263c1d5e96586f9472b3..811c7d2d6f16344a3d6ad060fec1a1966241d81b 100644 --- a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml +++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml @@ -73,6 +73,7 @@ Train: - VQAReTokenRelation: - VQAReTokenChunk: max_seq_len: *max_seq_len + - TensorizeEntitiesRelations: - Resize: size: [224,224] - NormalizeImage: @@ -88,7 +89,6 @@ Train: drop_last: False batch_size_per_card: 2 num_workers: 4 - collate_fn: ListCollator Eval: dataset: @@ -112,6 +112,7 @@ Eval: - VQAReTokenRelation: - VQAReTokenChunk: max_seq_len: *max_seq_len + - TensorizeEntitiesRelations: - Resize: size: [224,224] - NormalizeImage: @@ -127,5 +128,3 @@ Eval: drop_last: False batch_size_per_card: 8 num_workers: 8 - collate_fn: ListCollator - diff --git a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml index b96528d2738e7cfb2575feca4146af1eed0c5d2f..0bd42901ebd3a37eb29ce854b1e434dc356d9643 100644 --- a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml +++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml @@ -116,6 +116,7 @@ Train: - VQAReTokenRelation: - VQAReTokenChunk: max_seq_len: *max_seq_len + - TensorizeEntitiesRelations: - Resize: size: [224,224] - NormalizeImage: @@ -155,6 +156,7 @@ Eval: - VQAReTokenRelation: - VQAReTokenChunk: max_seq_len: *max_seq_len + - TensorizeEntitiesRelations: - Resize: size: [224,224] - NormalizeImage: diff --git a/doc/doc_ch/algorithm_kie_layoutxlm.md b/doc/doc_ch/algorithm_kie_layoutxlm.md index e693be49b7bc89e04b169fe74cf76525b2494948..0cbcad25016974207382a044e211c704082f6467 100644 --- a/doc/doc_ch/algorithm_kie_layoutxlm.md +++ b/doc/doc_ch/algorithm_kie_layoutxlm.md @@ -30,7 +30,7 @@ |模型|骨干网络|任务|配置文件|hmean|下载链接| | --- | --- |--|--- | --- | --- | |LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)| -|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型(coming soon)]()| +|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar)| @@ -52,14 +52,14 @@ ### 4.1 Python推理 -**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于LayoutXLM模型的关键信息抽取过程。 +- SER 首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。 ``` bash wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar tar -xf ser_LayoutXLM_xfun_zh.tar -python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm +python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer ``` LayoutXLM模型基于SER任务进行推理,可以执行如下命令: @@ -80,6 +80,34 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下 +- RE + +首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。 + +``` bash +wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar +tar -xf re_LayoutXLM_xfun_zh.tar +python3 tools/export_model.py -c configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer +``` + +LayoutXLM模型基于RE任务进行推理,可以执行如下命令: + +```bash +cd ppstructure +python3 kie/predict_kie_token_ser_re.py \ + --kie_algorithm=LayoutXLM \ + --re_model_dir=../inference/re_layoutxlm_infer \ + --ser_model_dir=../inference/ser_layoutxlm_infer \ + --image_dir=./docs/kie/input/zh_val_42.jpg \ + --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \ + --vis_font_path=../doc/fonts/simfang.ttf +``` + +RE可视化结果默认保存到`./output`文件夹里面,结果示例如下: + +
+ +
### 4.2 C++推理部署 diff --git a/doc/doc_ch/algorithm_kie_vi_layoutxlm.md b/doc/doc_ch/algorithm_kie_vi_layoutxlm.md index f1bb4b1e62736e88594196819dcc41980f1716bf..6c69230e020d8ead1ca3f917676730afaf0a19f8 100644 --- a/doc/doc_ch/algorithm_kie_vi_layoutxlm.md +++ b/doc/doc_ch/algorithm_kie_vi_layoutxlm.md @@ -23,7 +23,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去 |模型|骨干网络|任务|配置文件|hmean|下载链接| | --- | --- |---| --- | --- | --- | |VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)| -|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型(coming soon)]()| +|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)| @@ -45,7 +45,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去 ### 4.1 Python推理 -**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于VI-LayoutXLM模型的关键信息抽取过程。 +-SER 首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。 @@ -74,6 +74,36 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下 +-RE + +首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。 + +``` bash +wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar +tar -xf re_vi_layoutxlm_xfund_pretrained.tar +python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer +``` + +VI-LayoutXLM模型基于RE任务进行推理,可以执行如下命令: + +```bash +cd ppstructure +python3 kie/predict_kie_token_ser_re.py \ + --kie_algorithm=LayoutXLM \ + --re_model_dir=../inference/re_vi_layoutxlm_infer \ + --ser_model_dir=../inference/ser_vi_layoutxlm_infer \ + --use_visual_backbone=False \ + --image_dir=./docs/kie/input/zh_val_42.jpg \ + --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \ + --vis_font_path=../doc/fonts/simfang.ttf \ + --ocr_order_method="tb-yx" +``` + +RE可视化结果默认保存到`./output`文件夹里面,结果示例如下: + +
+ +
### 4.2 C++推理部署 diff --git a/doc/doc_en/algorithm_kie_layoutxlm_en.md b/doc/doc_en/algorithm_kie_layoutxlm_en.md index 910c1f4d497a6e503f0a7a5ec26dbeceb2d321a1..0c82b0423b2cdc11b1817e7575629a659e599374 100644 --- a/doc/doc_en/algorithm_kie_layoutxlm_en.md +++ b/doc/doc_en/algorithm_kie_layoutxlm_en.md @@ -28,7 +28,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows. |Model|Backbone|Task |Cnnfig|Hmean|Download link| | --- | --- |--|--- | --- | --- | |LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)| -|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[inference model(coming soon)]()| +|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[inference model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar)| ## 2. Environment @@ -46,7 +46,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code ### 4.1 Python Inference -**Note:** Currently, the RE model inference process is still in the process of adaptation. We take SER model as an example to introduce the KIE process based on LayoutXLM model. +- SER First, we need to export the trained model into inference model. Take LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)). Use the following command to export. @@ -54,7 +54,7 @@ First, we need to export the trained model into inference model. Take LayoutXLM ``` bash wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar tar -xf ser_LayoutXLM_xfun_zh.tar -python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm +python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer ``` Use the following command to infer using LayoutXLM SER model. @@ -77,6 +77,38 @@ The SER visualization results are saved in the `./output` directory by default. +- RE + +First, we need to export the trained model into inference model. Take LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)). Use the following command to export. + + +``` bash +wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar +tar -xf re_LayoutXLM_xfun_zh.tar +python3 tools/export_model.py -c configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/re_layoutxlm_infer +``` + +Use the following command to infer using LayoutXLM RE model. + + +```bash +cd ppstructure +python3 kie/predict_kie_token_ser_re.py \ + --kie_algorithm=LayoutXLM \ + --re_model_dir=../inference/re_layoutxlm_infer \ + --ser_model_dir=../inference/ser_layoutxlm_infer \ + --image_dir=./docs/kie/input/zh_val_42.jpg \ + --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \ + --vis_font_path=../doc/fonts/simfang.ttf +``` +The RE visualization results are saved in the `./output` directory by default. The results are as follows. + + +
+ +
+ + ### 4.2 C++ Inference Not supported diff --git a/doc/doc_en/algorithm_kie_vi_layoutxlm_en.md b/doc/doc_en/algorithm_kie_vi_layoutxlm_en.md index 12b6e1bddbd03b820ce33ba86de3d430a44f8987..fab761f5a957c2bfdd42e09d22da27410a2ad423 100644 --- a/doc/doc_en/algorithm_kie_vi_layoutxlm_en.md +++ b/doc/doc_en/algorithm_kie_vi_layoutxlm_en.md @@ -22,7 +22,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows. |Model|Backbone|Task |Cnnfig|Hmean|Download link| | --- | --- |---| --- | --- | --- | |VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)| -|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model(coming soon)]()| +|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)| Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. @@ -37,7 +37,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code ### 4.1 Python Inference -**Note:** Currently, the RE model inference process is still in the process of adaptation. We take SER model as an example to introduce the KIE process based on VI-LayoutXLM model. +-SER First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export. @@ -70,6 +70,41 @@ The SER visualization results are saved in the `./output` folder by default. The +-RE + +First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export. + + +``` bash +wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar +tar -xf re_vi_layoutxlm_xfund_pretrained.tar +python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer +``` + +Use the following command to infer using VI-LayoutXLM RE model. + + +```bash +cd ppstructure +python3 kie/predict_kie_token_ser_re.py \ + --kie_algorithm=LayoutXLM \ + --re_model_dir=../inference/re_vi_layoutxlm_infer \ + --ser_model_dir=../inference/ser_vi_layoutxlm_infer \ + --use_visual_backbone=False \ + --image_dir=./docs/kie/input/zh_val_42.jpg \ + --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \ + --vis_font_path=../doc/fonts/simfang.ttf \ + --ocr_order_method="tb-yx" +``` + +The RE visualization results are saved in the `./output` folder by default. The results are as follows. + + +
+ +
+ + ### 4.2 C++ Inference Not supported diff --git a/ppocr/data/imaug/vqa/__init__.py b/ppocr/data/imaug/vqa/__init__.py index 34189bcefb17a0776bd62a19c58081286882b5a5..73f7dcdf712f6db0ff4354b1b01134d1277ff078 100644 --- a/ppocr/data/imaug/vqa/__init__.py +++ b/ppocr/data/imaug/vqa/__init__.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation +from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation, TensorizeEntitiesRelations __all__ = [ - 'VQATokenPad', - 'VQASerTokenChunk', - 'VQAReTokenChunk', - 'VQAReTokenRelation', + 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation', + 'TensorizeEntitiesRelations' ] diff --git a/ppocr/data/imaug/vqa/token/__init__.py b/ppocr/data/imaug/vqa/token/__init__.py index 7c115661753cd031b16ec34697157e2fcdcf2dec..5fbaa43db9e7182cfa3efa4f3dc0d9e54c17c822 100644 --- a/ppocr/data/imaug/vqa/token/__init__.py +++ b/ppocr/data/imaug/vqa/token/__init__.py @@ -15,3 +15,4 @@ from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk from .vqa_token_pad import VQATokenPad from .vqa_token_relation import VQAReTokenRelation +from .vqa_re_convert import TensorizeEntitiesRelations \ No newline at end of file diff --git a/ppocr/data/imaug/vqa/token/vqa_re_convert.py b/ppocr/data/imaug/vqa/token/vqa_re_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..86962f2590b57f38640d76ef5d8b74ead5e854e0 --- /dev/null +++ b/ppocr/data/imaug/vqa/token/vqa_re_convert.py @@ -0,0 +1,51 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +class TensorizeEntitiesRelations(object): + def __init__(self, max_seq_len=512, infer_mode=False, **kwargs): + self.max_seq_len = max_seq_len + self.infer_mode = infer_mode + + def __call__(self, data): + entities = data['entities'] + relations = data['relations'] + + entities_new = np.full( + shape=[self.max_seq_len + 1, 3], fill_value=-1, dtype='int64') + entities_new[0, 0] = len(entities['start']) + entities_new[0, 1] = len(entities['end']) + entities_new[0, 2] = len(entities['label']) + entities_new[1:len(entities['start']) + 1, 0] = np.array(entities[ + 'start']) + entities_new[1:len(entities['end']) + 1, 1] = np.array(entities['end']) + entities_new[1:len(entities['label']) + 1, 2] = np.array(entities[ + 'label']) + + relations_new = np.full( + shape=[self.max_seq_len * self.max_seq_len + 1, 2], + fill_value=-1, + dtype='int64') + relations_new[0, 0] = len(relations['head']) + relations_new[0, 1] = len(relations['tail']) + relations_new[1:len(relations['head']) + 1, 0] = np.array(relations[ + 'head']) + relations_new[1:len(relations['tail']) + 1, 1] = np.array(relations[ + 'tail']) + + data['entities'] = entities_new + data['relations'] = relations_new + return data diff --git a/ppocr/metrics/vqa_token_re_metric.py b/ppocr/metrics/vqa_token_re_metric.py index f84387d8beb729bcc4b420ceea24a5e9b2993c64..0509984f7e7e85fc1dae859761fedb7356a02477 100644 --- a/ppocr/metrics/vqa_token_re_metric.py +++ b/ppocr/metrics/vqa_token_re_metric.py @@ -37,23 +37,25 @@ class VQAReTokenMetric(object): gt_relations = [] for b in range(len(self.relations_list)): rel_sent = [] - if "head" in self.relations_list[b]: - for head, tail in zip(self.relations_list[b]["head"], - self.relations_list[b]["tail"]): + relation_list = self.relations_list[b] + entitie_list = self.entities_list[b] + head_len = relation_list[0, 0] + if head_len > 0: + entitie_start_list = entitie_list[1:entitie_list[0, 0] + 1, 0] + entitie_end_list = entitie_list[1:entitie_list[0, 1] + 1, 1] + entitie_label_list = entitie_list[1:entitie_list[0, 2] + 1, 2] + for head, tail in zip(relation_list[1:head_len + 1, 0], + relation_list[1:head_len + 1, 1]): rel = {} rel["head_id"] = head - rel["head"] = ( - self.entities_list[b]["start"][rel["head_id"]], - self.entities_list[b]["end"][rel["head_id"]]) - rel["head_type"] = self.entities_list[b]["label"][rel[ - "head_id"]] + rel["head"] = (entitie_start_list[head], + entitie_end_list[head]) + rel["head_type"] = entitie_label_list[head] rel["tail_id"] = tail - rel["tail"] = ( - self.entities_list[b]["start"][rel["tail_id"]], - self.entities_list[b]["end"][rel["tail_id"]]) - rel["tail_type"] = self.entities_list[b]["label"][rel[ - "tail_id"]] + rel["tail"] = (entitie_start_list[tail], + entitie_end_list[tail]) + rel["tail_type"] = entitie_label_list[tail] rel["type"] = 1 rel_sent.append(rel) diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py index 8e10ed7b48e9aff344b71e5a04970d1a5dab8a71..acb1315cc0a588396549e5b8928bd2e4d3c769be 100644 --- a/ppocr/modeling/backbones/vqa_layoutlm.py +++ b/ppocr/modeling/backbones/vqa_layoutlm.py @@ -218,8 +218,12 @@ class LayoutXLMForRe(NLPBaseModel): def forward(self, x): if self.use_visual_backbone is True: image = x[4] + entities = x[5] + relations = x[6] else: image = None + entities = x[4] + relations = x[5] x = self.model( input_ids=x[0], bbox=x[1], @@ -229,6 +233,6 @@ class LayoutXLMForRe(NLPBaseModel): position_ids=None, head_mask=None, labels=None, - entities=x[5], - relations=x[6]) + entities=entities, + relations=relations) return x diff --git a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py index 96c25d9aac01066f7a3841fe61aa7b0fe05041bd..a6011acf8d68f65adfc84e134c9cc0e733dd68ea 100644 --- a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py +++ b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py @@ -21,18 +21,22 @@ class VQAReTokenLayoutLMPostProcess(object): super(VQAReTokenLayoutLMPostProcess, self).__init__() def __call__(self, preds, label=None, *args, **kwargs): + pred_relations = preds['pred_relations'] + if isinstance(preds['pred_relations'], paddle.Tensor): + pred_relations = pred_relations.numpy() + pred_relations = self.decode_pred(pred_relations) + if label is not None: - return self._metric(preds, label) + return self._metric(pred_relations, label) else: - return self._infer(preds, *args, **kwargs) + return self._infer(pred_relations, *args, **kwargs) - def _metric(self, preds, label): - return preds['pred_relations'], label[6], label[5] + def _metric(self, pred_relations, label): + return pred_relations, label[6], label[5] - def _infer(self, preds, *args, **kwargs): + def _infer(self, pred_relations, *args, **kwargs): ser_results = kwargs['ser_results'] entity_idx_dict_batch = kwargs['entity_idx_dict_batch'] - pred_relations = preds['pred_relations'] # merge relations and ocr info results = [] @@ -50,6 +54,24 @@ class VQAReTokenLayoutLMPostProcess(object): results.append(result) return results + def decode_pred(self, pred_relations): + pred_relations_new = [] + for pred_relation in pred_relations: + pred_relation_new = [] + pred_relation = pred_relation[1:pred_relation[0, 0, 0] + 1] + for relation in pred_relation: + relation_new = dict() + relation_new['head_id'] = relation[0, 0] + relation_new['head'] = tuple(relation[1]) + relation_new['head_type'] = relation[2, 0] + relation_new['tail_id'] = relation[3, 0] + relation_new['tail'] = tuple(relation[4]) + relation_new['tail_type'] = relation[5, 0] + relation_new['type'] = relation[6, 0] + pred_relation_new.append(relation_new) + pred_relations_new.append(pred_relation_new) + return pred_relations_new + class DistillationRePostProcess(VQAReTokenLayoutLMPostProcess): """ diff --git a/ppstructure/docs/models_list.md b/ppstructure/docs/models_list.md index 935d12d756eec467574f9ae32d48c70a3ea054c3..afed95600f0858b1423a105c4f5bcd3e092211ab 100644 --- a/ppstructure/docs/models_list.md +++ b/ppstructure/docs/models_list.md @@ -51,9 +51,9 @@ |模型名称|模型简介 | 推理模型大小| 精度(hmean) | 预测耗时(ms) | 下载地址| | --- | --- | --- |--- |--- | --- | |ser_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的SER模型|1.1G| 93.19% | 15.49 | [推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar) | -|re_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的RE模型|1.1G| 83.92% | 15.49 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar) | +|re_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的RE模型|1.1G| 83.92% | 15.49 |[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar) | |ser_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的SER模型|1.4G| 90.38% | 19.49 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) | -|re_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的RE模型|1.4G| 74.83% | 19.49 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | +|re_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的RE模型|1.4G| 74.83% | 19.49 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | |ser_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfund中文数据集上训练的SER模型|778M| 85.44% | 31.46 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) | |re_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M| 67.77% | 31.46 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) | |ser_LayoutLM_xfund_zh|基于LayoutLM在xfund中文数据集上训练的SER模型|430M| 77.31% | - |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) | diff --git a/ppstructure/kie/README.md b/ppstructure/kie/README.md index b3b4d47d86d0cf2871ff96951afa0007306a572b..ae39d51dd320ca99de962a9a8b8a1ed0dab9d105 100644 --- a/ppstructure/kie/README.md +++ b/ppstructure/kie/README.md @@ -209,17 +209,18 @@ python3 ./tools/infer_kie_token_ser_re.py \ #### 4.2.3 Inference using PaddleInference -At present, only SER model supports inference using PaddleInference. - Firstly, download the inference SER inference model. - ```bash mkdir inference cd inference wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar +wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar +cd .. ``` +-SER + Use the following command for inference. @@ -236,6 +237,26 @@ python3 kie/predict_kie_token_ser.py \ The visual results and text file will be saved in directory `output`. +-RE + +Use the following command for inference. + + +```bash +cd ppstructure +python3 kie/predict_kie_token_ser_re.py \ + --kie_algorithm=LayoutXLM \ + --re_model_dir=../inference/re_vi_layoutxlm_xfund_infer \ + --ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \ + --use_visual_backbone=False \ + --image_dir=./docs/kie/input/zh_val_42.jpg \ + --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \ + --vis_font_path=../doc/fonts/simfang.ttf \ + --ocr_order_method="tb-yx" +``` + +The visual results and text file will be saved in directory `output`. + ### 4.3 More diff --git a/ppstructure/kie/README_ch.md b/ppstructure/kie/README_ch.md index cc8c60009f4cb83d349c45573a9fa03832665374..15de1507fa16d9cf696a485ae89a5c95188b2f92 100644 --- a/ppstructure/kie/README_ch.md +++ b/ppstructure/kie/README_ch.md @@ -193,17 +193,18 @@ python3 ./tools/infer_kie_token_ser_re.py \ #### 4.2.3 基于PaddleInference的预测 -目前仅SER模型支持PaddleInference推理。 - -首先下载SER的推理模型。 - +首先下载SER和RE的推理模型。 ```bash mkdir inference cd inference wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar +wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar +cd .. ``` +-SER + 执行下面的命令进行预测。 ```bash @@ -219,6 +220,26 @@ python3 kie/predict_kie_token_ser.py \ 可视化结果保存在`output`目录下。 +-RE + +执行下面的命令进行预测。 + +```bash +cd ppstructure +python3 kie/predict_kie_token_ser_re.py \ + --kie_algorithm=LayoutXLM \ + --re_model_dir=../inference/re_vi_layoutxlm_xfund_infer \ + --ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \ + --use_visual_backbone=False \ + --image_dir=./docs/kie/input/zh_val_42.jpg \ + --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \ + --vis_font_path=../doc/fonts/simfang.ttf \ + --ocr_order_method="tb-yx" +``` + +可视化结果保存在`output`目录下。 + + ### 4.3 更多 关于KIE模型的训练评估与推理,请参考:[关键信息抽取教程](../../doc/doc_ch/kie.md)。 diff --git a/ppstructure/kie/predict_kie_token_ser.py b/ppstructure/kie/predict_kie_token_ser.py index 48cfc528a28e0a2bdfb51d3a537f26e891ae3286..e570979bcb419edbc2e09e190ae36ec1458c1826 100644 --- a/ppstructure/kie/predict_kie_token_ser.py +++ b/ppstructure/kie/predict_kie_token_ser.py @@ -102,16 +102,18 @@ class SerPredictor(object): ori_im = img.copy() data = {'image': img} data = transform(data, self.preprocess_op) - img = data[0] - if img is None: + if data[0] is None: return None, 0 - img = np.expand_dims(img, axis=0) - img = img.copy() starttime = time.time() + for idx in range(len(data)): + if isinstance(data[idx], np.ndarray): + data[idx] = np.expand_dims(data[idx], axis=0) + else: + data[idx] = [data[idx]] + for idx in range(len(self.input_tensor)): - expand_input = np.expand_dims(data[idx], axis=0) - self.input_tensor[idx].copy_from_cpu(expand_input) + self.input_tensor[idx].copy_from_cpu(data[idx]) self.predictor.run() @@ -122,9 +124,9 @@ class SerPredictor(object): preds = outputs[0] post_result = self.postprocess_op( - preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]]) + preds, segment_offset_ids=data[6], ocr_infos=data[7]) elapse = time.time() - starttime - return post_result, elapse + return post_result, data, elapse def main(args): @@ -145,7 +147,7 @@ def main(args): if img is None: logger.info("error in loading image:{}".format(image_file)) continue - ser_res, elapse = ser_predictor(img) + ser_res, _, elapse = ser_predictor(img) ser_res = ser_res[0] res_str = '{}\t{}\n'.format( diff --git a/ppstructure/kie/predict_kie_token_ser_re.py b/ppstructure/kie/predict_kie_token_ser_re.py new file mode 100644 index 0000000000000000000000000000000000000000..b4eace4b5ee15ccf64a03e96dafcb1cfb021e656 --- /dev/null +++ b/ppstructure/kie/predict_kie_token_ser_re.py @@ -0,0 +1,124 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import json +import numpy as np +import time + +import tools.infer.utility as utility +from tools.infer_kie_token_ser_re import make_input +from ppocr.postprocess import build_post_process +from ppocr.utils.logging import get_logger +from ppocr.utils.visual import draw_re_results +from ppocr.utils.utility import get_image_file_list, check_and_read +from ppstructure.utility import parse_args +from ppstructure.kie.predict_kie_token_ser import SerPredictor + +from paddleocr import PaddleOCR + +logger = get_logger() + + +class SerRePredictor(object): + def __init__(self, args): + self.use_visual_backbone = args.use_visual_backbone + self.ser_engine = SerPredictor(args) + + postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'} + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors, self.config = \ + utility.create_predictor(args, 're', logger) + + def __call__(self, img): + ori_im = img.copy() + starttime = time.time() + ser_results, ser_inputs, _ = self.ser_engine(img) + re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) + if self.use_visual_backbone == False: + re_input.pop(4) + for idx in range(len(self.input_tensor)): + self.input_tensor[idx].copy_from_cpu(re_input[idx]) + + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + preds = dict(loss=outputs[0], pred_relations=outputs[1]) + + post_result = self.postprocess_op( + preds, + ser_results=ser_results, + entity_idx_dict_batch=entity_idx_dict_batch) + + elapse = time.time() - starttime + return post_result, elapse + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + ser_predictor = SerRePredictor(args) + count = 0 + total_time = 0 + + os.makedirs(args.output, exist_ok=True) + with open( + os.path.join(args.output, 'infer.txt'), mode='w', + encoding='utf-8') as f_w: + for image_file in image_file_list: + img, flag, _ = check_and_read(image_file) + if not flag: + img = cv2.imread(image_file) + img = img[:, :, ::-1] + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + re_res, elapse = ser_predictor(img) + re_res = re_res[0] + + res_str = '{}\t{}\n'.format( + image_file, + json.dumps( + { + "ocr_info": re_res, + }, ensure_ascii=False)) + f_w.write(res_str) + + img_res = draw_re_results( + image_file, re_res, font_path=args.vis_font_path) + + img_save_path = os.path.join( + args.output, + os.path.splitext(os.path.basename(image_file))[0] + + "_ser_re.jpg") + + cv2.imwrite(img_save_path, img_res) + logger.info("save vis result to {}".format(img_save_path)) + if count > 0: + total_time += elapse + count += 1 + logger.info("Predict time of {}: {}".format(image_file, elapse)) + + +if __name__ == "__main__": + main(parse_args()) diff --git a/ppstructure/utility.py b/ppstructure/utility.py index 97b6d6fec0d70fe3014b0b2105dbbef6a292e4d7..9f1a46705fc129e089c4cdcb5cdd79c784b56fce 100644 --- a/ppstructure/utility.py +++ b/ppstructure/utility.py @@ -52,6 +52,8 @@ def init_args(): # params for kie parser.add_argument("--kie_algorithm", type=str, default='LayoutXLM') parser.add_argument("--ser_model_dir", type=str) + parser.add_argument("--re_model_dir", type=str) + parser.add_argument("--use_visual_backbone", type=str2bool, default=True) parser.add_argument( "--ser_dict_path", type=str, diff --git a/tools/export_model.py b/tools/export_model.py index 193988cc1b62a6c4536a8d2ec640e3e5fc81a79c..8610df83ef08926c245872e711cd1c828eb46765 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -115,16 +115,12 @@ def export_single_model(model, max_text_length = arch_config["Head"]["max_text_length"] other_shape = [ paddle.static.InputSpec( - shape=[None, 3, 48, 160], dtype="float32"), - - [ - paddle.static.InputSpec( - shape=[None, ], - dtype="float32"), - paddle.static.InputSpec( - shape=[None, max_text_length], - dtype="int64") - ] + shape=[None, 3, 48, 160], dtype="float32"), [ + paddle.static.InputSpec( + shape=[None, ], dtype="float32"), + paddle.static.InputSpec( + shape=[None, max_text_length], dtype="int64") + ] ] model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]: @@ -140,6 +136,13 @@ def export_single_model(model, paddle.static.InputSpec( shape=[None, 3, 224, 224], dtype="int64"), # image ] + if 'Re' in arch_config['Backbone']['name']: + input_spec.extend([ + paddle.static.InputSpec( + shape=[None, 512, 3], dtype="int64"), # entities + paddle.static.InputSpec( + shape=[None, None, 2], dtype="int64"), # relations + ]) if model.backbone.use_visual_backbone is False: input_spec.pop(4) model = to_static(model, input_spec=[input_spec]) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index dafbfbeaf3e25fdd402027190c92ab45cbe352b4..b9c9490bdb99f3bee67cb9460a9975b93b0d6366 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -162,6 +162,8 @@ def create_predictor(args, mode, logger): model_dir = args.table_model_dir elif mode == 'ser': model_dir = args.ser_model_dir + elif mode == 're': + model_dir = args.re_model_dir elif mode == "sr": model_dir = args.sr_model_dir elif mode == 'layout': @@ -227,7 +229,8 @@ def create_predictor(args, mode, logger): use_calib_mode=False) # collect shape - trt_shape_f = os.path.join(model_dir, f"{mode}_trt_dynamic_shape.txt") + trt_shape_f = os.path.join(model_dir, + f"{mode}_trt_dynamic_shape.txt") if not os.path.exists(trt_shape_f): config.collect_shape_range_info(trt_shape_f) @@ -262,6 +265,8 @@ def create_predictor(args, mode, logger): config.disable_glog_info() config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("matmul_transpose_reshape_fuse_pass") + if mode == 're': + config.delete_pass("simplify_with_basic_ops_pass") if mode == 'table': config.delete_pass("fc_fuse_pass") # not supported for table config.switch_use_feed_fetch_ops(False) diff --git a/tools/infer_kie_token_ser_re.py b/tools/infer_kie_token_ser_re.py index 3ee696f28470a16205be628b3aeb586ef7a9c6a6..c4fa2c927ab93cfa9082e51f08f8d6e1c35fe29e 100755 --- a/tools/infer_kie_token_ser_re.py +++ b/tools/infer_kie_token_ser_re.py @@ -63,7 +63,7 @@ class ReArgsParser(ArgsParser): def make_input(ser_inputs, ser_results): entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} - + batch_size, max_seq_len = ser_inputs[0].shape[:2] entities = ser_inputs[8][0] ser_results = ser_results[0] assert len(entities) == len(ser_results) @@ -80,34 +80,44 @@ def make_input(ser_inputs, ser_results): start.append(entity['start']) end.append(entity['end']) label.append(entities_labels[res['pred']]) - entities = dict(start=start, end=end, label=label) + + entities = np.full([max_seq_len + 1, 3], fill_value=-1) + entities[0, 0] = len(start) + entities[1:len(start) + 1, 0] = start + entities[0, 1] = len(end) + entities[1:len(end) + 1, 1] = end + entities[0, 2] = len(label) + entities[1:len(label) + 1, 2] = label # relations head = [] tail = [] - for i in range(len(entities["label"])): - for j in range(len(entities["label"])): - if entities["label"][i] == 1 and entities["label"][j] == 2: + for i in range(len(label)): + for j in range(len(label)): + if label[i] == 1 and label[j] == 2: head.append(i) tail.append(j) - relations = dict(head=head, tail=tail) + relations = np.full([len(head) + 1, 2], fill_value=-1) + relations[0, 0] = len(head) + relations[1:len(head) + 1, 0] = head + relations[0, 1] = len(tail) + relations[1:len(tail) + 1, 1] = tail + + entities = np.expand_dims(entities, axis=0) + entities = np.repeat(entities, batch_size, axis=0) + relations = np.expand_dims(relations, axis=0) + relations = np.repeat(relations, batch_size, axis=0) + + # remove ocr_info segment_offset_id and label in ser input + if isinstance(ser_inputs[0], paddle.Tensor): + entities = paddle.to_tensor(entities) + relations = paddle.to_tensor(relations) + ser_inputs = ser_inputs[:5] + [entities, relations] - batch_size = ser_inputs[0].shape[0] - entities_batch = [] - relations_batch = [] entity_idx_dict_batch = [] for b in range(batch_size): - entities_batch.append(entities) - relations_batch.append(relations) entity_idx_dict_batch.append(entity_idx_dict) - - ser_inputs[8] = entities_batch - ser_inputs.append(relations_batch) - # remove ocr_info segment_offset_id and label in ser input - ser_inputs.pop(7) - ser_inputs.pop(6) - ser_inputs.pop(5) return ser_inputs, entity_idx_dict_batch @@ -136,6 +146,8 @@ class SerRePredictor(object): def __call__(self, data): ser_results, ser_inputs = self.ser_engine(data) re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) + if self.model.backbone.use_visual_backbone is False: + re_input.pop(4) preds = self.model(re_input) post_result = self.post_process_class( preds,