diff --git a/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml b/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml
index 2401cf317987c5614a476065191e750587bc09b5..99dc771d150b15847486c096529a2828b9c0c05a 100644
--- a/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml
+++ b/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml
@@ -68,6 +68,7 @@ Train:
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
+ - TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
@@ -83,7 +84,6 @@ Train:
drop_last: False
batch_size_per_card: 2
num_workers: 8
- collate_fn: ListCollator
Eval:
dataset:
@@ -105,6 +105,7 @@ Eval:
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
+ - TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
@@ -120,4 +121,3 @@ Eval:
drop_last: False
batch_size_per_card: 8
num_workers: 8
- collate_fn: ListCollator
diff --git a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml
index ea9f50ef56ec8b169333263c1d5e96586f9472b3..811c7d2d6f16344a3d6ad060fec1a1966241d81b 100644
--- a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml
+++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml
@@ -73,6 +73,7 @@ Train:
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
+ - TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
@@ -88,7 +89,6 @@ Train:
drop_last: False
batch_size_per_card: 2
num_workers: 4
- collate_fn: ListCollator
Eval:
dataset:
@@ -112,6 +112,7 @@ Eval:
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
+ - TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
@@ -127,5 +128,3 @@ Eval:
drop_last: False
batch_size_per_card: 8
num_workers: 8
- collate_fn: ListCollator
-
diff --git a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml
index b96528d2738e7cfb2575feca4146af1eed0c5d2f..0bd42901ebd3a37eb29ce854b1e434dc356d9643 100644
--- a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml
+++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml
@@ -116,6 +116,7 @@ Train:
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
+ - TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
@@ -155,6 +156,7 @@ Eval:
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
+ - TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
diff --git a/doc/doc_ch/algorithm_kie_layoutxlm.md b/doc/doc_ch/algorithm_kie_layoutxlm.md
index e693be49b7bc89e04b169fe74cf76525b2494948..0cbcad25016974207382a044e211c704082f6467 100644
--- a/doc/doc_ch/algorithm_kie_layoutxlm.md
+++ b/doc/doc_ch/algorithm_kie_layoutxlm.md
@@ -30,7 +30,7 @@
|模型|骨干网络|任务|配置文件|hmean|下载链接|
| --- | --- |--|--- | --- | --- |
|LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)|
-|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型(coming soon)]()|
+|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar)|
@@ -52,14 +52,14 @@
### 4.1 Python推理
-**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于LayoutXLM模型的关键信息抽取过程。
+- SER
首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。
``` bash
wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar
tar -xf ser_LayoutXLM_xfun_zh.tar
-python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm
+python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
```
LayoutXLM模型基于SER任务进行推理,可以执行如下命令:
@@ -80,6 +80,34 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下
+- RE
+
+首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。
+
+``` bash
+wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar
+tar -xf re_LayoutXLM_xfun_zh.tar
+python3 tools/export_model.py -c configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
+```
+
+LayoutXLM模型基于RE任务进行推理,可以执行如下命令:
+
+```bash
+cd ppstructure
+python3 kie/predict_kie_token_ser_re.py \
+ --kie_algorithm=LayoutXLM \
+ --re_model_dir=../inference/re_layoutxlm_infer \
+ --ser_model_dir=../inference/ser_layoutxlm_infer \
+ --image_dir=./docs/kie/input/zh_val_42.jpg \
+ --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
+ --vis_font_path=../doc/fonts/simfang.ttf
+```
+
+RE可视化结果默认保存到`./output`文件夹里面,结果示例如下:
+
+
+
+
### 4.2 C++推理部署
diff --git a/doc/doc_ch/algorithm_kie_vi_layoutxlm.md b/doc/doc_ch/algorithm_kie_vi_layoutxlm.md
index f1bb4b1e62736e88594196819dcc41980f1716bf..6c69230e020d8ead1ca3f917676730afaf0a19f8 100644
--- a/doc/doc_ch/algorithm_kie_vi_layoutxlm.md
+++ b/doc/doc_ch/algorithm_kie_vi_layoutxlm.md
@@ -23,7 +23,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去
|模型|骨干网络|任务|配置文件|hmean|下载链接|
| --- | --- |---| --- | --- | --- |
|VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)|
-|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型(coming soon)]()|
+|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)|
@@ -45,7 +45,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去
### 4.1 Python推理
-**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于VI-LayoutXLM模型的关键信息抽取过程。
+-SER
首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。
@@ -74,6 +74,36 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下
+-RE
+
+首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。
+
+``` bash
+wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar
+tar -xf re_vi_layoutxlm_xfund_pretrained.tar
+python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer
+```
+
+VI-LayoutXLM模型基于RE任务进行推理,可以执行如下命令:
+
+```bash
+cd ppstructure
+python3 kie/predict_kie_token_ser_re.py \
+ --kie_algorithm=LayoutXLM \
+ --re_model_dir=../inference/re_vi_layoutxlm_infer \
+ --ser_model_dir=../inference/ser_vi_layoutxlm_infer \
+ --use_visual_backbone=False \
+ --image_dir=./docs/kie/input/zh_val_42.jpg \
+ --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
+ --vis_font_path=../doc/fonts/simfang.ttf \
+ --ocr_order_method="tb-yx"
+```
+
+RE可视化结果默认保存到`./output`文件夹里面,结果示例如下:
+
+
+
+
### 4.2 C++推理部署
diff --git a/doc/doc_en/algorithm_kie_layoutxlm_en.md b/doc/doc_en/algorithm_kie_layoutxlm_en.md
index 910c1f4d497a6e503f0a7a5ec26dbeceb2d321a1..0c82b0423b2cdc11b1817e7575629a659e599374 100644
--- a/doc/doc_en/algorithm_kie_layoutxlm_en.md
+++ b/doc/doc_en/algorithm_kie_layoutxlm_en.md
@@ -28,7 +28,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows.
|Model|Backbone|Task |Cnnfig|Hmean|Download link|
| --- | --- |--|--- | --- | --- |
|LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)|
-|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[inference model(coming soon)]()|
+|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[inference model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar)|
## 2. Environment
@@ -46,7 +46,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code
### 4.1 Python Inference
-**Note:** Currently, the RE model inference process is still in the process of adaptation. We take SER model as an example to introduce the KIE process based on LayoutXLM model.
+- SER
First, we need to export the trained model into inference model. Take LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)). Use the following command to export.
@@ -54,7 +54,7 @@ First, we need to export the trained model into inference model. Take LayoutXLM
``` bash
wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar
tar -xf ser_LayoutXLM_xfun_zh.tar
-python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm
+python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
```
Use the following command to infer using LayoutXLM SER model.
@@ -77,6 +77,38 @@ The SER visualization results are saved in the `./output` directory by default.
+- RE
+
+First, we need to export the trained model into inference model. Take LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)). Use the following command to export.
+
+
+``` bash
+wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar
+tar -xf re_LayoutXLM_xfun_zh.tar
+python3 tools/export_model.py -c configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/re_layoutxlm_infer
+```
+
+Use the following command to infer using LayoutXLM RE model.
+
+
+```bash
+cd ppstructure
+python3 kie/predict_kie_token_ser_re.py \
+ --kie_algorithm=LayoutXLM \
+ --re_model_dir=../inference/re_layoutxlm_infer \
+ --ser_model_dir=../inference/ser_layoutxlm_infer \
+ --image_dir=./docs/kie/input/zh_val_42.jpg \
+ --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
+ --vis_font_path=../doc/fonts/simfang.ttf
+```
+The RE visualization results are saved in the `./output` directory by default. The results are as follows.
+
+
+
+
+
+
+
### 4.2 C++ Inference
Not supported
diff --git a/doc/doc_en/algorithm_kie_vi_layoutxlm_en.md b/doc/doc_en/algorithm_kie_vi_layoutxlm_en.md
index 12b6e1bddbd03b820ce33ba86de3d430a44f8987..fab761f5a957c2bfdd42e09d22da27410a2ad423 100644
--- a/doc/doc_en/algorithm_kie_vi_layoutxlm_en.md
+++ b/doc/doc_en/algorithm_kie_vi_layoutxlm_en.md
@@ -22,7 +22,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows.
|Model|Backbone|Task |Cnnfig|Hmean|Download link|
| --- | --- |---| --- | --- | --- |
|VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)|
-|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model(coming soon)]()|
+|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)|
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
@@ -37,7 +37,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code
### 4.1 Python Inference
-**Note:** Currently, the RE model inference process is still in the process of adaptation. We take SER model as an example to introduce the KIE process based on VI-LayoutXLM model.
+-SER
First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export.
@@ -70,6 +70,41 @@ The SER visualization results are saved in the `./output` folder by default. The
+-RE
+
+First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export.
+
+
+``` bash
+wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar
+tar -xf re_vi_layoutxlm_xfund_pretrained.tar
+python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer
+```
+
+Use the following command to infer using VI-LayoutXLM RE model.
+
+
+```bash
+cd ppstructure
+python3 kie/predict_kie_token_ser_re.py \
+ --kie_algorithm=LayoutXLM \
+ --re_model_dir=../inference/re_vi_layoutxlm_infer \
+ --ser_model_dir=../inference/ser_vi_layoutxlm_infer \
+ --use_visual_backbone=False \
+ --image_dir=./docs/kie/input/zh_val_42.jpg \
+ --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
+ --vis_font_path=../doc/fonts/simfang.ttf \
+ --ocr_order_method="tb-yx"
+```
+
+The RE visualization results are saved in the `./output` folder by default. The results are as follows.
+
+
+
+
+
+
+
### 4.2 C++ Inference
Not supported
diff --git a/ppocr/data/imaug/vqa/__init__.py b/ppocr/data/imaug/vqa/__init__.py
index 34189bcefb17a0776bd62a19c58081286882b5a5..73f7dcdf712f6db0ff4354b1b01134d1277ff078 100644
--- a/ppocr/data/imaug/vqa/__init__.py
+++ b/ppocr/data/imaug/vqa/__init__.py
@@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
+from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation, TensorizeEntitiesRelations
__all__ = [
- 'VQATokenPad',
- 'VQASerTokenChunk',
- 'VQAReTokenChunk',
- 'VQAReTokenRelation',
+ 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation',
+ 'TensorizeEntitiesRelations'
]
diff --git a/ppocr/data/imaug/vqa/token/__init__.py b/ppocr/data/imaug/vqa/token/__init__.py
index 7c115661753cd031b16ec34697157e2fcdcf2dec..5fbaa43db9e7182cfa3efa4f3dc0d9e54c17c822 100644
--- a/ppocr/data/imaug/vqa/token/__init__.py
+++ b/ppocr/data/imaug/vqa/token/__init__.py
@@ -15,3 +15,4 @@
from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
from .vqa_token_pad import VQATokenPad
from .vqa_token_relation import VQAReTokenRelation
+from .vqa_re_convert import TensorizeEntitiesRelations
\ No newline at end of file
diff --git a/ppocr/data/imaug/vqa/token/vqa_re_convert.py b/ppocr/data/imaug/vqa/token/vqa_re_convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..86962f2590b57f38640d76ef5d8b74ead5e854e0
--- /dev/null
+++ b/ppocr/data/imaug/vqa/token/vqa_re_convert.py
@@ -0,0 +1,51 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+
+class TensorizeEntitiesRelations(object):
+ def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
+ self.max_seq_len = max_seq_len
+ self.infer_mode = infer_mode
+
+ def __call__(self, data):
+ entities = data['entities']
+ relations = data['relations']
+
+ entities_new = np.full(
+ shape=[self.max_seq_len + 1, 3], fill_value=-1, dtype='int64')
+ entities_new[0, 0] = len(entities['start'])
+ entities_new[0, 1] = len(entities['end'])
+ entities_new[0, 2] = len(entities['label'])
+ entities_new[1:len(entities['start']) + 1, 0] = np.array(entities[
+ 'start'])
+ entities_new[1:len(entities['end']) + 1, 1] = np.array(entities['end'])
+ entities_new[1:len(entities['label']) + 1, 2] = np.array(entities[
+ 'label'])
+
+ relations_new = np.full(
+ shape=[self.max_seq_len * self.max_seq_len + 1, 2],
+ fill_value=-1,
+ dtype='int64')
+ relations_new[0, 0] = len(relations['head'])
+ relations_new[0, 1] = len(relations['tail'])
+ relations_new[1:len(relations['head']) + 1, 0] = np.array(relations[
+ 'head'])
+ relations_new[1:len(relations['tail']) + 1, 1] = np.array(relations[
+ 'tail'])
+
+ data['entities'] = entities_new
+ data['relations'] = relations_new
+ return data
diff --git a/ppocr/metrics/vqa_token_re_metric.py b/ppocr/metrics/vqa_token_re_metric.py
index f84387d8beb729bcc4b420ceea24a5e9b2993c64..0509984f7e7e85fc1dae859761fedb7356a02477 100644
--- a/ppocr/metrics/vqa_token_re_metric.py
+++ b/ppocr/metrics/vqa_token_re_metric.py
@@ -37,23 +37,25 @@ class VQAReTokenMetric(object):
gt_relations = []
for b in range(len(self.relations_list)):
rel_sent = []
- if "head" in self.relations_list[b]:
- for head, tail in zip(self.relations_list[b]["head"],
- self.relations_list[b]["tail"]):
+ relation_list = self.relations_list[b]
+ entitie_list = self.entities_list[b]
+ head_len = relation_list[0, 0]
+ if head_len > 0:
+ entitie_start_list = entitie_list[1:entitie_list[0, 0] + 1, 0]
+ entitie_end_list = entitie_list[1:entitie_list[0, 1] + 1, 1]
+ entitie_label_list = entitie_list[1:entitie_list[0, 2] + 1, 2]
+ for head, tail in zip(relation_list[1:head_len + 1, 0],
+ relation_list[1:head_len + 1, 1]):
rel = {}
rel["head_id"] = head
- rel["head"] = (
- self.entities_list[b]["start"][rel["head_id"]],
- self.entities_list[b]["end"][rel["head_id"]])
- rel["head_type"] = self.entities_list[b]["label"][rel[
- "head_id"]]
+ rel["head"] = (entitie_start_list[head],
+ entitie_end_list[head])
+ rel["head_type"] = entitie_label_list[head]
rel["tail_id"] = tail
- rel["tail"] = (
- self.entities_list[b]["start"][rel["tail_id"]],
- self.entities_list[b]["end"][rel["tail_id"]])
- rel["tail_type"] = self.entities_list[b]["label"][rel[
- "tail_id"]]
+ rel["tail"] = (entitie_start_list[tail],
+ entitie_end_list[tail])
+ rel["tail_type"] = entitie_label_list[tail]
rel["type"] = 1
rel_sent.append(rel)
diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py
index 8e10ed7b48e9aff344b71e5a04970d1a5dab8a71..acb1315cc0a588396549e5b8928bd2e4d3c769be 100644
--- a/ppocr/modeling/backbones/vqa_layoutlm.py
+++ b/ppocr/modeling/backbones/vqa_layoutlm.py
@@ -218,8 +218,12 @@ class LayoutXLMForRe(NLPBaseModel):
def forward(self, x):
if self.use_visual_backbone is True:
image = x[4]
+ entities = x[5]
+ relations = x[6]
else:
image = None
+ entities = x[4]
+ relations = x[5]
x = self.model(
input_ids=x[0],
bbox=x[1],
@@ -229,6 +233,6 @@ class LayoutXLMForRe(NLPBaseModel):
position_ids=None,
head_mask=None,
labels=None,
- entities=x[5],
- relations=x[6])
+ entities=entities,
+ relations=relations)
return x
diff --git a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
index 96c25d9aac01066f7a3841fe61aa7b0fe05041bd..a6011acf8d68f65adfc84e134c9cc0e733dd68ea 100644
--- a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
+++ b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
@@ -21,18 +21,22 @@ class VQAReTokenLayoutLMPostProcess(object):
super(VQAReTokenLayoutLMPostProcess, self).__init__()
def __call__(self, preds, label=None, *args, **kwargs):
+ pred_relations = preds['pred_relations']
+ if isinstance(preds['pred_relations'], paddle.Tensor):
+ pred_relations = pred_relations.numpy()
+ pred_relations = self.decode_pred(pred_relations)
+
if label is not None:
- return self._metric(preds, label)
+ return self._metric(pred_relations, label)
else:
- return self._infer(preds, *args, **kwargs)
+ return self._infer(pred_relations, *args, **kwargs)
- def _metric(self, preds, label):
- return preds['pred_relations'], label[6], label[5]
+ def _metric(self, pred_relations, label):
+ return pred_relations, label[6], label[5]
- def _infer(self, preds, *args, **kwargs):
+ def _infer(self, pred_relations, *args, **kwargs):
ser_results = kwargs['ser_results']
entity_idx_dict_batch = kwargs['entity_idx_dict_batch']
- pred_relations = preds['pred_relations']
# merge relations and ocr info
results = []
@@ -50,6 +54,24 @@ class VQAReTokenLayoutLMPostProcess(object):
results.append(result)
return results
+ def decode_pred(self, pred_relations):
+ pred_relations_new = []
+ for pred_relation in pred_relations:
+ pred_relation_new = []
+ pred_relation = pred_relation[1:pred_relation[0, 0, 0] + 1]
+ for relation in pred_relation:
+ relation_new = dict()
+ relation_new['head_id'] = relation[0, 0]
+ relation_new['head'] = tuple(relation[1])
+ relation_new['head_type'] = relation[2, 0]
+ relation_new['tail_id'] = relation[3, 0]
+ relation_new['tail'] = tuple(relation[4])
+ relation_new['tail_type'] = relation[5, 0]
+ relation_new['type'] = relation[6, 0]
+ pred_relation_new.append(relation_new)
+ pred_relations_new.append(pred_relation_new)
+ return pred_relations_new
+
class DistillationRePostProcess(VQAReTokenLayoutLMPostProcess):
"""
diff --git a/ppstructure/docs/models_list.md b/ppstructure/docs/models_list.md
index 935d12d756eec467574f9ae32d48c70a3ea054c3..afed95600f0858b1423a105c4f5bcd3e092211ab 100644
--- a/ppstructure/docs/models_list.md
+++ b/ppstructure/docs/models_list.md
@@ -51,9 +51,9 @@
|模型名称|模型简介 | 推理模型大小| 精度(hmean) | 预测耗时(ms) | 下载地址|
| --- | --- | --- |--- |--- | --- |
|ser_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的SER模型|1.1G| 93.19% | 15.49 | [推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar) |
-|re_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的RE模型|1.1G| 83.92% | 15.49 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar) |
+|re_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的RE模型|1.1G| 83.92% | 15.49 |[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar) |
|ser_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的SER模型|1.4G| 90.38% | 19.49 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
-|re_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的RE模型|1.4G| 74.83% | 19.49 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
+|re_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的RE模型|1.4G| 74.83% | 19.49 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfund中文数据集上训练的SER模型|778M| 85.44% | 31.46 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M| 67.77% | 31.46 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfund_zh|基于LayoutLM在xfund中文数据集上训练的SER模型|430M| 77.31% | - |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
diff --git a/ppstructure/kie/README.md b/ppstructure/kie/README.md
index b3b4d47d86d0cf2871ff96951afa0007306a572b..ae39d51dd320ca99de962a9a8b8a1ed0dab9d105 100644
--- a/ppstructure/kie/README.md
+++ b/ppstructure/kie/README.md
@@ -209,17 +209,18 @@ python3 ./tools/infer_kie_token_ser_re.py \
#### 4.2.3 Inference using PaddleInference
-At present, only SER model supports inference using PaddleInference.
-
Firstly, download the inference SER inference model.
-
```bash
mkdir inference
cd inference
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
+wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar
+cd ..
```
+-SER
+
Use the following command for inference.
@@ -236,6 +237,26 @@ python3 kie/predict_kie_token_ser.py \
The visual results and text file will be saved in directory `output`.
+-RE
+
+Use the following command for inference.
+
+
+```bash
+cd ppstructure
+python3 kie/predict_kie_token_ser_re.py \
+ --kie_algorithm=LayoutXLM \
+ --re_model_dir=../inference/re_vi_layoutxlm_xfund_infer \
+ --ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \
+ --use_visual_backbone=False \
+ --image_dir=./docs/kie/input/zh_val_42.jpg \
+ --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
+ --vis_font_path=../doc/fonts/simfang.ttf \
+ --ocr_order_method="tb-yx"
+```
+
+The visual results and text file will be saved in directory `output`.
+
### 4.3 More
diff --git a/ppstructure/kie/README_ch.md b/ppstructure/kie/README_ch.md
index cc8c60009f4cb83d349c45573a9fa03832665374..15de1507fa16d9cf696a485ae89a5c95188b2f92 100644
--- a/ppstructure/kie/README_ch.md
+++ b/ppstructure/kie/README_ch.md
@@ -193,17 +193,18 @@ python3 ./tools/infer_kie_token_ser_re.py \
#### 4.2.3 基于PaddleInference的预测
-目前仅SER模型支持PaddleInference推理。
-
-首先下载SER的推理模型。
-
+首先下载SER和RE的推理模型。
```bash
mkdir inference
cd inference
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
+wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar
+cd ..
```
+-SER
+
执行下面的命令进行预测。
```bash
@@ -219,6 +220,26 @@ python3 kie/predict_kie_token_ser.py \
可视化结果保存在`output`目录下。
+-RE
+
+执行下面的命令进行预测。
+
+```bash
+cd ppstructure
+python3 kie/predict_kie_token_ser_re.py \
+ --kie_algorithm=LayoutXLM \
+ --re_model_dir=../inference/re_vi_layoutxlm_xfund_infer \
+ --ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \
+ --use_visual_backbone=False \
+ --image_dir=./docs/kie/input/zh_val_42.jpg \
+ --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
+ --vis_font_path=../doc/fonts/simfang.ttf \
+ --ocr_order_method="tb-yx"
+```
+
+可视化结果保存在`output`目录下。
+
+
### 4.3 更多
关于KIE模型的训练评估与推理,请参考:[关键信息抽取教程](../../doc/doc_ch/kie.md)。
diff --git a/ppstructure/kie/predict_kie_token_ser.py b/ppstructure/kie/predict_kie_token_ser.py
index 48cfc528a28e0a2bdfb51d3a537f26e891ae3286..e570979bcb419edbc2e09e190ae36ec1458c1826 100644
--- a/ppstructure/kie/predict_kie_token_ser.py
+++ b/ppstructure/kie/predict_kie_token_ser.py
@@ -102,16 +102,18 @@ class SerPredictor(object):
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
- img = data[0]
- if img is None:
+ if data[0] is None:
return None, 0
- img = np.expand_dims(img, axis=0)
- img = img.copy()
starttime = time.time()
+ for idx in range(len(data)):
+ if isinstance(data[idx], np.ndarray):
+ data[idx] = np.expand_dims(data[idx], axis=0)
+ else:
+ data[idx] = [data[idx]]
+
for idx in range(len(self.input_tensor)):
- expand_input = np.expand_dims(data[idx], axis=0)
- self.input_tensor[idx].copy_from_cpu(expand_input)
+ self.input_tensor[idx].copy_from_cpu(data[idx])
self.predictor.run()
@@ -122,9 +124,9 @@ class SerPredictor(object):
preds = outputs[0]
post_result = self.postprocess_op(
- preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]])
+ preds, segment_offset_ids=data[6], ocr_infos=data[7])
elapse = time.time() - starttime
- return post_result, elapse
+ return post_result, data, elapse
def main(args):
@@ -145,7 +147,7 @@ def main(args):
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
- ser_res, elapse = ser_predictor(img)
+ ser_res, _, elapse = ser_predictor(img)
ser_res = ser_res[0]
res_str = '{}\t{}\n'.format(
diff --git a/ppstructure/kie/predict_kie_token_ser_re.py b/ppstructure/kie/predict_kie_token_ser_re.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4eace4b5ee15ccf64a03e96dafcb1cfb021e656
--- /dev/null
+++ b/ppstructure/kie/predict_kie_token_ser_re.py
@@ -0,0 +1,124 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import json
+import numpy as np
+import time
+
+import tools.infer.utility as utility
+from tools.infer_kie_token_ser_re import make_input
+from ppocr.postprocess import build_post_process
+from ppocr.utils.logging import get_logger
+from ppocr.utils.visual import draw_re_results
+from ppocr.utils.utility import get_image_file_list, check_and_read
+from ppstructure.utility import parse_args
+from ppstructure.kie.predict_kie_token_ser import SerPredictor
+
+from paddleocr import PaddleOCR
+
+logger = get_logger()
+
+
+class SerRePredictor(object):
+ def __init__(self, args):
+ self.use_visual_backbone = args.use_visual_backbone
+ self.ser_engine = SerPredictor(args)
+
+ postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'}
+ self.postprocess_op = build_post_process(postprocess_params)
+ self.predictor, self.input_tensor, self.output_tensors, self.config = \
+ utility.create_predictor(args, 're', logger)
+
+ def __call__(self, img):
+ ori_im = img.copy()
+ starttime = time.time()
+ ser_results, ser_inputs, _ = self.ser_engine(img)
+ re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
+ if self.use_visual_backbone == False:
+ re_input.pop(4)
+ for idx in range(len(self.input_tensor)):
+ self.input_tensor[idx].copy_from_cpu(re_input[idx])
+
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ preds = dict(loss=outputs[0], pred_relations=outputs[1])
+
+ post_result = self.postprocess_op(
+ preds,
+ ser_results=ser_results,
+ entity_idx_dict_batch=entity_idx_dict_batch)
+
+ elapse = time.time() - starttime
+ return post_result, elapse
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
+ ser_predictor = SerRePredictor(args)
+ count = 0
+ total_time = 0
+
+ os.makedirs(args.output, exist_ok=True)
+ with open(
+ os.path.join(args.output, 'infer.txt'), mode='w',
+ encoding='utf-8') as f_w:
+ for image_file in image_file_list:
+ img, flag, _ = check_and_read(image_file)
+ if not flag:
+ img = cv2.imread(image_file)
+ img = img[:, :, ::-1]
+ if img is None:
+ logger.info("error in loading image:{}".format(image_file))
+ continue
+ re_res, elapse = ser_predictor(img)
+ re_res = re_res[0]
+
+ res_str = '{}\t{}\n'.format(
+ image_file,
+ json.dumps(
+ {
+ "ocr_info": re_res,
+ }, ensure_ascii=False))
+ f_w.write(res_str)
+
+ img_res = draw_re_results(
+ image_file, re_res, font_path=args.vis_font_path)
+
+ img_save_path = os.path.join(
+ args.output,
+ os.path.splitext(os.path.basename(image_file))[0] +
+ "_ser_re.jpg")
+
+ cv2.imwrite(img_save_path, img_res)
+ logger.info("save vis result to {}".format(img_save_path))
+ if count > 0:
+ total_time += elapse
+ count += 1
+ logger.info("Predict time of {}: {}".format(image_file, elapse))
+
+
+if __name__ == "__main__":
+ main(parse_args())
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
index 97b6d6fec0d70fe3014b0b2105dbbef6a292e4d7..9f1a46705fc129e089c4cdcb5cdd79c784b56fce 100644
--- a/ppstructure/utility.py
+++ b/ppstructure/utility.py
@@ -52,6 +52,8 @@ def init_args():
# params for kie
parser.add_argument("--kie_algorithm", type=str, default='LayoutXLM')
parser.add_argument("--ser_model_dir", type=str)
+ parser.add_argument("--re_model_dir", type=str)
+ parser.add_argument("--use_visual_backbone", type=str2bool, default=True)
parser.add_argument(
"--ser_dict_path",
type=str,
diff --git a/tools/export_model.py b/tools/export_model.py
index 193988cc1b62a6c4536a8d2ec640e3e5fc81a79c..8610df83ef08926c245872e711cd1c828eb46765 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -115,16 +115,12 @@ def export_single_model(model,
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
paddle.static.InputSpec(
- shape=[None, 3, 48, 160], dtype="float32"),
-
- [
- paddle.static.InputSpec(
- shape=[None, ],
- dtype="float32"),
- paddle.static.InputSpec(
- shape=[None, max_text_length],
- dtype="int64")
- ]
+ shape=[None, 3, 48, 160], dtype="float32"), [
+ paddle.static.InputSpec(
+ shape=[None, ], dtype="float32"),
+ paddle.static.InputSpec(
+ shape=[None, max_text_length], dtype="int64")
+ ]
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
@@ -140,6 +136,13 @@ def export_single_model(model,
paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype="int64"), # image
]
+ if 'Re' in arch_config['Backbone']['name']:
+ input_spec.extend([
+ paddle.static.InputSpec(
+ shape=[None, 512, 3], dtype="int64"), # entities
+ paddle.static.InputSpec(
+ shape=[None, None, 2], dtype="int64"), # relations
+ ])
if model.backbone.use_visual_backbone is False:
input_spec.pop(4)
model = to_static(model, input_spec=[input_spec])
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index dafbfbeaf3e25fdd402027190c92ab45cbe352b4..b9c9490bdb99f3bee67cb9460a9975b93b0d6366 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -162,6 +162,8 @@ def create_predictor(args, mode, logger):
model_dir = args.table_model_dir
elif mode == 'ser':
model_dir = args.ser_model_dir
+ elif mode == 're':
+ model_dir = args.re_model_dir
elif mode == "sr":
model_dir = args.sr_model_dir
elif mode == 'layout':
@@ -227,7 +229,8 @@ def create_predictor(args, mode, logger):
use_calib_mode=False)
# collect shape
- trt_shape_f = os.path.join(model_dir, f"{mode}_trt_dynamic_shape.txt")
+ trt_shape_f = os.path.join(model_dir,
+ f"{mode}_trt_dynamic_shape.txt")
if not os.path.exists(trt_shape_f):
config.collect_shape_range_info(trt_shape_f)
@@ -262,6 +265,8 @@ def create_predictor(args, mode, logger):
config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.delete_pass("matmul_transpose_reshape_fuse_pass")
+ if mode == 're':
+ config.delete_pass("simplify_with_basic_ops_pass")
if mode == 'table':
config.delete_pass("fc_fuse_pass") # not supported for table
config.switch_use_feed_fetch_ops(False)
diff --git a/tools/infer_kie_token_ser_re.py b/tools/infer_kie_token_ser_re.py
index 3ee696f28470a16205be628b3aeb586ef7a9c6a6..c4fa2c927ab93cfa9082e51f08f8d6e1c35fe29e 100755
--- a/tools/infer_kie_token_ser_re.py
+++ b/tools/infer_kie_token_ser_re.py
@@ -63,7 +63,7 @@ class ReArgsParser(ArgsParser):
def make_input(ser_inputs, ser_results):
entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
-
+ batch_size, max_seq_len = ser_inputs[0].shape[:2]
entities = ser_inputs[8][0]
ser_results = ser_results[0]
assert len(entities) == len(ser_results)
@@ -80,34 +80,44 @@ def make_input(ser_inputs, ser_results):
start.append(entity['start'])
end.append(entity['end'])
label.append(entities_labels[res['pred']])
- entities = dict(start=start, end=end, label=label)
+
+ entities = np.full([max_seq_len + 1, 3], fill_value=-1)
+ entities[0, 0] = len(start)
+ entities[1:len(start) + 1, 0] = start
+ entities[0, 1] = len(end)
+ entities[1:len(end) + 1, 1] = end
+ entities[0, 2] = len(label)
+ entities[1:len(label) + 1, 2] = label
# relations
head = []
tail = []
- for i in range(len(entities["label"])):
- for j in range(len(entities["label"])):
- if entities["label"][i] == 1 and entities["label"][j] == 2:
+ for i in range(len(label)):
+ for j in range(len(label)):
+ if label[i] == 1 and label[j] == 2:
head.append(i)
tail.append(j)
- relations = dict(head=head, tail=tail)
+ relations = np.full([len(head) + 1, 2], fill_value=-1)
+ relations[0, 0] = len(head)
+ relations[1:len(head) + 1, 0] = head
+ relations[0, 1] = len(tail)
+ relations[1:len(tail) + 1, 1] = tail
+
+ entities = np.expand_dims(entities, axis=0)
+ entities = np.repeat(entities, batch_size, axis=0)
+ relations = np.expand_dims(relations, axis=0)
+ relations = np.repeat(relations, batch_size, axis=0)
+
+ # remove ocr_info segment_offset_id and label in ser input
+ if isinstance(ser_inputs[0], paddle.Tensor):
+ entities = paddle.to_tensor(entities)
+ relations = paddle.to_tensor(relations)
+ ser_inputs = ser_inputs[:5] + [entities, relations]
- batch_size = ser_inputs[0].shape[0]
- entities_batch = []
- relations_batch = []
entity_idx_dict_batch = []
for b in range(batch_size):
- entities_batch.append(entities)
- relations_batch.append(relations)
entity_idx_dict_batch.append(entity_idx_dict)
-
- ser_inputs[8] = entities_batch
- ser_inputs.append(relations_batch)
- # remove ocr_info segment_offset_id and label in ser input
- ser_inputs.pop(7)
- ser_inputs.pop(6)
- ser_inputs.pop(5)
return ser_inputs, entity_idx_dict_batch
@@ -136,6 +146,8 @@ class SerRePredictor(object):
def __call__(self, data):
ser_results, ser_inputs = self.ser_engine(data)
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
+ if self.model.backbone.use_visual_backbone is False:
+ re_input.pop(4)
preds = self.model(re_input)
post_result = self.post_process_class(
preds,