diff --git a/PPOCRLabel/README_ch.md b/PPOCRLabel/README_ch.md
index d745c9734636e2637f5019a5dd5097dd88cf2e56..e50e7b7daf5ac203129cf80b5a926a988929af5c 100644
--- a/PPOCRLabel/README_ch.md
+++ b/PPOCRLabel/README_ch.md
@@ -71,6 +71,8 @@ pip3 install opencv-contrib-python-headless==4.2.0.32 # 如果下载过慢请添
PPOCRLabel --lang ch # 启动
```
+> 如果上述安装出现问题,可以参考3.6节 错误提示
+
#### 1.2.2 本地构建whl包并安装
```bash
diff --git a/doc/doc_ch/code_and_doc.md b/doc/doc_ch/code_and_doc.md
index b1d8b4b36bd45fc1574b5049ce9af808a00b7574..7a4c64efaff22e99b6d95151ec3675c50a5a0910 100644
--- a/doc/doc_ch/code_and_doc.md
+++ b/doc/doc_ch/code_and_doc.md
@@ -139,7 +139,7 @@ PaddleOCR欢迎大家向repo中积极贡献代码,下面给出一些贡献代
- 在PaddleOCR的 [GitHub首页](https://github.com/PaddlePaddle/PaddleOCR),点击左上角 `Fork` 按钮,在你的个人目录下创建 `远程仓库`,比如`https://github.com/{your_name}/PaddleOCR`。
-![banner](/Users/zhulingfeng01/OCR/PaddleOCR/doc/banner.png)
+![banner](../banner.png)
- 将 `远程仓库` Clone到本地
@@ -230,7 +230,7 @@ pre-commit
重复上述步骤,直到pre-comit格式检查不报错。如下所示。
-[![img](https://github.com/PaddlePaddle/PaddleClas/raw/release/2.3/docs/images/quick_start/community/003_precommit_pass.png)](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/images/quick_start/community/003_precommit_pass.png)
+![img](../precommit_pass.png)
使用下面的命令完成提交。
@@ -258,7 +258,7 @@ git push origin new_branch
点击new pull request,选择本地分支和目标分支,如下图所示。在PR的描述说明中,填写该PR所完成的功能。接下来等待review,如果有需要修改的地方,参照上述步骤更新 origin 中的对应分支即可。
-![banner](/Users/zhulingfeng01/OCR/PaddleOCR/doc/pr.png)
+![banner](../pr.png)
#### 3.2.8 签署CLA协议和通过单元测试
diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
index c964d23117d022531d1181455a7b1c6c1d08ccae..c02da14af495cd807668dca6d7f3823d1de6820d 100755
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -34,7 +34,7 @@ inference 模型(`paddle.jit.save`保存的模型)
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
- [2. 其他模型推理](#其他模型推理)
-- [六、参数解释](参数解释)
+- [六、参数解释](#参数解释)
@@ -504,7 +504,7 @@ PSE算法相关参数如下
| e2e_model_dir | str | 无,如果使用端到端模型,该项是必填项 | 端到端模型inference模型路径 |
| e2e_limit_side_len | int | 768 | 端到端的输入图像边长限制 |
| e2e_limit_type | str | "max" | 端到端的边长限制类型,目前支持`min`, `max`,`min`表示保证图像最短边不小于`e2e_limit_side_len`,`max`表示保证图像最长边不大于`e2e_limit_side_len` |
-| e2e_pgnet_score_thresh | float | xx | xx |
+| e2e_pgnet_score_thresh | float | 0.5 | 端到端得分阈值,小于该阈值的结果会被丢弃 |
| e2e_char_dict_path | str | "./ppocr/utils/ic15_dict.txt" | 识别的字典文件路径 |
| e2e_pgnet_valid_set | str | "totaltext" | 验证集名称,目前支持`totaltext`, `partvgg`,不同数据集对应的后处理方式不同,与训练过程保持一致即可 |
| e2e_pgnet_mode | str | "fast" | PGNet的检测结果得分计算方法,支持`fast`和`slow`,`fast`是根据polygon的外接矩形边框内的所有像素计算平均得分,`slow`是根据原始polygon内的所有像素计算平均得分,计算速度相对较慢一些,但是更加准确一些。 |
diff --git a/doc/doc_ch/models_list.md b/doc/doc_ch/models_list.md
index 8f1a53bccacde8e478e67c7eae5df3c818bb4004..6843ffdc19d5bde205124c30f1d0a5fc2144ce99 100644
--- a/doc/doc_ch/models_list.md
+++ b/doc/doc_ch/models_list.md
@@ -1,4 +1,4 @@
-# OCR模型列表(V2.1,2021年9月6日更新)
+# PP-OCR系列模型列表(V2.1,2021年9月6日更新)
> **说明**
> 1. 2.1版模型相比2.0版模型,2.1的模型在模型精度上做了提升
diff --git a/doc/doc_ch/pgnet.md b/doc/doc_ch/pgnet.md
index 9aa7f255e54ce8dec3a20d475cccb71847d95cc7..0aee58ec1aca24d06305c47569fdf156df6ee874 100644
--- a/doc/doc_ch/pgnet.md
+++ b/doc/doc_ch/pgnet.md
@@ -66,13 +66,13 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.
### 单张图像或者图像集合预测
```bash
# 预测image_dir指定的单张图像
-python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_valid_set="totaltext"
# 预测image_dir指定的图像集合
-python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_valid_set="totaltext"
# 如果想使用CPU进行预测,需设置use_gpu参数为False
-python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True --use_gpu=False
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_valid_set="totaltext" --use_gpu=False
```
### 可视化结果
可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
@@ -167,9 +167,9 @@ python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img=
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
```
-**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
+**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"` and `--e2e_pgnet_valid_set="partvgg"`**,可以执行如下命令:
```
-python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_valid_set="partvgg" --e2e_pgnet_valid_set="totaltext"
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
@@ -178,9 +178,9 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
#### (2). 弯曲文本检测模型(Total-Text)
对于弯曲文本样例
-**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
+**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_valid_set="totaltext"`,**可以执行如下命令:
```
-python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_valid_set="totaltext"
```
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
diff --git a/doc/doc_ch/thirdparty.md b/doc/doc_ch/thirdparty.md
index d01f4b09c01d2c090c829bbb9c58c43557566118..b83b8fee8dbbf867d95c4cd0e087ebfde5f4bfc1 100644
--- a/doc/doc_ch/thirdparty.md
+++ b/doc/doc_ch/thirdparty.md
@@ -12,30 +12,37 @@ PaddleOCR希望可以通过AI的力量助力任何一位有梦想的开发者实
## 1. 社区贡献
-### 1.1 为PaddleOCR新增功能
+### 1.1 基于PaddleOCR的社区贡献
+
+- 【最新】 [FastOCRLabel](https://gitee.com/BaoJianQiang/FastOCRLabel):完整的C#版本标注工具 (@ [包建强](https://gitee.com/BaoJianQiang) )
+
+#### 1.1.1 通用工具
+
+- [DangoOCR离线版](https://github.com/PantsuDango/DangoOCR):通用型桌面级即时翻译工具 (@ [PantsuDango](https://github.com/PantsuDango))
+- [scr2txt](https://github.com/lstwzd/scr2txt):截屏转文字工具 (@ [lstwzd](https://github.com/lstwzd))
+- [AI Studio项目](https://aistudio.baidu.com/aistudio/projectdetail/1054614?channelType=0&channel=0):英文视频自动生成字幕( @ [叶月水狐](https://aistudio.baidu.com/aistudio/personalcenter/thirdview/322052))
+
+#### 1.1.2 垂类场景工具
+
+- [id_card_ocr](https://github.com/baseli/id_card_ocr):身份证复印件识别(@ [baseli](https://github.com/baseli))
+- [Paddle_Table_Image_Reader](https://github.com/thunder95/Paddle_Table_Image_Reader):能看懂表格图片的数据助手(@ [thunder95](https://github.com/thunder95]))
+
+#### 1.1.3 前后处理
+
+- [paddleOCRCorrectOutputs](https://github.com/yuranusduke/paddleOCRCorrectOutputs):获取OCR识别结果的key-value(@ [yuranusduke](https://github.com/yuranusduke))
+
+### 1.2 为PaddleOCR新增功能
- 非常感谢 [authorfu](https://github.com/authorfu) 贡献Android([#340](https://github.com/PaddlePaddle/PaddleOCR/pull/340))和[xiadeye](https://github.com/xiadeye) 贡献IOS的demo代码([#325](https://github.com/PaddlePaddle/PaddleOCR/pull/325))
- 非常感谢 [tangmq](https://gitee.com/tangmq) 给PaddleOCR增加Docker化部署服务,支持快速发布可调用的Restful API服务([#507](https://github.com/PaddlePaddle/PaddleOCR/pull/507))。
- 非常感谢 [lijinhan](https://github.com/lijinhan) 给PaddleOCR增加java SpringBoot 调用OCR Hubserving接口完成对OCR服务化部署的使用([#1027](https://github.com/PaddlePaddle/PaddleOCR/pull/1027))。
- 非常感谢 [Evezerest](https://github.com/Evezerest), [ninetailskim](https://github.com/ninetailskim), [edencfc](https://github.com/edencfc), [BeyondYourself](https://github.com/BeyondYourself), [1084667371](https://github.com/1084667371) 贡献了[PPOCRLabel](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/PPOCRLabel/README_ch.md) 的完整代码。
-### 1.2 基于PaddleOCR的社区贡献
-
-- 【最新】完整的C#版本标注工具 [FastOCRLabel](https://gitee.com/BaoJianQiang/FastOCRLabel) (@ [包建强](https://gitee.com/BaoJianQiang) )
-- 通用型桌面级即时翻译工具 [DangoOCR离线版](https://github.com/PantsuDango/DangoOCR) (@ [PantsuDango](https://github.com/PantsuDango))
-- 获取OCR识别结果的key-value [paddleOCRCorrectOutputs](https://github.com/yuranusduke/paddleOCRCorrectOutputs) (@ [yuranusduke](https://github.com/yuranusduke))
-- 截屏转文字工具 [scr2txt](https://github.com/lstwzd/scr2txt) (@ [lstwzd](https://github.com/lstwzd))
-- 身份证复印件识别 [id_card_ocr](https://github.com/baseli/id_card_ocr)(@ [baseli](https://github.com/baseli))
-- 能看懂表格图片的数据助手:[Paddle_Table_Image_Reader](https://github.com/thunder95/Paddle_Table_Image_Reader) (@ [thunder95][https://github.com/thunder95])
-- 英文视频自动生成字幕 [AI Studio项目](https://aistudio.baidu.com/aistudio/projectdetail/1054614?channelType=0&channel=0)( @ [叶月水狐](https://aistudio.baidu.com/aistudio/personalcenter/thirdview/322052))
-
### 1.3 代码与文档优化
-
- 非常感谢 [zhangxin](https://github.com/ZhangXinNan)([Blog](https://blog.csdn.net/sdlypyzq)) 贡献新的可视化方式、添加.gitgnore、处理手动设置PYTHONPATH环境变量的问题([#210](https://github.com/PaddlePaddle/PaddleOCR/pull/210))。
- 非常感谢 [lyl120117](https://github.com/lyl120117) 贡献打印网络结构的代码([#304](https://github.com/PaddlePaddle/PaddleOCR/pull/304))。
- 非常感谢 [BeyondYourself](https://github.com/BeyondYourself) 给PaddleOCR提了很多非常棒的建议,并简化了PaddleOCR的部分代码风格([so many commits)](https://github.com/PaddlePaddle/PaddleOCR/commits?author=BeyondYourself)。
-
- 非常感谢 [Khanh Tran](https://github.com/xxxpsyduck) 和 [Karl Horky](https://github.com/karlhorky) 贡献修改英文文档。
### 1.4 多语言语料
diff --git a/doc/doc_en/pgnet_en.md b/doc/doc_en/pgnet_en.md
index d2c6b30248ebad920c41ca53ee38cce828dddb8c..e176a1260c734974e2dad843faeb3e5532176629 100644
--- a/doc/doc_en/pgnet_en.md
+++ b/doc/doc_en/pgnet_en.md
@@ -59,13 +59,13 @@ After decompression, there should be the following file structure:
### Single image or image set prediction
```bash
# Prediction single image specified by image_dir
-python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_valid_set="totaltext"
# Prediction the collection of images specified by image_dir
-python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_valid_set="totaltext"
# If you want to use CPU for prediction, you need to set use_gpu parameter is false
-python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True --use_gpu=False
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --use_gpu=False --e2e_pgnet_valid_set="totaltext"
```
### Visualization results
The visualized end-to-end results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows:
@@ -166,9 +166,9 @@ First, convert the model saved in the PGNet end-to-end training process into an
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
```
-**For PGNet quadrangle end-to-end model inference, you need to set the parameter `--e2e_algorithm="PGNet"`**, run the following command:
+**For PGNet quadrangle end-to-end model inference, you need to set the parameter `--e2e_algorithm="PGNet"` and `--e2e_pgnet_valid_set="partvgg"`**, run the following command:
```
-python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_valid_set="partvgg"
```
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows:
@@ -176,9 +176,9 @@ The visualized text detection results are saved to the `./inference_results` fol
#### (2). Curved text detection model (Total-Text)
For the curved text example, we use the same model as the quadrilateral
-**For PGNet end-to-end curved text detection model inference, you need to set the parameter `--e2e_algorithm="PGNet"` and `--e2e_pgnet_polygon=True`**, run the following command:
+**For PGNet end-to-end curved text detection model inference, you need to set the parameter `--e2e_algorithm="PGNet"` and `--e2e_pgnet_valid_set="totaltext"`**, run the following command:
```
-python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_valid_set="totaltext"
```
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows:
diff --git a/doc/joinus.PNG b/doc/joinus.PNG
index 99964b62d0e8a5867d5eb7a29640f0414c7af3b2..a7c5e29b0cf1355092a848f14dbf90b08919a7f8 100644
Binary files a/doc/joinus.PNG and b/doc/joinus.PNG differ
diff --git a/doc/precommit_pass.png b/doc/precommit_pass.png
new file mode 100644
index 0000000000000000000000000000000000000000..067fb75ddb222ab0b9c71a46619c3fe7b239bc26
Binary files /dev/null and b/doc/precommit_pass.png differ
diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8d117fdeb16e1c0e90bf6ec89924e414fc764249
--- /dev/null
+++ b/ppstructure/vqa/README.md
@@ -0,0 +1,182 @@
+# 视觉问答(VQA)
+
+VQA主要特性如下:
+
+- 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。
+- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取(比如判断问题对)
+- 支持SER任务与OCR引擎联合的端到端系统预测与评估。
+- 支持SER任务和RE任务的自定义训练
+
+
+本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现,
+包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。
+
+## 1. 效果演示
+
+**注意:** 测试图片来源于XFUN数据集。
+
+### 1.1 SER
+
+
+
+
+
+
+
+
+
+其中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别,在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
+
+
+### 1.2 RE
+
+* Coming soon!
+
+
+
+## 2. 安装
+
+### 2.1 安装依赖
+
+- **(1) 安装PaddlePaddle**
+
+```bash
+pip3 install --upgrade pip
+
+# GPU安装
+python3 -m pip install paddlepaddle-gpu==2.2 -i https://mirror.baidu.com/pypi/simple
+
+# CPU安装
+python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
+
+```
+更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
+
+
+### 2.2 安装PaddleOCR(包含 PP-OCR 和 VQA )
+
+- **(1)pip快速安装PaddleOCR whl包(仅预测)**
+
+```bash
+pip install "paddleocr>=2.2" # 推荐使用2.2+版本
+```
+
+- **(2)下载VQA源码(预测+训练)**
+
+```bash
+【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR
+
+# 如果因为网络问题无法pull成功,也可选择使用码云上的托管:
+git clone https://gitee.com/paddlepaddle/PaddleOCR
+
+# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
+```
+
+- **(3)安装PaddleNLP**
+
+```bash
+# 需要使用PaddleNLP最新的代码版本进行安装
+git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
+cd PaddleNLP
+pip install -e .
+```
+
+
+- **(4)安装VQA的`requirements`**
+
+```bash
+pip install -r requirements.txt
+```
+
+## 3. 使用
+
+
+### 3.1 数据和预训练模型准备
+
+处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)。
+
+
+下载并解压该数据集,解压后将数据集放置在当前目录下。
+
+```shell
+wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
+```
+
+如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py)。
+
+如果希望直接体验预测过程,可以下载我们提供的SER预训练模型,跳过训练过程,直接预测即可。
+
+* SER任务预训练模型下载链接:[链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)
+* RE任务预训练模型下载链接:coming soon!
+
+
+### 3.2 SER任务
+
+* 启动训练
+
+```shell
+python train_ser.py \
+ --model_name_or_path "layoutxlm-base-uncased" \
+ --train_data_dir "XFUND/zh_train/image" \
+ --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
+ --eval_data_dir "XFUND/zh_val/image" \
+ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
+ --num_train_epochs 200 \
+ --eval_steps 10 \
+ --save_steps 500 \
+ --output_dir "./output/ser/" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --evaluate_during_training \
+ --seed 2048
+```
+
+最终会打印出`precision`, `recall`, `f1`等指标,如下所示。
+
+```
+best metrics: {'loss': 1.066644651549203, 'precision': 0.8770182068017863, 'recall': 0.9361936193619362, 'f1': 0.9056402979780063}
+```
+
+模型和训练日志会保存在`./output/ser/`文件夹中。
+
+* 使用评估集合中提供的OCR识别结果进行预测
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python3.7 infer_ser.py \
+ --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
+ --output_dir "output_res/" \
+ --infer_imgs "XFUND/zh_val/image/" \
+ --ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
+```
+
+最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`。
+
+* 使用`OCR引擎 + SER`串联结果
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python3.7 infer_ser_e2e.py \
+ --model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
+ --max_seq_length 512 \
+ --output_dir "output_res_e2e/"
+```
+
+* 对`OCR引擎 + SER`预测系统进行端到端评估
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
+```
+
+
+3.3 RE任务
+
+coming soon!
+
+
+## 参考链接
+
+- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
+- microsoft/unilm/layoutxlm, https://github.com/microsoft/unilm/tree/master/layoutxlm
+- XFUND dataset, https://github.com/doc-analysis/XFUND
diff --git a/ppstructure/vqa/helper/eval_with_label_end2end.py b/ppstructure/vqa/helper/eval_with_label_end2end.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8dd3e0ad437e51e21ebc53daeec9fdf9aa76b63
--- /dev/null
+++ b/ppstructure/vqa/helper/eval_with_label_end2end.py
@@ -0,0 +1,262 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+import sys
+# import Polygon
+import shapely
+from shapely.geometry import Polygon
+import numpy as np
+from collections import defaultdict
+import operator
+import editdistance
+import argparse
+import json
+import copy
+
+
+def parse_ser_results_fp(fp, fp_type="gt", ignore_background=True):
+ # img/zh_val_0.jpg {
+ # "height": 3508,
+ # "width": 2480,
+ # "ocr_info": [
+ # {"text": "Maribyrnong", "label": "other", "bbox": [1958, 144, 2184, 198]},
+ # {"text": "CITYCOUNCIL", "label": "other", "bbox": [2052, 183, 2171, 214]},
+ # ]
+ assert fp_type in ["gt", "pred"]
+ key = "label" if fp_type == "gt" else "pred"
+ res_dict = dict()
+ with open(fp, "r") as fin:
+ lines = fin.readlines()
+
+ for _, line in enumerate(lines):
+ img_path, info = line.strip().split("\t")
+ # get key
+ image_name = os.path.basename(img_path)
+ res_dict[image_name] = []
+ # get infos
+ json_info = json.loads(info)
+ for single_ocr_info in json_info["ocr_info"]:
+ label = single_ocr_info[key].upper()
+ if label in ["O", "OTHERS", "OTHER"]:
+ label = "O"
+ if ignore_background and label == "O":
+ continue
+ single_ocr_info["label"] = label
+ res_dict[image_name].append(copy.deepcopy(single_ocr_info))
+ return res_dict
+
+
+def polygon_from_str(polygon_points):
+ """
+ Create a shapely polygon object from gt or dt line.
+ """
+ polygon_points = np.array(polygon_points).reshape(4, 2)
+ polygon = Polygon(polygon_points).convex_hull
+ return polygon
+
+
+def polygon_iou(poly1, poly2):
+ """
+ Intersection over union between two shapely polygons.
+ """
+ if not poly1.intersects(
+ poly2): # this test is fast and can accelerate calculation
+ iou = 0
+ else:
+ try:
+ inter_area = poly1.intersection(poly2).area
+ union_area = poly1.area + poly2.area - inter_area
+ iou = float(inter_area) / union_area
+ except shapely.geos.TopologicalError:
+ # except Exception as e:
+ # print(e)
+ print('shapely.geos.TopologicalError occured, iou set to 0')
+ iou = 0
+ return iou
+
+
+def ed(args, str1, str2):
+ if args.ignore_space:
+ str1 = str1.replace(" ", "")
+ str2 = str2.replace(" ", "")
+ if args.ignore_case:
+ str1 = str1.lower()
+ str2 = str2.lower()
+ return editdistance.eval(str1, str2)
+
+
+def convert_bbox_to_polygon(bbox):
+ """
+ bbox : [x1, y1, x2, y2]
+ output: [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
+ """
+ xmin, ymin, xmax, ymax = bbox
+ poly = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
+ return poly
+
+
+def eval_e2e(args):
+ # gt
+ gt_results = parse_ser_results_fp(args.gt_json_path, "gt",
+ args.ignore_background)
+ # pred
+ dt_results = parse_ser_results_fp(args.pred_json_path, "pred",
+ args.ignore_background)
+ assert set(gt_results.keys()) == set(dt_results.keys())
+
+ iou_thresh = args.iou_thres
+ num_gt_chars = 0
+ gt_count = 0
+ dt_count = 0
+ hit = 0
+ ed_sum = 0
+
+ for img_name in gt_results:
+ gt_info = gt_results[img_name]
+ gt_count += len(gt_info)
+
+ dt_info = dt_results[img_name]
+ dt_count += len(dt_info)
+
+ dt_match = [False] * len(dt_info)
+ gt_match = [False] * len(gt_info)
+
+ all_ious = defaultdict(tuple)
+ # gt: {text, label, bbox or poly}
+ for index_gt, gt in enumerate(gt_info):
+ if "poly" not in gt:
+ gt["poly"] = convert_bbox_to_polygon(gt["bbox"])
+ gt_poly = polygon_from_str(gt["poly"])
+ for index_dt, dt in enumerate(dt_info):
+ if "poly" not in dt:
+ dt["poly"] = convert_bbox_to_polygon(dt["bbox"])
+ dt_poly = polygon_from_str(dt["poly"])
+ iou = polygon_iou(dt_poly, gt_poly)
+ if iou >= iou_thresh:
+ all_ious[(index_gt, index_dt)] = iou
+ sorted_ious = sorted(
+ all_ious.items(), key=operator.itemgetter(1), reverse=True)
+ sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
+
+ # matched gt and dt
+ for gt_dt_pair in sorted_gt_dt_pairs:
+ index_gt, index_dt = gt_dt_pair
+ if gt_match[index_gt] == False and dt_match[index_dt] == False:
+ gt_match[index_gt] = True
+ dt_match[index_dt] = True
+ # ocr rec results
+ gt_text = gt_info[index_gt]["text"]
+ dt_text = dt_info[index_dt]["text"]
+
+ # ser results
+ gt_label = gt_info[index_gt]["label"]
+ dt_label = dt_info[index_dt]["pred"]
+
+ if True: # ignore_masks[index_gt] == '0':
+ ed_sum += ed(args, gt_text, dt_text)
+ num_gt_chars += len(gt_text)
+ if gt_text == dt_text:
+ if args.ignore_ser_prediction or gt_label == dt_label:
+ hit += 1
+
+# unmatched dt
+ for tindex, dt_match_flag in enumerate(dt_match):
+ if dt_match_flag == False:
+ dt_text = dt_info[tindex]["text"]
+ gt_text = ""
+ ed_sum += ed(args, dt_text, gt_text)
+
+# unmatched gt
+ for tindex, gt_match_flag in enumerate(gt_match):
+ if gt_match_flag == False:
+ dt_text = ""
+ gt_text = gt_info[tindex]["text"]
+ ed_sum += ed(args, gt_text, dt_text)
+ num_gt_chars += len(gt_text)
+
+ eps = 1e-9
+ print("config: ", args)
+ print('hit, dt_count, gt_count', hit, dt_count, gt_count)
+ precision = hit / (dt_count + eps)
+ recall = hit / (gt_count + eps)
+ fmeasure = 2.0 * precision * recall / (precision + recall + eps)
+ avg_edit_dist_img = ed_sum / len(gt_results)
+ avg_edit_dist_field = ed_sum / (gt_count + eps)
+ character_acc = 1 - ed_sum / (num_gt_chars + eps)
+
+ print('character_acc: %.2f' % (character_acc * 100) + "%")
+ print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
+ print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
+ print('precision: %.2f' % (precision * 100) + "%")
+ print('recall: %.2f' % (recall * 100) + "%")
+ print('fmeasure: %.2f' % (fmeasure * 100) + "%")
+
+ return
+
+
+def parse_args():
+ """
+ """
+
+ def str2bool(v):
+ return v.lower() in ("true", "t", "1")
+
+ parser = argparse.ArgumentParser()
+ ## Required parameters
+ parser.add_argument(
+ "--gt_json_path",
+ default=None,
+ type=str,
+ required=True, )
+ parser.add_argument(
+ "--pred_json_path",
+ default=None,
+ type=str,
+ required=True, )
+
+ parser.add_argument("--iou_thres", default=0.5, type=float)
+
+ parser.add_argument(
+ "--ignore_case",
+ default=False,
+ type=str2bool,
+ help="whether to do lower case for the strs")
+
+ parser.add_argument(
+ "--ignore_space",
+ default=True,
+ type=str2bool,
+ help="whether to ignore space")
+
+ parser.add_argument(
+ "--ignore_background",
+ default=True,
+ type=str2bool,
+ help="whether to ignore other label")
+
+ parser.add_argument(
+ "--ignore_ser_prediction",
+ default=False,
+ type=str2bool,
+ help="whether to ignore ocr pred results")
+
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ eval_e2e(args)
diff --git a/ppstructure/vqa/helper/trans_xfun_data.py b/ppstructure/vqa/helper/trans_xfun_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5ebd5dfbd8addda0701a7cfd2387133f7a8776b
--- /dev/null
+++ b/ppstructure/vqa/helper/trans_xfun_data.py
@@ -0,0 +1,52 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+
+def transfer_xfun_data(json_path=None, output_file=None):
+ with open(json_path, "r") as fin:
+ lines = fin.readlines()
+
+ json_info = json.loads(lines[0])
+ documents = json_info["documents"]
+ label_info = {}
+ with open(output_file, "w") as fout:
+ for idx, document in enumerate(documents):
+ img_info = document["img"]
+ document = document["document"]
+ image_path = img_info["fname"]
+
+ label_info["height"] = img_info["height"]
+ label_info["width"] = img_info["width"]
+
+ label_info["ocr_info"] = []
+
+ for doc in document:
+ label_info["ocr_info"].append({
+ "text": doc["text"],
+ "label": doc["label"],
+ "bbox": doc["box"],
+ "id": doc["id"],
+ "linking": doc["linking"],
+ "words": doc["words"]
+ })
+
+ fout.write(image_path + "\t" + json.dumps(
+ label_info, ensure_ascii=False) + "\n")
+
+ print("===ok====")
+
+
+transfer_xfun_data("./xfun/zh.val.json", "./xfun_normalize_val.json")
diff --git a/ppstructure/vqa/images/input/zh_val_0.jpg b/ppstructure/vqa/images/input/zh_val_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..479b60bcd3a859b187ce5325dfc381c1b87ee27f
Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_0.jpg differ
diff --git a/ppstructure/vqa/images/input/zh_val_42.jpg b/ppstructure/vqa/images/input/zh_val_42.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..42151bdd94929ede9da1a63ce8d9339971094a46
Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_42.jpg differ
diff --git a/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg b/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..22ba9a6f1b7652ca9ce6848093c7a39affb4886b
Binary files /dev/null and b/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg differ
diff --git a/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg b/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..951864e5f35a987ff241f276c8da523d8c8eeaf3
Binary files /dev/null and b/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg differ
diff --git a/ppstructure/vqa/infer_ser.py b/ppstructure/vqa/infer_ser.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ad220094a26b330555fbe9122a46fb56e64fe1e
--- /dev/null
+++ b/ppstructure/vqa/infer_ser.py
@@ -0,0 +1,279 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import json
+import cv2
+import numpy as np
+from copy import deepcopy
+
+import paddle
+
+# relative reference
+from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
+
+
+def pad_sentences(tokenizer,
+ encoded_inputs,
+ max_seq_len=512,
+ pad_to_max_seq_len=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False):
+ # Padding with larger size, reshape is carried out
+ max_seq_len = (
+ len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
+
+ needs_to_be_padded = pad_to_max_seq_len and \
+ max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
+
+ if needs_to_be_padded:
+ difference = max_seq_len - len(encoded_inputs["input_ids"])
+ if tokenizer.padding_side == 'right':
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"]) + [0] * difference
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] +
+ [tokenizer.pad_token_type_id] * difference)
+ if return_special_tokens_mask:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs[
+ "special_tokens_mask"] + [1] * difference
+ encoded_inputs["input_ids"] = encoded_inputs[
+ "input_ids"] + [tokenizer.pad_token_id] * difference
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
+ ] * difference
+ else:
+ assert False, f"padding_side of tokenizer just supports [\"right\"] but got {tokenizer.padding_side}"
+ else:
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"])
+
+ return encoded_inputs
+
+
+def split_page(encoded_inputs, max_seq_len=512):
+ """
+ truncate is often used in training process
+ """
+ for key in encoded_inputs:
+ encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
+ if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
+ encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
+ else: # for bbox
+ encoded_inputs[key] = encoded_inputs[key].reshape(
+ [-1, max_seq_len, 4])
+ return encoded_inputs
+
+
+def preprocess(
+ tokenizer,
+ ori_img,
+ ocr_info,
+ img_size=(224, 224),
+ pad_token_label_id=-100,
+ max_seq_len=512,
+ add_special_ids=False,
+ return_attention_mask=True, ):
+ ocr_info = deepcopy(ocr_info)
+ height = ori_img.shape[0]
+ width = ori_img.shape[1]
+
+ img = cv2.resize(ori_img,
+ (224, 224)).transpose([2, 0, 1]).astype(np.float32)
+
+ segment_offset_id = []
+ words_list = []
+ bbox_list = []
+ input_ids_list = []
+ token_type_ids_list = []
+
+ for info in ocr_info:
+ # x1, y1, x2, y2
+ bbox = info["bbox"]
+ bbox[0] = int(bbox[0] * 1000.0 / width)
+ bbox[2] = int(bbox[2] * 1000.0 / width)
+ bbox[1] = int(bbox[1] * 1000.0 / height)
+ bbox[3] = int(bbox[3] * 1000.0 / height)
+
+ text = info["text"]
+ encode_res = tokenizer.encode(
+ text, pad_to_max_seq_len=False, return_attention_mask=True)
+
+ if not add_special_ids:
+ # TODO: use tok.all_special_ids to remove
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
+
+ input_ids_list.extend(encode_res["input_ids"])
+ token_type_ids_list.extend(encode_res["token_type_ids"])
+ bbox_list.extend([bbox] * len(encode_res["input_ids"]))
+ words_list.append(text)
+ segment_offset_id.append(len(input_ids_list))
+
+ encoded_inputs = {
+ "input_ids": input_ids_list,
+ "token_type_ids": token_type_ids_list,
+ "bbox": bbox_list,
+ "attention_mask": [1] * len(input_ids_list),
+ }
+
+ encoded_inputs = pad_sentences(
+ tokenizer,
+ encoded_inputs,
+ max_seq_len=max_seq_len,
+ return_attention_mask=return_attention_mask)
+
+ encoded_inputs = split_page(encoded_inputs)
+
+ fake_bs = encoded_inputs["input_ids"].shape[0]
+
+ encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
+ [fake_bs] + list(img.shape))
+
+ encoded_inputs["segment_offset_id"] = segment_offset_id
+
+ return encoded_inputs
+
+
+def postprocess(attention_mask, preds, label_map_path):
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ preds = np.argmax(preds, axis=2)
+
+ _, label_map = get_bio_label_maps(label_map_path)
+
+ preds_list = [[] for _ in range(preds.shape[0])]
+
+ # keep batch info
+ for i in range(preds.shape[0]):
+ for j in range(preds.shape[1]):
+ if attention_mask[i][j] == 1:
+ preds_list[i].append(label_map[preds[i][j]])
+
+ return preds_list
+
+
+def merge_preds_list_with_ocr_info(label_map_path, ocr_info, segment_offset_id,
+ preds_list):
+ # must ensure the preds_list is generated from the same image
+ preds = [p for pred in preds_list for p in pred]
+ label2id_map, _ = get_bio_label_maps(label_map_path)
+ for key in label2id_map:
+ if key.startswith("I-"):
+ label2id_map[key] = label2id_map["B" + key[1:]]
+
+ id2label_map = dict()
+ for key in label2id_map:
+ val = label2id_map[key]
+ if key == "O":
+ id2label_map[val] = key
+ if key.startswith("B-") or key.startswith("I-"):
+ id2label_map[val] = key[2:]
+ else:
+ id2label_map[val] = key
+
+ for idx in range(len(segment_offset_id)):
+ if idx == 0:
+ start_id = 0
+ else:
+ start_id = segment_offset_id[idx - 1]
+
+ end_id = segment_offset_id[idx]
+
+ curr_pred = preds[start_id:end_id]
+ curr_pred = [label2id_map[p] for p in curr_pred]
+
+ if len(curr_pred) <= 0:
+ pred_id = 0
+ else:
+ counts = np.bincount(curr_pred)
+ pred_id = np.argmax(counts)
+ ocr_info[idx]["pred_id"] = int(pred_id)
+ ocr_info[idx]["pred"] = id2label_map[pred_id]
+ return ocr_info
+
+
+@paddle.no_grad()
+def infer(args):
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # init token and model
+ tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
+ # model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
+ model = LayoutXLMForTokenClassification.from_pretrained(
+ args.model_name_or_path)
+ model.eval()
+
+ # load ocr results json
+ ocr_results = dict()
+ with open(args.ocr_json_path, "r") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ img_name, json_info = line.split("\t")
+ ocr_results[os.path.basename(img_name)] = json.loads(json_info)
+
+ # get infer img list
+ infer_imgs = get_image_file_list(args.infer_imgs)
+
+ # loop for infer
+ with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
+ for idx, img_path in enumerate(infer_imgs):
+ print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
+
+ img = cv2.imread(img_path)
+
+ ocr_info = ocr_results[os.path.basename(img_path)]["ocr_info"]
+ inputs = preprocess(
+ tokenizer=tokenizer,
+ ori_img=img,
+ ocr_info=ocr_info,
+ max_seq_len=args.max_seq_length)
+
+ outputs = model(
+ input_ids=inputs["input_ids"],
+ bbox=inputs["bbox"],
+ image=inputs["image"],
+ token_type_ids=inputs["token_type_ids"],
+ attention_mask=inputs["attention_mask"])
+
+ preds = outputs[0]
+ preds = postprocess(inputs["attention_mask"], preds,
+ args.label_map_path)
+ ocr_info = merge_preds_list_with_ocr_info(
+ args.label_map_path, ocr_info, inputs["segment_offset_id"],
+ preds)
+
+ fout.write(img_path + "\t" + json.dumps(
+ {
+ "ocr_info": ocr_info,
+ }, ensure_ascii=False) + "\n")
+
+ img_res = draw_ser_results(img, ocr_info)
+ cv2.imwrite(
+ os.path.join(args.output_dir, os.path.basename(img_path)),
+ img_res)
+
+ return
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ infer(args)
diff --git a/ppstructure/vqa/infer_ser_e2e.py b/ppstructure/vqa/infer_ser_e2e.py
new file mode 100644
index 0000000000000000000000000000000000000000..da027a140bdb4fa12a40d423998d94e438a7cd11
--- /dev/null
+++ b/ppstructure/vqa/infer_ser_e2e.py
@@ -0,0 +1,121 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import json
+import cv2
+import numpy as np
+from copy import deepcopy
+from PIL import Image
+
+import paddle
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
+
+# relative reference
+from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps, build_ocr_engine
+
+from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
+
+
+def trans_poly_to_bbox(poly):
+ x1 = np.min([p[0] for p in poly])
+ x2 = np.max([p[0] for p in poly])
+ y1 = np.min([p[1] for p in poly])
+ y2 = np.max([p[1] for p in poly])
+ return [x1, y1, x2, y2]
+
+
+def parse_ocr_info_for_ser(ocr_result):
+ ocr_info = []
+ for res in ocr_result:
+ ocr_info.append({
+ "text": res[1][0],
+ "bbox": trans_poly_to_bbox(res[0]),
+ "poly": res[0],
+ })
+ return ocr_info
+
+
+@paddle.no_grad()
+def infer(args):
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # init token and model
+ tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
+ model = LayoutXLMForTokenClassification.from_pretrained(
+ args.model_name_or_path)
+ model.eval()
+
+ label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
+ label2id_map_for_draw = dict()
+ for key in label2id_map:
+ if key.startswith("I-"):
+ label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
+ else:
+ label2id_map_for_draw[key] = label2id_map[key]
+
+ # get infer img list
+ infer_imgs = get_image_file_list(args.infer_imgs)
+
+ ocr_engine = build_ocr_engine(args.ocr_rec_model_dir,
+ args.ocr_det_model_dir)
+
+ # loop for infer
+ with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
+ for idx, img_path in enumerate(infer_imgs):
+ print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
+
+ img = cv2.imread(img_path)
+
+ ocr_result = ocr_engine.ocr(img_path, cls=False)
+
+ ocr_info = parse_ocr_info_for_ser(ocr_result)
+
+ inputs = preprocess(
+ tokenizer=tokenizer,
+ ori_img=img,
+ ocr_info=ocr_info,
+ max_seq_len=args.max_seq_length)
+
+ outputs = model(
+ input_ids=inputs["input_ids"],
+ bbox=inputs["bbox"],
+ image=inputs["image"],
+ token_type_ids=inputs["token_type_ids"],
+ attention_mask=inputs["attention_mask"])
+
+ preds = outputs[0]
+ preds = postprocess(inputs["attention_mask"], preds, id2label_map)
+ ocr_info = merge_preds_list_with_ocr_info(
+ ocr_info, inputs["segment_offset_id"], preds,
+ label2id_map_for_draw)
+
+ fout.write(img_path + "\t" + json.dumps(
+ {
+ "ocr_info": ocr_info,
+ }, ensure_ascii=False) + "\n")
+
+ img_res = draw_ser_results(img, ocr_info)
+ cv2.imwrite(
+ os.path.join(args.output_dir,
+ os.path.splitext(os.path.basename(img_path))[0] +
+ "_ser.jpg"), img_res)
+
+ return
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ infer(args)
diff --git a/ppstructure/vqa/labels/labels_ser.txt b/ppstructure/vqa/labels/labels_ser.txt
new file mode 100644
index 0000000000000000000000000000000000000000..508e48112412f62538baf0c78bcf99ec8945196e
--- /dev/null
+++ b/ppstructure/vqa/labels/labels_ser.txt
@@ -0,0 +1,3 @@
+QUESTION
+ANSWER
+HEADER
diff --git a/ppstructure/vqa/requirements.txt b/ppstructure/vqa/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c259fadc395335b336cb0ecdb5aa6bca48631987
--- /dev/null
+++ b/ppstructure/vqa/requirements.txt
@@ -0,0 +1,2 @@
+sentencepiece
+yacs
diff --git a/ppstructure/vqa/train_ser.py b/ppstructure/vqa/train_ser.py
new file mode 100644
index 0000000000000000000000000000000000000000..90ca69d93fd22983533fcacd639bbd64dc3e11ec
--- /dev/null
+++ b/ppstructure/vqa/train_ser.py
@@ -0,0 +1,313 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os
+import random
+import copy
+import logging
+
+import argparse
+import paddle
+import numpy as np
+from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
+from xfun import XFUNDataset
+from utils import parse_args
+from utils import get_bio_label_maps
+
+logger = logging.getLogger(__name__)
+
+
+def set_seed(args):
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ paddle.seed(args.seed)
+
+
+def train(args):
+ os.makedirs(args.output_dir, exist_ok=True)
+ logging.basicConfig(
+ filename=os.path.join(args.output_dir, "train.log")
+ if paddle.distributed.get_rank() == 0 else None,
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO
+ if paddle.distributed.get_rank() == 0 else logging.WARN, )
+
+ ch = logging.StreamHandler()
+ ch.setLevel(logging.DEBUG)
+ logger.addHandler(ch)
+
+ label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
+ pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+
+ # dist mode
+ if paddle.distributed.get_world_size() > 1:
+ paddle.distributed.init_parallel_env()
+
+ tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
+ base_model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
+ model = LayoutXLMForTokenClassification(
+ base_model, num_classes=len(label2id_map), dropout=None)
+
+ # dist mode
+ if paddle.distributed.get_world_size() > 1:
+ model = paddle.DataParallel(model)
+
+ train_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.train_data_dir,
+ label_path=args.train_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ pad_token_label_id=pad_token_label_id,
+ contains_re=False,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ train_sampler = paddle.io.DistributedBatchSampler(
+ train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
+
+ args.train_batch_size = args.per_gpu_train_batch_size * max(
+ 1, paddle.distributed.get_world_size())
+
+ train_dataloader = paddle.io.DataLoader(
+ train_dataset,
+ batch_sampler=train_sampler,
+ num_workers=0,
+ use_shared_memory=True,
+ collate_fn=None, )
+
+ t_total = len(train_dataloader) * args.num_train_epochs
+
+ # build linear decay with warmup lr sch
+ lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
+ learning_rate=args.learning_rate,
+ decay_steps=t_total,
+ end_lr=0.0,
+ power=1.0)
+ if args.warmup_steps > 0:
+ lr_scheduler = paddle.optimizer.lr.LinearWarmup(
+ lr_scheduler,
+ args.warmup_steps,
+ start_lr=0,
+ end_lr=args.learning_rate, )
+
+ optimizer = paddle.optimizer.AdamW(
+ learning_rate=lr_scheduler,
+ parameters=model.parameters(),
+ epsilon=args.adam_epsilon,
+ weight_decay=args.weight_decay)
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(" Num examples = %d", len(train_dataset))
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
+ logger.info(" Instantaneous batch size per GPU = %d",
+ args.per_gpu_train_batch_size)
+ logger.info(
+ " Total train batch size (w. parallel, distributed) = %d",
+ args.train_batch_size * paddle.distributed.get_world_size(), )
+ logger.info(" Total optimization steps = %d", t_total)
+
+ global_step = 0
+ tr_loss = 0.0
+ set_seed(args)
+ best_metrics = None
+
+ for epoch_id in range(args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ model.train()
+ outputs = model(**batch)
+ # model outputs are always tuple in ppnlp (see doc)
+ loss = outputs[0]
+ loss = loss.mean()
+ logger.info(
+ "[epoch {}/{}][iter: {}/{}] lr: {:.5f}, train loss: {:.5f}, ".
+ format(epoch_id, args.num_train_epochs, step,
+ len(train_dataloader),
+ lr_scheduler.get_lr(), loss.numpy()[0]))
+
+ loss.backward()
+ tr_loss += loss.item()
+ optimizer.step()
+ lr_scheduler.step() # Update learning rate schedule
+ optimizer.clear_grad()
+ global_step += 1
+
+ if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and
+ global_step % args.eval_steps == 0):
+ # Log metrics
+ # Only evaluate when single GPU otherwise metrics may not average well
+ if paddle.distributed.get_rank(
+ ) == 0 and args.evaluate_during_training:
+ results, _ = evaluate(
+ args,
+ model,
+ tokenizer,
+ label2id_map,
+ id2label_map,
+ pad_token_label_id, )
+
+ if best_metrics is None or results["f1"] >= best_metrics[
+ "f1"]:
+ best_metrics = copy.deepcopy(results)
+ output_dir = os.path.join(args.output_dir, "best_model")
+ os.makedirs(output_dir, exist_ok=True)
+ if paddle.distributed.get_rank() == 0:
+ model.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ paddle.save(
+ args,
+ os.path.join(output_dir, "training_args.bin"))
+ logger.info("Saving model checkpoint to %s",
+ output_dir)
+
+ logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format(
+ epoch_id, args.num_train_epochs, step,
+ len(train_dataloader), results))
+ if best_metrics is not None:
+ logger.info("best metrics: {}".format(best_metrics))
+
+ if paddle.distributed.get_rank(
+ ) == 0 and args.save_steps > 0 and global_step % args.save_steps == 0:
+ # Save model checkpoint
+ output_dir = os.path.join(args.output_dir,
+ "checkpoint-{}".format(global_step))
+ os.makedirs(output_dir, exist_ok=True)
+ if paddle.distributed.get_rank() == 0:
+ model.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ paddle.save(args,
+ os.path.join(output_dir, "training_args.bin"))
+ logger.info("Saving model checkpoint to %s", output_dir)
+
+ return global_step, tr_loss / global_step
+
+
+def evaluate(args,
+ model,
+ tokenizer,
+ label2id_map,
+ id2label_map,
+ pad_token_label_id,
+ prefix=""):
+ eval_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.eval_data_dir,
+ label_path=args.eval_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ pad_token_label_id=pad_token_label_id,
+ contains_re=False,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(
+ 1, paddle.distributed.get_world_size())
+
+ eval_dataloader = paddle.io.DataLoader(
+ eval_dataset,
+ batch_size=args.eval_batch_size,
+ num_workers=0,
+ use_shared_memory=True,
+ collate_fn=None, )
+
+ # Eval!
+ logger.info("***** Running evaluation %s *****", prefix)
+ logger.info(" Num examples = %d", len(eval_dataset))
+ logger.info(" Batch size = %d", args.eval_batch_size)
+ eval_loss = 0.0
+ nb_eval_steps = 0
+ preds = None
+ out_label_ids = None
+ model.eval()
+ for idx, batch in enumerate(eval_dataloader):
+ with paddle.no_grad():
+ outputs = model(**batch)
+ tmp_eval_loss, logits = outputs[:2]
+
+ tmp_eval_loss = tmp_eval_loss.mean()
+
+ if paddle.distributed.get_rank() == 0:
+ logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
+ idx, len(eval_dataloader), tmp_eval_loss.numpy()[0]))
+
+ eval_loss += tmp_eval_loss.item()
+ nb_eval_steps += 1
+ if preds is None:
+ preds = logits.numpy()
+ out_label_ids = batch["labels"].numpy()
+ else:
+ preds = np.append(preds, logits.numpy(), axis=0)
+ out_label_ids = np.append(
+ out_label_ids, batch["labels"].numpy(), axis=0)
+
+ eval_loss = eval_loss / nb_eval_steps
+ preds = np.argmax(preds, axis=2)
+
+ # label_map = {i: label.upper() for i, label in enumerate(labels)}
+
+ out_label_list = [[] for _ in range(out_label_ids.shape[0])]
+ preds_list = [[] for _ in range(out_label_ids.shape[0])]
+
+ for i in range(out_label_ids.shape[0]):
+ for j in range(out_label_ids.shape[1]):
+ if out_label_ids[i, j] != pad_token_label_id:
+ out_label_list[i].append(id2label_map[out_label_ids[i][j]])
+ preds_list[i].append(id2label_map[preds[i][j]])
+
+ results = {
+ "loss": eval_loss,
+ "precision": precision_score(out_label_list, preds_list),
+ "recall": recall_score(out_label_list, preds_list),
+ "f1": f1_score(out_label_list, preds_list),
+ }
+
+ with open(os.path.join(args.output_dir, "test_gt.txt"), "w") as fout:
+ for lbl in out_label_list:
+ for l in lbl:
+ fout.write(l + "\t")
+ fout.write("\n")
+ with open(os.path.join(args.output_dir, "test_pred.txt"), "w") as fout:
+ for lbl in preds_list:
+ for l in lbl:
+ fout.write(l + "\t")
+ fout.write("\n")
+
+ report = classification_report(out_label_list, preds_list)
+ logger.info("\n" + report)
+
+ logger.info("***** Eval results %s *****", prefix)
+ for key in sorted(results.keys()):
+ logger.info(" %s = %s", key, str(results[key]))
+
+ return results, preds_list
+
+
+def print_arguments(args):
+ """print arguments"""
+ print('----------- Configuration Arguments -----------')
+ for arg, value in sorted(vars(args).items()):
+ print('%s: %s' % (arg, value))
+ print('------------------------------------------------')
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ print_arguments(args)
+ train(args)
diff --git a/ppstructure/vqa/utils.py b/ppstructure/vqa/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4ac1e77d37d0a662294480a393c2f67e7f4cc64
--- /dev/null
+++ b/ppstructure/vqa/utils.py
@@ -0,0 +1,328 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import cv2
+import random
+import numpy as np
+import imghdr
+from copy import deepcopy
+
+import paddle
+
+from PIL import Image, ImageDraw, ImageFont
+
+from paddleocr import PaddleOCR
+
+
+def get_bio_label_maps(label_map_path):
+ with open(label_map_path, "r") as fin:
+ lines = fin.readlines()
+ lines = [line.strip() for line in lines]
+ if "O" not in lines:
+ lines.insert(0, "O")
+ labels = []
+ for line in lines:
+ if line == "O":
+ labels.append("O")
+ else:
+ labels.append("B-" + line)
+ labels.append("I-" + line)
+ label2id_map = {label: idx for idx, label in enumerate(labels)}
+ id2label_map = {idx: label for idx, label in enumerate(labels)}
+ return label2id_map, id2label_map
+
+
+def get_image_file_list(img_file):
+ imgs_lists = []
+ if img_file is None or not os.path.exists(img_file):
+ raise Exception("not found any img file in {}".format(img_file))
+
+ img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
+ if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
+ imgs_lists.append(img_file)
+ elif os.path.isdir(img_file):
+ for single_file in os.listdir(img_file):
+ file_path = os.path.join(img_file, single_file)
+ if os.path.isfile(file_path) and imghdr.what(file_path) in img_end:
+ imgs_lists.append(file_path)
+ if len(imgs_lists) == 0:
+ raise Exception("not found any img file in {}".format(img_file))
+ imgs_lists = sorted(imgs_lists)
+ return imgs_lists
+
+
+def draw_ser_results(image,
+ ocr_results,
+ font_path="../doc/fonts/simfang.ttf",
+ font_size=18):
+ np.random.seed(0)
+ color = (np.random.permutation(range(255)),
+ np.random.permutation(range(255)),
+ np.random.permutation(range(255)))
+ color_map = {
+ idx: (color[0][idx], color[1][idx], color[2][idx])
+ for idx in range(1, 255)
+ }
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ img_new = image.copy()
+ draw = ImageDraw.Draw(img_new)
+
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+
+ for ocr_info in ocr_results:
+ if ocr_info["pred_id"] not in color_map:
+ continue
+ color = color_map[ocr_info["pred_id"]]
+
+ # draw ocr results outline
+ bbox = ocr_info["bbox"]
+ bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
+ draw.rectangle(bbox, fill=color)
+
+ # draw ocr results
+ text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
+ start_y = max(0, bbox[0][1] - font_size)
+ tw = font.getsize(text)[0]
+ draw.rectangle(
+ [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1,
+ start_y + font_size)],
+ fill=(0, 0, 255))
+ draw.text(
+ (bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
+
+ img_new = Image.blend(image, img_new, 0.5)
+ return np.array(img_new)
+
+
+def build_ocr_engine(rec_model_dir, det_model_dir):
+ ocr_engine = PaddleOCR(
+ rec_model_dir=rec_model_dir,
+ det_model_dir=det_model_dir,
+ use_angle_cls=False)
+ return ocr_engine
+
+
+# pad sentences
+def pad_sentences(tokenizer,
+ encoded_inputs,
+ max_seq_len=512,
+ pad_to_max_seq_len=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False):
+ # Padding with larger size, reshape is carried out
+ max_seq_len = (
+ len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
+
+ needs_to_be_padded = pad_to_max_seq_len and \
+ max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
+
+ if needs_to_be_padded:
+ difference = max_seq_len - len(encoded_inputs["input_ids"])
+ if tokenizer.padding_side == 'right':
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"]) + [0] * difference
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] +
+ [tokenizer.pad_token_type_id] * difference)
+ if return_special_tokens_mask:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs[
+ "special_tokens_mask"] + [1] * difference
+ encoded_inputs["input_ids"] = encoded_inputs[
+ "input_ids"] + [tokenizer.pad_token_id] * difference
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
+ ] * difference
+ else:
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"])
+
+ return encoded_inputs
+
+
+def split_page(encoded_inputs, max_seq_len=512):
+ """
+ truncate is often used in training process
+ """
+ for key in encoded_inputs:
+ encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
+ if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
+ encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
+ else: # for bbox
+ encoded_inputs[key] = encoded_inputs[key].reshape(
+ [-1, max_seq_len, 4])
+ return encoded_inputs
+
+
+def preprocess(
+ tokenizer,
+ ori_img,
+ ocr_info,
+ img_size=(224, 224),
+ pad_token_label_id=-100,
+ max_seq_len=512,
+ add_special_ids=False,
+ return_attention_mask=True, ):
+ ocr_info = deepcopy(ocr_info)
+ height = ori_img.shape[0]
+ width = ori_img.shape[1]
+
+ img = cv2.resize(ori_img,
+ (224, 224)).transpose([2, 0, 1]).astype(np.float32)
+
+ segment_offset_id = []
+ words_list = []
+ bbox_list = []
+ input_ids_list = []
+ token_type_ids_list = []
+
+ for info in ocr_info:
+ # x1, y1, x2, y2
+ bbox = info["bbox"]
+ bbox[0] = int(bbox[0] * 1000.0 / width)
+ bbox[2] = int(bbox[2] * 1000.0 / width)
+ bbox[1] = int(bbox[1] * 1000.0 / height)
+ bbox[3] = int(bbox[3] * 1000.0 / height)
+
+ text = info["text"]
+ encode_res = tokenizer.encode(
+ text, pad_to_max_seq_len=False, return_attention_mask=True)
+
+ if not add_special_ids:
+ # TODO: use tok.all_special_ids to remove
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
+
+ input_ids_list.extend(encode_res["input_ids"])
+ token_type_ids_list.extend(encode_res["token_type_ids"])
+ bbox_list.extend([bbox] * len(encode_res["input_ids"]))
+ words_list.append(text)
+ segment_offset_id.append(len(input_ids_list))
+
+ encoded_inputs = {
+ "input_ids": input_ids_list,
+ "token_type_ids": token_type_ids_list,
+ "bbox": bbox_list,
+ "attention_mask": [1] * len(input_ids_list),
+ }
+
+ encoded_inputs = pad_sentences(
+ tokenizer,
+ encoded_inputs,
+ max_seq_len=max_seq_len,
+ return_attention_mask=return_attention_mask)
+
+ encoded_inputs = split_page(encoded_inputs)
+
+ fake_bs = encoded_inputs["input_ids"].shape[0]
+
+ encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
+ [fake_bs] + list(img.shape))
+
+ encoded_inputs["segment_offset_id"] = segment_offset_id
+
+ return encoded_inputs
+
+
+def postprocess(attention_mask, preds, id2label_map):
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ preds = np.argmax(preds, axis=2)
+
+ preds_list = [[] for _ in range(preds.shape[0])]
+
+ # keep batch info
+ for i in range(preds.shape[0]):
+ for j in range(preds.shape[1]):
+ if attention_mask[i][j] == 1:
+ preds_list[i].append(id2label_map[preds[i][j]])
+
+ return preds_list
+
+
+def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list,
+ label2id_map_for_draw):
+ # must ensure the preds_list is generated from the same image
+ preds = [p for pred in preds_list for p in pred]
+
+ id2label_map = dict()
+ for key in label2id_map_for_draw:
+ val = label2id_map_for_draw[key]
+ if key == "O":
+ id2label_map[val] = key
+ if key.startswith("B-") or key.startswith("I-"):
+ id2label_map[val] = key[2:]
+ else:
+ id2label_map[val] = key
+
+ for idx in range(len(segment_offset_id)):
+ if idx == 0:
+ start_id = 0
+ else:
+ start_id = segment_offset_id[idx - 1]
+
+ end_id = segment_offset_id[idx]
+
+ curr_pred = preds[start_id:end_id]
+ curr_pred = [label2id_map_for_draw[p] for p in curr_pred]
+
+ if len(curr_pred) <= 0:
+ pred_id = 0
+ else:
+ counts = np.bincount(curr_pred)
+ pred_id = np.argmax(counts)
+ ocr_info[idx]["pred_id"] = int(pred_id)
+ ocr_info[idx]["pred"] = id2label_map[int(pred_id)]
+ return ocr_info
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ # yapf: disable
+ parser.add_argument("--model_name_or_path", default=None, type=str, required=True,)
+ parser.add_argument("--train_data_dir", default=None, type=str, required=False,)
+ parser.add_argument("--train_label_path", default=None, type=str, required=False,)
+ parser.add_argument("--eval_data_dir", default=None, type=str, required=False,)
+ parser.add_argument("--eval_label_path", default=None, type=str, required=False,)
+ parser.add_argument("--output_dir", default=None, type=str, required=True,)
+ parser.add_argument("--max_seq_length", default=512, type=int,)
+ parser.add_argument("--evaluate_during_training", action="store_true",)
+ parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",)
+ parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for eval.",)
+ parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.",)
+ parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.",)
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.",)
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.",)
+ parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.",)
+ parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.",)
+ parser.add_argument("--eval_steps", type=int, default=10, help="eval every X updates steps.",)
+ parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.",)
+ parser.add_argument("--seed", type=int, default=2048, help="random seed for initialization",)
+
+ parser.add_argument("--ocr_rec_model_dir", default=None, type=str, )
+ parser.add_argument("--ocr_det_model_dir", default=None, type=str, )
+ parser.add_argument("--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
+ parser.add_argument("--infer_imgs", default=None, type=str, required=False)
+ parser.add_argument("--ocr_json_path", default=None, type=str, required=False, help="ocr prediction results")
+ # yapf: enable
+ args = parser.parse_args()
+ return args
diff --git a/ppstructure/vqa/xfun.py b/ppstructure/vqa/xfun.py
new file mode 100644
index 0000000000000000000000000000000000000000..d62cdb5da5514280b62687d80d345ede9484ee90
--- /dev/null
+++ b/ppstructure/vqa/xfun.py
@@ -0,0 +1,442 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import cv2
+import numpy as np
+import paddle
+import copy
+from paddle.io import Dataset
+
+__all__ = ["XFUNDataset"]
+
+
+class XFUNDataset(Dataset):
+ """
+ Example:
+ print("=====begin to build dataset=====")
+ from paddlenlp.transformers import LayoutXLMTokenizer
+ tokenizer = LayoutXLMTokenizer.from_pretrained("/paddle/models/transformers/layoutxlm-base-paddle/")
+ tok_res = tokenizer.tokenize("Maribyrnong")
+ # res = tokenizer.convert_ids_to_tokens(val_data["input_ids"][0])
+ dataset = XfunDatasetForSer(
+ tokenizer,
+ data_dir="./zh.val/",
+ label_path="zh.val/xfun_normalize_val.json",
+ img_size=(224,224))
+ print(len(dataset))
+
+ data = dataset[0]
+ print(data.keys())
+ print("input_ids: ", data["input_ids"])
+ print("labels: ", data["labels"])
+ print("token_type_ids: ", data["token_type_ids"])
+ print("words_list: ", data["words_list"])
+ print("image shape: ", data["image"].shape)
+ """
+
+ def __init__(self,
+ tokenizer,
+ data_dir,
+ label_path,
+ contains_re=False,
+ label2id_map=None,
+ img_size=(224, 224),
+ pad_token_label_id=None,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all',
+ max_seq_len=512):
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.data_dir = data_dir
+ self.label_path = label_path
+ self.contains_re = contains_re
+ self.label2id_map = label2id_map
+ self.img_size = img_size
+ self.pad_token_label_id = pad_token_label_id
+ self.add_special_ids = add_special_ids
+ self.return_attention_mask = return_attention_mask
+ self.load_mode = load_mode
+ self.max_seq_len = max_seq_len
+
+ if self.pad_token_label_id is None:
+ self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+
+ self.all_lines = self.read_all_lines()
+
+ self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
+ self.return_keys = {
+ 'bbox': 'np',
+ 'input_ids': 'np',
+ 'labels': 'np',
+ 'attention_mask': 'np',
+ 'image': 'np',
+ 'token_type_ids': 'np',
+ 'entities': 'dict',
+ 'relations': 'dict',
+ }
+
+ if load_mode == "all":
+ self.encoded_inputs_all = self._parse_label_file_all()
+
+ def pad_sentences(self,
+ encoded_inputs,
+ max_seq_len=512,
+ pad_to_max_seq_len=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ truncation_strategy="longest_first",
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False):
+ # Padding
+ needs_to_be_padded = pad_to_max_seq_len and \
+ max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
+
+ if needs_to_be_padded:
+ difference = max_seq_len - len(encoded_inputs["input_ids"])
+ if self.tokenizer.padding_side == 'right':
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"]) + [0] * difference
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] +
+ [self.tokenizer.pad_token_type_id] * difference)
+ if return_special_tokens_mask:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs[
+ "special_tokens_mask"] + [1] * difference
+ encoded_inputs["input_ids"] = encoded_inputs[
+ "input_ids"] + [self.tokenizer.pad_token_id] * difference
+ encoded_inputs["labels"] = encoded_inputs[
+ "labels"] + [self.pad_token_label_id] * difference
+ encoded_inputs["bbox"] = encoded_inputs[
+ "bbox"] + [[0, 0, 0, 0]] * difference
+ elif self.tokenizer.padding_side == 'left':
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [0] * difference + [
+ 1
+ ] * len(encoded_inputs["input_ids"])
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = (
+ [self.tokenizer.pad_token_type_id] * difference +
+ encoded_inputs["token_type_ids"])
+ if return_special_tokens_mask:
+ encoded_inputs["special_tokens_mask"] = [
+ 1
+ ] * difference + encoded_inputs["special_tokens_mask"]
+ encoded_inputs["input_ids"] = [
+ self.tokenizer.pad_token_id
+ ] * difference + encoded_inputs["input_ids"]
+ encoded_inputs["labels"] = [
+ self.pad_token_label_id
+ ] * difference + encoded_inputs["labels"]
+ encoded_inputs["bbox"] = [
+ [0, 0, 0, 0]
+ ] * difference + encoded_inputs["bbox"]
+ else:
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"])
+
+ return encoded_inputs
+
+ def truncate_inputs(self, encoded_inputs, max_seq_len=512):
+ for key in encoded_inputs:
+ if key == "sample_id":
+ continue
+ length = min(len(encoded_inputs[key]), max_seq_len)
+ encoded_inputs[key] = encoded_inputs[key][:length]
+ return encoded_inputs
+
+ def read_all_lines(self, ):
+ with open(self.label_path, "r") as fin:
+ lines = fin.readlines()
+ return lines
+
+ def _parse_label_file_all(self):
+ """
+ parse all samples
+ """
+ encoded_inputs_all = []
+ for line in self.all_lines:
+ encoded_inputs_all.extend(self._parse_label_file(line))
+ return encoded_inputs_all
+
+ def _parse_label_file(self, line):
+ """
+ parse single sample
+ """
+
+ image_name, info_str = line.split("\t")
+ image_path = os.path.join(self.data_dir, image_name)
+
+ def add_imgge_path(x):
+ x['image_path'] = image_path
+ return x
+
+ encoded_inputs = self._read_encoded_inputs_sample(info_str)
+ if self.contains_re:
+ encoded_inputs = self._chunk_re(encoded_inputs)
+ else:
+ encoded_inputs = self._chunk_ser(encoded_inputs)
+ encoded_inputs = list(map(add_imgge_path, encoded_inputs))
+ return encoded_inputs
+
+ def _read_encoded_inputs_sample(self, info_str):
+ """
+ parse label info
+ """
+ # read text info
+ info_dict = json.loads(info_str)
+ height = info_dict["height"]
+ width = info_dict["width"]
+
+ words_list = []
+ bbox_list = []
+ input_ids_list = []
+ token_type_ids_list = []
+ gt_label_list = []
+
+ if self.contains_re:
+ # for re
+ entities = []
+ relations = []
+ id2label = {}
+ entity_id_to_index_map = {}
+ empty_entity = set()
+ for info in info_dict["ocr_info"]:
+ if self.contains_re:
+ # for re
+ if len(info["text"]) == 0:
+ empty_entity.add(info["id"])
+ continue
+ id2label[info["id"]] = info["label"]
+ relations.extend([tuple(sorted(l)) for l in info["linking"]])
+
+ # x1, y1, x2, y2
+ bbox = info["bbox"]
+ label = info["label"]
+ bbox[0] = int(bbox[0] * 1000.0 / width)
+ bbox[2] = int(bbox[2] * 1000.0 / width)
+ bbox[1] = int(bbox[1] * 1000.0 / height)
+ bbox[3] = int(bbox[3] * 1000.0 / height)
+
+ text = info["text"]
+ encode_res = self.tokenizer.encode(
+ text, pad_to_max_seq_len=False, return_attention_mask=True)
+
+ gt_label = []
+ if not self.add_special_ids:
+ # TODO: use tok.all_special_ids to remove
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
+ -1]
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:
+ -1]
+ if label.lower() == "other":
+ gt_label.extend([0] * len(encode_res["input_ids"]))
+ else:
+ gt_label.append(self.label2id_map[("b-" + label).upper()])
+ gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
+ (len(encode_res["input_ids"]) - 1))
+ if self.contains_re:
+ if gt_label[0] != self.label2id_map["O"]:
+ entity_id_to_index_map[info["id"]] = len(entities)
+ entities.append({
+ "start": len(input_ids_list),
+ "end":
+ len(input_ids_list) + len(encode_res["input_ids"]),
+ "label": label.upper(),
+ })
+ input_ids_list.extend(encode_res["input_ids"])
+ token_type_ids_list.extend(encode_res["token_type_ids"])
+ bbox_list.extend([bbox] * len(encode_res["input_ids"]))
+ gt_label_list.extend(gt_label)
+ words_list.append(text)
+
+ encoded_inputs = {
+ "input_ids": input_ids_list,
+ "labels": gt_label_list,
+ "token_type_ids": token_type_ids_list,
+ "bbox": bbox_list,
+ "attention_mask": [1] * len(input_ids_list),
+ # "words_list": words_list,
+ }
+ encoded_inputs = self.pad_sentences(
+ encoded_inputs,
+ max_seq_len=self.max_seq_len,
+ return_attention_mask=self.return_attention_mask)
+ encoded_inputs = self.truncate_inputs(encoded_inputs)
+
+ if self.contains_re:
+ relations = self._relations(entities, relations, id2label,
+ empty_entity, entity_id_to_index_map)
+ encoded_inputs['relations'] = relations
+ encoded_inputs['entities'] = entities
+ return encoded_inputs
+
+ def _chunk_ser(self, encoded_inputs):
+ encoded_inputs_all = []
+ seq_len = len(encoded_inputs['input_ids'])
+ chunk_size = 512
+ for chunk_id, index in enumerate(range(0, seq_len, chunk_size)):
+ chunk_beg = index
+ chunk_end = min(index + chunk_size, seq_len)
+ encoded_inputs_example = {}
+ for key in encoded_inputs:
+ encoded_inputs_example[key] = encoded_inputs[key][chunk_beg:
+ chunk_end]
+
+ encoded_inputs_all.append(encoded_inputs_example)
+ return encoded_inputs_all
+
+ def _chunk_re(self, encoded_inputs):
+ # prepare data
+ entities = encoded_inputs.pop('entities')
+ relations = encoded_inputs.pop('relations')
+ encoded_inputs_all = []
+ chunk_size = 512
+ for chunk_id, index in enumerate(
+ range(0, len(encoded_inputs["input_ids"]), chunk_size)):
+ item = {}
+ for k in encoded_inputs:
+ item[k] = encoded_inputs[k][index:index + chunk_size]
+
+ # select entity in current chunk
+ entities_in_this_span = []
+ global_to_local_map = {} #
+ for entity_id, entity in enumerate(entities):
+ if (index <= entity["start"] < index + chunk_size and
+ index <= entity["end"] < index + chunk_size):
+ entity["start"] = entity["start"] - index
+ entity["end"] = entity["end"] - index
+ global_to_local_map[entity_id] = len(entities_in_this_span)
+ entities_in_this_span.append(entity)
+
+ # select relations in current chunk
+ relations_in_this_span = []
+ for relation in relations:
+ if (index <= relation["start_index"] < index + chunk_size and
+ index <= relation["end_index"] < index + chunk_size):
+ relations_in_this_span.append({
+ "head": global_to_local_map[relation["head"]],
+ "tail": global_to_local_map[relation["tail"]],
+ "start_index": relation["start_index"] - index,
+ "end_index": relation["end_index"] - index,
+ })
+ item.update({
+ "entities": reformat(entities_in_this_span),
+ "relations": reformat(relations_in_this_span),
+ })
+ item['entities']['label'] = [
+ self.entities_labels[x] for x in item['entities']['label']
+ ]
+ encoded_inputs_all.append(item)
+ return encoded_inputs_all
+
+ def _relations(self, entities, relations, id2label, empty_entity,
+ entity_id_to_index_map):
+ """
+ build relations
+ """
+ relations = list(set(relations))
+ relations = [
+ rel for rel in relations
+ if rel[0] not in empty_entity and rel[1] not in empty_entity
+ ]
+ kv_relations = []
+ for rel in relations:
+ pair = [id2label[rel[0]], id2label[rel[1]]]
+ if pair == ["question", "answer"]:
+ kv_relations.append({
+ "head": entity_id_to_index_map[rel[0]],
+ "tail": entity_id_to_index_map[rel[1]]
+ })
+ elif pair == ["answer", "question"]:
+ kv_relations.append({
+ "head": entity_id_to_index_map[rel[1]],
+ "tail": entity_id_to_index_map[rel[0]]
+ })
+ else:
+ continue
+ relations = sorted(
+ [{
+ "head": rel["head"],
+ "tail": rel["tail"],
+ "start_index": get_relation_span(rel, entities)[0],
+ "end_index": get_relation_span(rel, entities)[1],
+ } for rel in kv_relations],
+ key=lambda x: x["head"], )
+ return relations
+
+ def load_img(self, image_path):
+ # read img
+ img = cv2.imread(image_path)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ resize_h, resize_w = self.img_size
+ im_shape = img.shape[0:2]
+ im_scale_y = resize_h / im_shape[0]
+ im_scale_x = resize_w / im_shape[1]
+ img_new = cv2.resize(
+ img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=2)
+ mean = np.array([0.485, 0.456, 0.406])[np.newaxis, np.newaxis, :]
+ std = np.array([0.229, 0.224, 0.225])[np.newaxis, np.newaxis, :]
+ img_new = img_new / 255.0
+ img_new -= mean
+ img_new /= std
+ img = img_new.transpose((2, 0, 1))
+ return img
+
+ def __getitem__(self, idx):
+ if self.load_mode == "all":
+ data = copy.deepcopy(self.encoded_inputs_all[idx])
+ else:
+ data = self._parse_label_file(self.all_lines[idx])[0]
+
+ image_path = data.pop('image_path')
+ data["image"] = self.load_img(image_path)
+
+ return_data = {}
+ for k, v in data.items():
+ if k in self.return_keys:
+ if self.return_keys[k] == 'np':
+ v = np.array(v)
+ return_data[k] = v
+ return return_data
+
+ def __len__(self, ):
+ if self.load_mode == "all":
+ return len(self.encoded_inputs_all)
+ else:
+ return len(self.all_lines)
+
+
+def get_relation_span(rel, entities):
+ bound = []
+ for entity_index in [rel["head"], rel["tail"]]:
+ bound.append(entities[entity_index]["start"])
+ bound.append(entities[entity_index]["end"])
+ return min(bound), max(bound)
+
+
+def reformat(data):
+ new_data = {}
+ for item in data:
+ for k, v in item.items():
+ if k not in new_data:
+ new_data[k] = []
+ new_data[k].append(v)
+ return new_data
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index a19c8ee3355b010b55d1dbf16aa0e21940ba546c..6e5cecf632a42294006cffdf4cf3a466a326260b 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -15,4 +15,4 @@ op.det.local_service_conf.thread_num:1|6
op.det.local_service_conf.use_trt:False|True
op.det.local_service_conf.precision:fp32|fp16|int8
pipline:pipeline_rpc_client.py|pipeline_http_client.py
---image_dir:../../doc/imgs
\ No newline at end of file
+--image_dir:../../doc/imgs
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index d152ef29d0a2983e656f9868147158a3b7e66aa5..8876157ef8f4b44b227c171d25bdfd1060007910 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -201,8 +201,11 @@ fi
if [ ${MODE} = "serving_infer" ];then
# prepare serving env
- python_name=$(func_parser_value "${lines[2]}")
- wget https://paddle-serving.bj.bcebos.com/chain/paddle_serving_server_gpu-0.0.0.post101-py3-none-any.whl
+ python_name_list=$(func_parser_value "${lines[2]}")
+ IFS='|'
+ array=(${python_name_list})
+ python_name=${array[0]}
+ wget -nc https://paddle-serving.bj.bcebos.com/chain/paddle_serving_server_gpu-0.0.0.post101-py3-none-any.whl
${python_name} -m pip install install paddle_serving_server_gpu-0.0.0.post101-py3-none-any.whl
${python_name} -m pip install paddle_serving_client==0.6.1
${python_name} -m pip install paddle-serving-app==0.6.3
diff --git a/test_tipc/readme.md b/test_tipc/readme.md
index a188b675a90a651588fdda08694bc30ca9e0f301..8b2489f3445ddfa87c1e587d6da81992fdb90e64 100644
--- a/test_tipc/readme.md
+++ b/test_tipc/readme.md
@@ -1,9 +1,9 @@
-# 飞桨训推一体认证(TIPC)
+# 飞桨训推一体全流程(TIPC)
## 1. 简介
-飞桨除了基本的模型训练和预测,还提供了支持多端多平台的高性能推理部署工具。本文档提供了PaddleOCR中所有模型的飞桨训推一体认证 (Training and Inference Pipeline Certification(TIPC)) 信息和测试工具,方便用户查阅每种模型的训练推理部署打通情况,并可以进行一键测试。
+飞桨除了基本的模型训练和预测,还提供了支持多端多平台的高性能推理部署工具。本文档提供了PaddleOCR中所有模型的飞桨训推一体全流程(Training and Inference Pipeline Criterion(TIPC))信息和测试工具,方便用户查阅每种模型的训练推理部署打通情况,并可以进行一键测试。
diff --git a/test_tipc/test_serving.sh b/test_tipc/test_serving.sh
index c36935a60fecacea672fd932773a8fb0bdcd619b..1318d012d401c4f4e8540a5d0d227ea75f677004 100644
--- a/test_tipc/test_serving.sh
+++ b/test_tipc/test_serving.sh
@@ -10,7 +10,7 @@ lines=(${dataline})
# parser serving
model_name=$(func_parser_value "${lines[1]}")
-python=$(func_parser_value "${lines[2]}")
+python_list=$(func_parser_value "${lines[2]}")
trans_model_py=$(func_parser_value "${lines[3]}")
infer_model_dir_key=$(func_parser_key "${lines[4]}")
infer_model_dir_value=$(func_parser_value "${lines[4]}")
@@ -54,14 +54,15 @@ function func_serving(){
set_serving_server=$(func_set_params "${serving_server_key}" "${serving_server_value}")
set_serving_client=$(func_set_params "${serving_client_key}" "${serving_client_value}")
set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
- trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
+ python_list=(${python_list})
+ trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
eval $trans_model_cmd
cd ${serving_dir_value}
echo $PWD
unset https_proxy
unset http_proxy
- for python in ${python[*]}; do
- if [ ${python} = "cpp"]; then
+ for python in ${python_list[*]}; do
+ if [ ${python} = "cpp" ]; then
for use_gpu in ${web_use_gpu_list[*]}; do
if [ ${use_gpu} = "null" ]; then
web_service_cpp_cmd="${python} -m paddle_serving_server.serve --model ppocr_det_mobile_2.0_serving/ ppocr_rec_mobile_2.0_serving/ --port 9293"
@@ -91,9 +92,6 @@ function func_serving(){
echo ${ues_gpu}
if [ ${use_gpu} = "null" ]; then
for use_mkldnn in ${web_use_mkldnn_list[*]}; do
- if [ ${use_mkldnn} = "False" ]; then
- continue
- fi
for threads in ${web_cpu_threads_list[*]}; do
set_cpu_threads=$(func_set_params "${web_cpu_threads_key}" "${threads}")
web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}=${use_gpu} ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} &"
@@ -124,6 +122,9 @@ function func_serving(){
continue
fi
set_tensorrt=$(func_set_params "${web_use_trt_key}" "${use_trt}")
+ if [ ${use_trt} = True ]; then
+ device_type=2
+ fi
set_precision=$(func_set_params "${web_precision_key}" "${precision}")
web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} & "
eval $web_service_cmd
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index f437056ec7b10e28e626d2028b6401cebc647bb1..d7e058c2b7c0eaf6bd40dd197a3cb1417bc7bb7d 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -271,8 +271,13 @@ def create_predictor(args, mode, logger):
min_input_shape = {"x": [1, 3, 10, 10]}
max_input_shape = {"x": [1, 3, 512, 512]}
opt_input_shape = {"x": [1, 3, 256, 256]}
- config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
- opt_input_shape)
+ if mode == "rec":
+ if args.rec_algorithm == "CRNN":
+ config.set_trt_dynamic_shape_info(
+ min_input_shape, max_input_shape, opt_input_shape)
+ else:
+ config.set_trt_dynamic_shape_info(
+ min_input_shape, max_input_shape, opt_input_shape)
else:
config.disable_gpu()
@@ -311,7 +316,10 @@ def create_predictor(args, mode, logger):
def get_infer_gpuid():
- cmd = "env | grep CUDA_VISIBLE_DEVICES"
+ if not paddle.fluid.core.is_compiled_with_rocm():
+ cmd = "env | grep CUDA_VISIBLE_DEVICES"
+ else:
+ cmd = "env | grep HIP_VISIBLE_DEVICES"
env_cuda = os.popen(cmd).readlines()
if len(env_cuda) == 0:
return 0