diff --git a/PPOCRLabel/README_ch.md b/PPOCRLabel/README_ch.md
index d745c9734636e2637f5019a5dd5097dd88cf2e56..e50e7b7daf5ac203129cf80b5a926a988929af5c 100644
--- a/PPOCRLabel/README_ch.md
+++ b/PPOCRLabel/README_ch.md
@@ -71,6 +71,8 @@ pip3 install opencv-contrib-python-headless==4.2.0.32 # 如果下载过慢请添
PPOCRLabel --lang ch # 启动
```
+> 如果上述安装出现问题,可以参考3.6节 错误提示
+
#### 1.2.2 本地构建whl包并安装
```bash
diff --git a/PPOCRLabel/libs/canvas.py b/PPOCRLabel/libs/canvas.py
index d5662ac79a85c07c79ed2b7df315f338a229535c..6ac1f28b85e65c3776d310136352b70c45628db6 100644
--- a/PPOCRLabel/libs/canvas.py
+++ b/PPOCRLabel/libs/canvas.py
@@ -704,8 +704,9 @@ class Canvas(QWidget):
def keyPressEvent(self, ev):
key = ev.key()
- shapesBackup = []
shapesBackup = copy.deepcopy(self.shapes)
+ if len(shapesBackup) == 0:
+ return
self.shapesBackups.pop()
self.shapesBackups.append(shapesBackup)
if key == Qt.Key_Escape and self.current:
diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
index ab484a44833a405513d7f2b4079a4da4c2e403c8..bb6a196864b6e9e7525f2b5217f0c90ea2ca05a4 100644
--- a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
+++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
@@ -18,6 +18,7 @@ Global:
Architecture:
name: DistillationModel
algorithm: Distillation
+ model_type: det
Models:
Teacher:
freeze_params: true
diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py
index dddae923de223178665e3bfb55a2e7a8c0d5ba17..0cb86108d2275dc6ee1a74e118c27b94131975d3 100755
--- a/deploy/slim/quantization/export_model.py
+++ b/deploy/slim/quantization/export_model.py
@@ -111,7 +111,7 @@ def main():
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
use_srn = config['Architecture']['algorithm'] == "SRN"
- model_type = config['Architecture']['model_type']
+ model_type = config['Architecture'].get('model_type', None)
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn)
@@ -120,8 +120,7 @@ def main():
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
- infer_shape = [3, 32, 100] if config['Architecture'][
- 'model_type'] != "det" else [3, 640, 640]
+ infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640]
save_path = config["Global"]["save_inference_dir"]
diff --git a/doc/datasets/ch_doc2.jpg b/doc/datasets/ch_doc2.jpg
deleted file mode 100644
index 23343b8dedbae7be025552e3a45f9b7af7cf49ee..0000000000000000000000000000000000000000
Binary files a/doc/datasets/ch_doc2.jpg and /dev/null differ
diff --git a/doc/doc_ch/code_and_doc.md b/doc/doc_ch/code_and_doc.md
index b1d8b4b36bd45fc1574b5049ce9af808a00b7574..7a4c64efaff22e99b6d95151ec3675c50a5a0910 100644
--- a/doc/doc_ch/code_and_doc.md
+++ b/doc/doc_ch/code_and_doc.md
@@ -139,7 +139,7 @@ PaddleOCR欢迎大家向repo中积极贡献代码,下面给出一些贡献代
- 在PaddleOCR的 [GitHub首页](https://github.com/PaddlePaddle/PaddleOCR),点击左上角 `Fork` 按钮,在你的个人目录下创建 `远程仓库`,比如`https://github.com/{your_name}/PaddleOCR`。
-![banner](/Users/zhulingfeng01/OCR/PaddleOCR/doc/banner.png)
+![banner](../banner.png)
- 将 `远程仓库` Clone到本地
@@ -230,7 +230,7 @@ pre-commit
重复上述步骤,直到pre-comit格式检查不报错。如下所示。
-[![img](https://github.com/PaddlePaddle/PaddleClas/raw/release/2.3/docs/images/quick_start/community/003_precommit_pass.png)](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/images/quick_start/community/003_precommit_pass.png)
+![img](../precommit_pass.png)
使用下面的命令完成提交。
@@ -258,7 +258,7 @@ git push origin new_branch
点击new pull request,选择本地分支和目标分支,如下图所示。在PR的描述说明中,填写该PR所完成的功能。接下来等待review,如果有需要修改的地方,参照上述步骤更新 origin 中的对应分支即可。
-![banner](/Users/zhulingfeng01/OCR/PaddleOCR/doc/pr.png)
+![banner](../pr.png)
#### 3.2.8 签署CLA协议和通过单元测试
diff --git a/doc/doc_ch/datasets.md b/doc/doc_ch/datasets.md
index 6d84dbbe484be1e2b19a4dedced90f61b7085148..d365fd711aff2dffcd30dd06028734cc707d5df0 100644
--- a/doc/doc_ch/datasets.md
+++ b/doc/doc_ch/datasets.md
@@ -49,7 +49,6 @@ https://aistudio.baidu.com/aistudio/datasetdetail/8429
- 每个样本固定10个字符,字符随机截取自语料库中的句子
- 图片分辨率统一为280x32
![](../datasets/ch_doc1.jpg)
- ![](../datasets/ch_doc2.jpg)
![](../datasets/ch_doc3.jpg)
- **下载地址**:https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (密码:lu7m)
diff --git a/doc/doc_ch/distributed_training.md b/doc/doc_ch/distributed_training.md
index 411ce5ba6aea26755cc65c405be6e0f0d5fd4738..e0251b21ea1157084e4e1b1d77429264d452aa20 100644
--- a/doc/doc_ch/distributed_training.md
+++ b/doc/doc_ch/distributed_training.md
@@ -13,7 +13,7 @@
```shell
python3 -m paddle.distributed.launch \
--log_dir=./log/ \
- --gpus '0,1,2,3,4,5,6,7' \
+ --gpus "0,1,2,3,4,5,6,7" \
tools/train.py \
-c configs/rec/rec_mv3_none_bilstm_ctc.yml
```
diff --git a/doc/doc_ch/models_list.md b/doc/doc_ch/models_list.md
index 8f1a53bccacde8e478e67c7eae5df3c818bb4004..6843ffdc19d5bde205124c30f1d0a5fc2144ce99 100644
--- a/doc/doc_ch/models_list.md
+++ b/doc/doc_ch/models_list.md
@@ -1,4 +1,4 @@
-# OCR模型列表(V2.1,2021年9月6日更新)
+# PP-OCR系列模型列表(V2.1,2021年9月6日更新)
> **说明**
> 1. 2.1版模型相比2.0版模型,2.1的模型在模型精度上做了提升
diff --git a/doc/doc_ch/thirdparty.md b/doc/doc_ch/thirdparty.md
index d01f4b09c01d2c090c829bbb9c58c43557566118..b83b8fee8dbbf867d95c4cd0e087ebfde5f4bfc1 100644
--- a/doc/doc_ch/thirdparty.md
+++ b/doc/doc_ch/thirdparty.md
@@ -12,30 +12,37 @@ PaddleOCR希望可以通过AI的力量助力任何一位有梦想的开发者实
## 1. 社区贡献
-### 1.1 为PaddleOCR新增功能
+### 1.1 基于PaddleOCR的社区贡献
+
+- 【最新】 [FastOCRLabel](https://gitee.com/BaoJianQiang/FastOCRLabel):完整的C#版本标注工具 (@ [包建强](https://gitee.com/BaoJianQiang) )
+
+#### 1.1.1 通用工具
+
+- [DangoOCR离线版](https://github.com/PantsuDango/DangoOCR):通用型桌面级即时翻译工具 (@ [PantsuDango](https://github.com/PantsuDango))
+- [scr2txt](https://github.com/lstwzd/scr2txt):截屏转文字工具 (@ [lstwzd](https://github.com/lstwzd))
+- [AI Studio项目](https://aistudio.baidu.com/aistudio/projectdetail/1054614?channelType=0&channel=0):英文视频自动生成字幕( @ [叶月水狐](https://aistudio.baidu.com/aistudio/personalcenter/thirdview/322052))
+
+#### 1.1.2 垂类场景工具
+
+- [id_card_ocr](https://github.com/baseli/id_card_ocr):身份证复印件识别(@ [baseli](https://github.com/baseli))
+- [Paddle_Table_Image_Reader](https://github.com/thunder95/Paddle_Table_Image_Reader):能看懂表格图片的数据助手(@ [thunder95](https://github.com/thunder95]))
+
+#### 1.1.3 前后处理
+
+- [paddleOCRCorrectOutputs](https://github.com/yuranusduke/paddleOCRCorrectOutputs):获取OCR识别结果的key-value(@ [yuranusduke](https://github.com/yuranusduke))
+
+### 1.2 为PaddleOCR新增功能
- 非常感谢 [authorfu](https://github.com/authorfu) 贡献Android([#340](https://github.com/PaddlePaddle/PaddleOCR/pull/340))和[xiadeye](https://github.com/xiadeye) 贡献IOS的demo代码([#325](https://github.com/PaddlePaddle/PaddleOCR/pull/325))
- 非常感谢 [tangmq](https://gitee.com/tangmq) 给PaddleOCR增加Docker化部署服务,支持快速发布可调用的Restful API服务([#507](https://github.com/PaddlePaddle/PaddleOCR/pull/507))。
- 非常感谢 [lijinhan](https://github.com/lijinhan) 给PaddleOCR增加java SpringBoot 调用OCR Hubserving接口完成对OCR服务化部署的使用([#1027](https://github.com/PaddlePaddle/PaddleOCR/pull/1027))。
- 非常感谢 [Evezerest](https://github.com/Evezerest), [ninetailskim](https://github.com/ninetailskim), [edencfc](https://github.com/edencfc), [BeyondYourself](https://github.com/BeyondYourself), [1084667371](https://github.com/1084667371) 贡献了[PPOCRLabel](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/PPOCRLabel/README_ch.md) 的完整代码。
-### 1.2 基于PaddleOCR的社区贡献
-
-- 【最新】完整的C#版本标注工具 [FastOCRLabel](https://gitee.com/BaoJianQiang/FastOCRLabel) (@ [包建强](https://gitee.com/BaoJianQiang) )
-- 通用型桌面级即时翻译工具 [DangoOCR离线版](https://github.com/PantsuDango/DangoOCR) (@ [PantsuDango](https://github.com/PantsuDango))
-- 获取OCR识别结果的key-value [paddleOCRCorrectOutputs](https://github.com/yuranusduke/paddleOCRCorrectOutputs) (@ [yuranusduke](https://github.com/yuranusduke))
-- 截屏转文字工具 [scr2txt](https://github.com/lstwzd/scr2txt) (@ [lstwzd](https://github.com/lstwzd))
-- 身份证复印件识别 [id_card_ocr](https://github.com/baseli/id_card_ocr)(@ [baseli](https://github.com/baseli))
-- 能看懂表格图片的数据助手:[Paddle_Table_Image_Reader](https://github.com/thunder95/Paddle_Table_Image_Reader) (@ [thunder95][https://github.com/thunder95])
-- 英文视频自动生成字幕 [AI Studio项目](https://aistudio.baidu.com/aistudio/projectdetail/1054614?channelType=0&channel=0)( @ [叶月水狐](https://aistudio.baidu.com/aistudio/personalcenter/thirdview/322052))
-
### 1.3 代码与文档优化
-
- 非常感谢 [zhangxin](https://github.com/ZhangXinNan)([Blog](https://blog.csdn.net/sdlypyzq)) 贡献新的可视化方式、添加.gitgnore、处理手动设置PYTHONPATH环境变量的问题([#210](https://github.com/PaddlePaddle/PaddleOCR/pull/210))。
- 非常感谢 [lyl120117](https://github.com/lyl120117) 贡献打印网络结构的代码([#304](https://github.com/PaddlePaddle/PaddleOCR/pull/304))。
- 非常感谢 [BeyondYourself](https://github.com/BeyondYourself) 给PaddleOCR提了很多非常棒的建议,并简化了PaddleOCR的部分代码风格([so many commits)](https://github.com/PaddlePaddle/PaddleOCR/commits?author=BeyondYourself)。
-
- 非常感谢 [Khanh Tran](https://github.com/xxxpsyduck) 和 [Karl Horky](https://github.com/karlhorky) 贡献修改英文文档。
### 1.4 多语言语料
diff --git a/doc/doc_en/datasets_en.md b/doc/doc_en/datasets_en.md
index 61d2033b4fe8f0077ad66fb9ae2cd559ce29fd65..0e6b6f381e9d008add802c5f8a30d5498a4f94b2 100644
--- a/doc/doc_en/datasets_en.md
+++ b/doc/doc_en/datasets_en.md
@@ -50,7 +50,6 @@ https://aistudio.baidu.com/aistudio/datasetdetail/8429
- Each sample is fixed with 10 characters, and the characters are randomly intercepted from the sentences in the corpus
- Image resolution is 280x32
![](../datasets/ch_doc1.jpg)
- ![](../datasets/ch_doc2.jpg)
![](../datasets/ch_doc3.jpg)
- **Download link**:https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (Password: lu7m)
diff --git a/doc/doc_en/distributed_training.md b/doc/doc_en/distributed_training.md
index 7a8b71ce308837568c84bf56292f78e9979d3907..519a42f0dc4b9bd4fa18f3f65019e4235282df92 100644
--- a/doc/doc_en/distributed_training.md
+++ b/doc/doc_en/distributed_training.md
@@ -13,7 +13,7 @@ Take recognition as an example. After the data is prepared locally, start the tr
```shell
python3 -m paddle.distributed.launch \
--log_dir=./log/ \
- --gpus '0,1,2,3,4,5,6,7' \
+ --gpus "0,1,2,3,4,5,6,7" \
tools/train.py \
-c configs/rec/rec_mv3_none_bilstm_ctc.yml
```
diff --git a/doc/joinus.PNG b/doc/joinus.PNG
index 99964b62d0e8a5867d5eb7a29640f0414c7af3b2..e2dd99383de10b5263c1ec9d255a8a31815b50b6 100644
Binary files a/doc/joinus.PNG and b/doc/joinus.PNG differ
diff --git a/doc/precommit_pass.png b/doc/precommit_pass.png
new file mode 100644
index 0000000000000000000000000000000000000000..067fb75ddb222ab0b9c71a46619c3fe7b239bc26
Binary files /dev/null and b/doc/precommit_pass.png differ
diff --git a/ppocr/data/imaug/copy_paste.py b/ppocr/data/imaug/copy_paste.py
index bbf62e2a3d813671551efa1a76c03754b1b764f5..0b3386c896792bd670cd2bfc757eb3b80f22bac4 100644
--- a/ppocr/data/imaug/copy_paste.py
+++ b/ppocr/data/imaug/copy_paste.py
@@ -32,6 +32,7 @@ class CopyPaste(object):
self.aug = IaaAugment(augmenter_args)
def __call__(self, data):
+ point_num = data['polys'].shape[1]
src_img = data['image']
src_polys = data['polys'].tolist()
src_ignores = data['ignore_tags'].tolist()
@@ -57,6 +58,9 @@ class CopyPaste(object):
src_img, box = self.paste_img(src_img, box_img, src_polys)
if box is not None:
+ box = box.tolist()
+ for _ in range(len(box), point_num):
+ box.append(box[-1])
src_polys.append(box)
src_ignores.append(tag)
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py
index 6a33e1342506f26ccaa4a146f3f02fadfbd741a2..ee8571b8c452bbd834fc5dbcf01ce390562163d6 100644
--- a/ppocr/data/simple_dataset.py
+++ b/ppocr/data/simple_dataset.py
@@ -14,6 +14,7 @@
import numpy as np
import os
import random
+import traceback
from paddle.io import Dataset
from .imaug import transform, create_operators
@@ -93,7 +94,8 @@ class SimpleDataSet(Dataset):
img = f.read()
data['image'] = img
data = transform(data, load_data_ops)
- if data is None:
+
+ if data is None or data['polys'].shape[1]!=4:
continue
ext_data.append(data)
return ext_data
@@ -115,10 +117,10 @@ class SimpleDataSet(Dataset):
data['image'] = img
data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops)
- except Exception as e:
+ except:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
- data_line, e))
+ data_line, traceback.format_exc()))
outs = None
if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation.
diff --git a/ppocr/modeling/backbones/det_resnet_vd.py b/ppocr/modeling/backbones/det_resnet_vd.py
index 3bb4a0d50501860d5e9df2971e93fba66c152187..a29cf1b5e1ff56e59984bc91226ef7e6b65d0da1 100644
--- a/ppocr/modeling/backbones/det_resnet_vd.py
+++ b/ppocr/modeling/backbones/det_resnet_vd.py
@@ -25,16 +25,14 @@ __all__ = ["ResNet"]
class ConvBNLayer(nn.Layer):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- groups=1,
- is_vd_mode=False,
- act=None,
- name=None, ):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
@@ -47,19 +45,8 @@ class ConvBNLayer(nn.Layer):
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
- weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
- if name == "conv1":
- bn_name = "bn_" + name
- else:
- bn_name = "bn" + name[3:]
- self._batch_norm = nn.BatchNorm(
- out_channels,
- act=act,
- param_attr=ParamAttr(name=bn_name + '_scale'),
- bias_attr=ParamAttr(bn_name + '_offset'),
- moving_mean_name=bn_name + '_mean',
- moving_variance_name=bn_name + '_variance')
+ self._batch_norm = nn.BatchNorm(out_channels, act=act)
def forward(self, inputs):
if self.is_vd_mode:
@@ -75,29 +62,25 @@ class BottleneckBlock(nn.Layer):
out_channels,
stride,
shortcut=True,
- if_first=False,
- name=None):
+ if_first=False):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
- act='relu',
- name=name + "_branch2a")
+ act='relu')
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu',
- name=name + "_branch2b")
+ act='relu')
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
- act=None,
- name=name + "_branch2c")
+ act=None)
if not shortcut:
self.short = ConvBNLayer(
@@ -105,8 +88,7 @@ class BottleneckBlock(nn.Layer):
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
- is_vd_mode=False if if_first else True,
- name=name + "_branch1")
+ is_vd_mode=False if if_first else True)
self.shortcut = shortcut
@@ -125,13 +107,13 @@ class BottleneckBlock(nn.Layer):
class BasicBlock(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False, ):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
@@ -139,14 +121,12 @@ class BasicBlock(nn.Layer):
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu',
- name=name + "_branch2a")
+ act='relu')
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
- act=None,
- name=name + "_branch2b")
+ act=None)
if not shortcut:
self.short = ConvBNLayer(
@@ -154,8 +134,7 @@ class BasicBlock(nn.Layer):
out_channels=out_channels,
kernel_size=1,
stride=1,
- is_vd_mode=False if if_first else True,
- name=name + "_branch1")
+ is_vd_mode=False if if_first else True)
self.shortcut = shortcut
@@ -201,22 +180,19 @@ class ResNet(nn.Layer):
out_channels=32,
kernel_size=3,
stride=2,
- act='relu',
- name="conv1_1")
+ act='relu')
self.conv1_2 = ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
- act='relu',
- name="conv1_2")
+ act='relu')
self.conv1_3 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
- act='relu',
- name="conv1_3")
+ act='relu')
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
@@ -226,13 +202,6 @@ class ResNet(nn.Layer):
block_list = []
shortcut = False
for i in range(depth[block]):
- if layers in [101, 152] and block == 2:
- if i == 0:
- conv_name = "res" + str(block + 2) + "a"
- else:
- conv_name = "res" + str(block + 2) + "b" + str(i)
- else:
- conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
@@ -241,8 +210,7 @@ class ResNet(nn.Layer):
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
- if_first=block == i == 0,
- name=conv_name))
+ if_first=block == i == 0))
shortcut = True
block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4)
@@ -252,7 +220,6 @@ class ResNet(nn.Layer):
block_list = []
shortcut = False
for i in range(depth[block]):
- conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
@@ -261,8 +228,7 @@ class ResNet(nn.Layer):
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
- if_first=block == i == 0,
- name=conv_name))
+ if_first=block == i == 0))
shortcut = True
block_list.append(basic_block)
self.out_channels.append(num_filters[block])
diff --git a/ppstructure/README.md b/ppstructure/README.md
index 849c5c5667ff0532dfee35479715880192df0dc5..8994cdd46191a0fd4fb1beba2fcad91542e19b50 100644
--- a/ppstructure/README.md
+++ b/ppstructure/README.md
@@ -153,7 +153,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_in
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
-python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
+python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
```
After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image.
diff --git a/ppstructure/README_ch.md b/ppstructure/README_ch.md
index 821a6c3e36361abefa4d754537fdbd694e844efe..607efac1bf6bfaa58f0e96ceef1a0ee344189e9c 100644
--- a/ppstructure/README_ch.md
+++ b/ppstructure/README_ch.md
@@ -1,6 +1,12 @@
[English](README.md) | 简体中文
-# PP-Structure
+## 简介
+PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包,旨在帮助开发者更好的完成文档理解相关任务。
+
+## 近期更新
+* 2021.12.07 新增VQA任务-SER和RE。
+
+## 特性
PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包,主要特性如下:
- 支持对图片形式的文档进行版面分析,可以划分**文字、标题、表格、图片以及列表**5类区域(与Layout-Parser联合使用)
@@ -8,181 +14,88 @@ PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包
- 支持表格区域进行结构化分析,最终结果输出Excel文件
- 支持python whl包和命令行两种方式,简单易用
- 支持版面分析和表格结构化两类任务自定义训练
+- 支持文档视觉问答(Document Visual Question Answering,DOC-VQA)任务-语义实体识别(Semantic Entity Recognition,SER)和关系抽取(Relation Extraction,RE)
-## 1. 效果展示
-
-
-
-
-
-## 2. 安装
-
-### 2.1 安装依赖
-
-- **(1) 安装PaddlePaddle**
-
-```bash
-pip3 install --upgrade pip
-
-# GPU安装
-python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/simple
-
-# CPU安装
- python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple
-
-```
-更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
-
-- **(2) 安装 Layout-Parser**
-
-```bash
-pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
-```
-
-### 2.2 安装PaddleOCR(包含PP-OCR和PP-Structure)
-
-- **(1) PIP快速安装PaddleOCR whl包(仅预测)**
-```bash
-pip install "paddleocr>=2.2" # 推荐使用2.2+版本
-```
-
-- **(2) 完整克隆PaddleOCR源码(预测+训练)**
-
-```bash
-【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR
-
-#如果因为网络问题无法pull成功,也可选择使用码云上的托管:
-git clone https://gitee.com/paddlepaddle/PaddleOCR
-
-#注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
-```
-
-
-## 3. PP-Structure 快速开始
-
-### 3.1 命令行使用(默认参数,极简)
-
-```bash
-paddleocr --image_dir=../doc/table/1.png --type=structure
-```
-
-### 3.2 Python脚本使用(自定义参数,灵活)
+## 1. 效果展示
-```python
-import os
-import cv2
-from paddleocr import PPStructure,draw_structure_result,save_structure_res
+### 1.1 版面分析和表格识别
-table_engine = PPStructure(show_log=True)
+
-save_folder = './output/table'
-img_path = '../doc/table/1.png'
-img = cv2.imread(img_path)
-result = table_engine(img)
-save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
+### 1.2 VQA
-for line in result:
- line.pop('img')
- print(line)
+* SER
-from PIL import Image
+![](./vqa/images/result_ser/zh_val_0_ser.jpg) | ![](./vqa/images/result_ser/zh_val_42_ser.jpg)
+---|---
-font_path = '../doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
-image = Image.open(img_path).convert('RGB')
-im_show = draw_structure_result(image, result,font_path=font_path)
-im_show = Image.fromarray(im_show)
-im_show.save('result.jpg')
-```
+图中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别
-### 3.3 返回结果说明
-PP-Structure的返回结果为一个dict组成的list,示例如下
+* 深紫色:HEADER
+* 浅紫色:QUESTION
+* 军绿色:ANSWER
-```shell
-[
- { 'type': 'Text',
- 'bbox': [34, 432, 345, 462],
- 'res': ([[36.0, 437.0, 341.0, 437.0, 341.0, 446.0, 36.0, 447.0], [41.0, 454.0, 125.0, 453.0, 125.0, 459.0, 41.0, 460.0]],
- [('Tigure-6. The performance of CNN and IPT models using difforen', 0.90060663), ('Tent ', 0.465441)])
- }
-]
-```
-dict 里各个字段说明如下
+在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
-| 字段 | 说明 |
-| --------------- | -------------|
-|type|图片区域的类型|
-|bbox|图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y]|
-|res|图片区域的OCR或表格识别结果。
表格: 表格的HTML字符串;
OCR: 一个包含各个单行文字的检测坐标和识别结果的元组|
+* RE
+![](./vqa/images/result_re/zh_val_21_re.jpg) | ![](./vqa/images/result_re/zh_val_40_re.jpg)
+---|---
-### 3.4 参数说明
-| 字段 | 说明 | 默认值 |
-| --------------- | ---------------------------------------- | ------------------------------------------- |
-| output | excel和识别结果保存的地址 | ./output/table |
-| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
-| table_model_dir | 表格结构模型 inference 模型地址 | None |
-| table_char_type | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.tx |
+图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
-大部分参数和paddleocr whl包保持一致,见 [whl包文档](../doc/doc_ch/whl.md)
+## 2. 快速体验
-运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
+代码体验:从 [快速安装](./docs/quickstart.md) 开始
+## 3. PP-Structure Pipeline介绍
-## 4. PP-Structure Pipeline介绍
+### 3.1 版面分析+表格识别
![pipeline](../doc/table/pipeline.jpg)
在PP-Structure中,图片会先经由Layout-Parser进行版面分析,在版面分析中,会对图片里的区域进行分类,包括**文字、标题、图片、列表和表格**5类。对于前4类区域,直接使用PP-OCR完成对应区域文字检测与识别。对于表格类区域,经过表格结构化处理后,表格图片转换为相同表格样式的Excel文件。
-### 4.1 版面分析
+#### 3.1.1 版面分析
版面分析对文档数据进行区域分类,其中包括版面分析工具的Python脚本使用、提取指定类别检测框、性能指标以及自定义训练版面分析模型,详细内容可以参考[文档](layout/README_ch.md)。
-### 4.2 表格识别
+#### 3.1.2 表格识别
表格识别将表格图片转换为excel文档,其中包含对于表格文本的检测和识别以及对于表格结构和单元格坐标的预测,详细说明参考[文档](table/README_ch.md)
-## 5. 预测引擎推理(与whl包效果相同)
-使用如下命令即可完成预测引擎的推理
+### 3.2 VQA
-```python
-cd ppstructure
+coming soon
-# 下载模型
-mkdir inference && cd inference
-# 下载超轻量级中文OCR模型的检测模型并解压
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar
-# 下载超轻量级中文OCR模型的识别模型并解压
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
-# 下载超轻量级英文表格英寸模型并解压
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
-cd ..
+## 4. 模型库
-python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
-```
-运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
+PP-Structure系列模型列表(更新中)
-**Model List**
-
-LayoutParser 模型
+* LayoutParser 模型
|模型名称|模型简介|下载地址|
| --- | --- | --- |
| ppyolov2_r50vd_dcn_365e_publaynet | PubLayNet 数据集训练的版面分析模型,可以划分**文字、标题、表格、图片以及列表**5类区域 | [PubLayNet](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) |
-| ppyolov2_r50vd_dcn_365e_tableBank_word | TableBank Word 数据集训练的版面分析模型,只能检测表格 | [TableBank Word](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar) |
-| ppyolov2_r50vd_dcn_365e_tableBank_latex | TableBank Latex 数据集训练的版面分析模型,只能检测表格 | [TableBank Latex](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar) |
-OCR和表格识别模型
-|模型名称|模型简介|推理模型大小|下载地址|
+* OCR和表格识别模型
+
+|模型名称|模型简介|模型大小|下载地址|
| --- | --- | --- | --- |
|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
-|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) |
-|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
-如需要使用其他模型,可以在 [model_list](../doc/doc_ch/models_list.md) 下载模型或者使用自己训练好的模型配置到`det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
+* VQA模型
+
+|模型名称|模型简介|模型大小|下载地址|
+| --- | --- | --- | --- |
+|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
+|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
+
+
+更多模型下载,可以参考 [模型库](./docs/model_list.md)
diff --git a/ppstructure/docs/installation.md b/ppstructure/docs/installation.md
new file mode 100644
index 0000000000000000000000000000000000000000..30c25d5dc92f6ccdb0d93dafe9707f30eca0c0a9
--- /dev/null
+++ b/ppstructure/docs/installation.md
@@ -0,0 +1,28 @@
+# 快速安装
+
+## 1. PaddlePaddle 和 PaddleOCR
+
+可参考[PaddleOCR安装文档](../../doc/doc_ch/installation.md)
+
+## 2. 安装其他依赖
+
+### 2.1 版面分析所需 Layout-Parser
+
+Layout-Parser 可通过如下命令安装
+
+```bash
+pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
+```
+### 2.2 VQA所需依赖
+* paddleocr
+
+```bash
+pip3 install paddleocr
+```
+
+* PaddleNLP
+```bash
+git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
+cd PaddleNLP
+pip3 install -e .
+```
diff --git a/ppstructure/docs/model_list.md b/ppstructure/docs/model_list.md
new file mode 100644
index 0000000000000000000000000000000000000000..835d39a735462edb0d9f51493ec0529248aeadbf
--- /dev/null
+++ b/ppstructure/docs/model_list.md
@@ -0,0 +1,28 @@
+# Model List
+
+## 1. LayoutParser 模型
+
+|模型名称|模型简介|下载地址|
+| --- | --- | --- |
+| ppyolov2_r50vd_dcn_365e_publaynet | PubLayNet 数据集训练的版面分析模型,可以划分**文字、标题、表格、图片以及列表**5类区域 | [PubLayNet](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) |
+| ppyolov2_r50vd_dcn_365e_tableBank_word | TableBank Word 数据集训练的版面分析模型,只能检测表格 | [TableBank Word](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar) |
+| ppyolov2_r50vd_dcn_365e_tableBank_latex | TableBank Latex 数据集训练的版面分析模型,只能检测表格 | [TableBank Latex](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar) |
+
+## 2. OCR和表格识别模型
+
+|模型名称|模型简介|推理模型大小|下载地址|
+| --- | --- | --- | --- |
+|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
+|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
+|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) |
+|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
+|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
+
+如需要使用其他OCR模型,可以在 [model_list](../../doc/doc_ch/models_list.md) 下载模型或者使用自己训练好的模型配置到`det_model_dir`,`rec_model_dir`两个字段即可。
+
+## 3. VQA模型
+
+|模型名称|模型简介|推理模型大小|下载地址|
+| --- | --- | --- | --- |
+|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
+|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
diff --git a/ppstructure/docs/quickstart.md b/ppstructure/docs/quickstart.md
new file mode 100644
index 0000000000000000000000000000000000000000..446c577ec39cf24dd4b8699558c633a1308fa444
--- /dev/null
+++ b/ppstructure/docs/quickstart.md
@@ -0,0 +1,171 @@
+# PP-Structure 快速开始
+
+* [1. 安装PaddleOCR whl包](#1)
+* [2. 便捷使用](#2)
+ + [2.1 命令行使用](#21)
+ + [2.2 Python脚本使用](#22)
+ + [2.3 返回结果说明](#23)
+ + [2.4 参数说明](#24)
+* [3. Python脚本使用](#3)
+
+
+
+
+## 1. 安装依赖包
+
+```bash
+pip install "paddleocr>=2.3.0.2" # 推荐使用2.3.0.2+版本
+pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
+
+# 安装 PaddleNLP
+git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
+cd PaddleNLP
+pip3 install -e .
+
+```
+
+
+
+## 2. 便捷使用
+
+
+
+### 2.1 命令行使用
+
+* 版面分析+表格识别
+```bash
+paddleocr --image_dir=../doc/table/1.png --type=structure
+```
+
+* VQA
+
+coming soon
+
+
+
+### 2.2 Python脚本使用
+
+* 版面分析+表格识别
+```python
+import os
+import cv2
+from paddleocr import PPStructure,draw_structure_result,save_structure_res
+
+table_engine = PPStructure(show_log=True)
+
+save_folder = './output/table'
+img_path = '../doc/table/1.png'
+img = cv2.imread(img_path)
+result = table_engine(img)
+save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
+
+for line in result:
+ line.pop('img')
+ print(line)
+
+from PIL import Image
+
+font_path = '../doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
+image = Image.open(img_path).convert('RGB')
+im_show = draw_structure_result(image, result,font_path=font_path)
+im_show = Image.fromarray(im_show)
+im_show.save('result.jpg')
+```
+
+* VQA
+
+comming soon
+
+
+
+### 2.3 返回结果说明
+PP-Structure的返回结果为一个dict组成的list,示例如下
+
+* 版面分析+表格识别
+```shell
+[
+ { 'type': 'Text',
+ 'bbox': [34, 432, 345, 462],
+ 'res': ([[36.0, 437.0, 341.0, 437.0, 341.0, 446.0, 36.0, 447.0], [41.0, 454.0, 125.0, 453.0, 125.0, 459.0, 41.0, 460.0]],
+ [('Tigure-6. The performance of CNN and IPT models using difforen', 0.90060663), ('Tent ', 0.465441)])
+ }
+]
+```
+dict 里各个字段说明如下
+
+| 字段 | 说明 |
+| --------------- | -------------|
+|type|图片区域的类型|
+|bbox|图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y]|
+|res|图片区域的OCR或表格识别结果。
表格: 表格的HTML字符串;
OCR: 一个包含各个单行文字的检测坐标和识别结果的元组|
+
+* VQA
+
+comming soon
+
+
+
+### 2.4 参数说明
+
+| 字段 | 说明 | 默认值 |
+| --------------- | ---------------------------------------- | ------------------------------------------- |
+| output | excel和识别结果保存的地址 | ./output/table |
+| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
+| table_model_dir | 表格结构模型 inference 模型地址 | None |
+| table_char_type | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt |
+| model_name_or_path | VQA SER模型地址 | None |
+| max_seq_length | VQA SER模型最大支持token长度 | 512 |
+| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt |
+| mode | pipeline预测模式,structure: 版面分析+表格识别; vqa: ser文档信息抽取 | structure |
+
+大部分参数和paddleocr whl包保持一致,见 [whl包文档](../doc/doc_ch/whl.md)
+
+运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
+
+
+
+## 3. Python脚本使用
+
+* 版面分析+表格识别
+
+```bash
+cd ppstructure
+
+# 下载模型
+mkdir inference && cd inference
+# 下载超轻量级中文OCR模型的检测模型并解压
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar
+# 下载超轻量级中文OCR模型的识别模型并解压
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
+# 下载超轻量级英文表格英寸模型并解压
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
+cd ..
+
+python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
+ --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
+ --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \
+ --image_dir=../doc/table/1.png \
+ --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
+ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
+ --output=../output/table \
+ --vis_font_path=../doc/fonts/simfang.ttf
+```
+运行完成后,每张图片会在`output`字段指定的目录下的`talbe`目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
+
+* VQA
+
+```bash
+cd ppstructure
+
+# 下载模型
+mkdir inference && cd inference
+# 下载SER xfun 模型并解压
+wget https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar && tar xf PP-Layout_v1.0_ser_pretrained.tar
+cd ..
+
+python3 predict_system.py --model_name_or_path=vqa/PP-Layout_v1.0_ser_pretrained/ \
+ --mode=vqa \
+ --image_dir=vqa/images/input/zh_val_0.jpg \
+ --vis_font_path=../doc/fonts/simfang.ttf
+```
+运行完成后,每张图片会在`output`字段指定的目录下的`vqa`目录下存放可视化之后的图片,图片名和输入图片名一致。
diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py
index b2de3d4de80b39f046cf6cbc8a9ebbc52bf69334..e87499ccc410ae67a170f63301e5a99ef948b161 100644
--- a/ppstructure/predict_system.py
+++ b/ppstructure/predict_system.py
@@ -30,6 +30,7 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel
+from ppstructure.vqa.infer_ser_e2e import SerPredictor, draw_ser_results
from ppstructure.utility import parse_args, draw_structure_result
logger = get_logger()
@@ -37,53 +38,75 @@ logger = get_logger()
class OCRSystem(object):
def __init__(self, args):
- import layoutparser as lp
- # args.det_limit_type = 'resize_long'
- args.drop_score = 0
- if not args.show_log:
- logger.setLevel(logging.INFO)
- self.text_system = TextSystem(args)
- self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
-
- config_path = None
- model_path = None
- if os.path.isdir(args.layout_path_model):
- model_path = args.layout_path_model
- else:
- config_path = args.layout_path_model
- self.table_layout = lp.PaddleDetectionLayoutModel(config_path=config_path,
- model_path=model_path,
- threshold=0.5, enable_mkldnn=args.enable_mkldnn,
- enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads)
- self.use_angle_cls = args.use_angle_cls
- self.drop_score = args.drop_score
+ self.mode = args.mode
+ if self.mode == 'structure':
+ import layoutparser as lp
+ # args.det_limit_type = 'resize_long'
+ args.drop_score = 0
+ if not args.show_log:
+ logger.setLevel(logging.INFO)
+ self.text_system = TextSystem(args)
+ self.table_system = TableSystem(args,
+ self.text_system.text_detector,
+ self.text_system.text_recognizer)
+
+ config_path = None
+ model_path = None
+ if os.path.isdir(args.layout_path_model):
+ model_path = args.layout_path_model
+ else:
+ config_path = args.layout_path_model
+ self.table_layout = lp.PaddleDetectionLayoutModel(
+ config_path=config_path,
+ model_path=model_path,
+ threshold=0.5,
+ enable_mkldnn=args.enable_mkldnn,
+ enforce_cpu=not args.use_gpu,
+ thread_num=args.cpu_threads)
+ self.use_angle_cls = args.use_angle_cls
+ self.drop_score = args.drop_score
+ elif self.mode == 'vqa':
+ self.vqa_engine = SerPredictor(args)
def __call__(self, img):
- ori_im = img.copy()
- layout_res = self.table_layout.detect(img[..., ::-1])
- res_list = []
- for region in layout_res:
- x1, y1, x2, y2 = region.coordinates
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
- roi_img = ori_im[y1:y2, x1:x2, :]
- if region.type == 'Table':
- res = self.table_system(roi_img)
- else:
- filter_boxes, filter_rec_res = self.text_system(roi_img)
- filter_boxes = [x + [x1, y1] for x in filter_boxes]
- filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
- # remove style char
- style_token = ['', '', '', '', '', '', '', '',
- '', '', '', '', '', '']
- filter_rec_res_tmp = []
- for rec_res in filter_rec_res:
- rec_str, rec_conf = rec_res
- for token in style_token:
- if token in rec_str:
- rec_str = rec_str.replace(token, '')
- filter_rec_res_tmp.append((rec_str, rec_conf))
- res = (filter_boxes, filter_rec_res_tmp)
- res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'img': roi_img, 'res': res})
+ if self.mode == 'structure':
+ ori_im = img.copy()
+ layout_res = self.table_layout.detect(img[..., ::-1])
+ res_list = []
+ for region in layout_res:
+ x1, y1, x2, y2 = region.coordinates
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
+ roi_img = ori_im[y1:y2, x1:x2, :]
+ if region.type == 'Table':
+ res = self.table_system(roi_img)
+ else:
+ filter_boxes, filter_rec_res = self.text_system(roi_img)
+ filter_boxes = [x + [x1, y1] for x in filter_boxes]
+ filter_boxes = [
+ x.reshape(-1).tolist() for x in filter_boxes
+ ]
+ # remove style char
+ style_token = [
+ '', '', '', '', '',
+ '', '', '', '', '',
+ '', '', '', ''
+ ]
+ filter_rec_res_tmp = []
+ for rec_res in filter_rec_res:
+ rec_str, rec_conf = rec_res
+ for token in style_token:
+ if token in rec_str:
+ rec_str = rec_str.replace(token, '')
+ filter_rec_res_tmp.append((rec_str, rec_conf))
+ res = (filter_boxes, filter_rec_res_tmp)
+ res_list.append({
+ 'type': region.type,
+ 'bbox': [x1, y1, x2, y2],
+ 'img': roi_img,
+ 'res': res
+ })
+ elif self.mode == 'vqa':
+ res_list, _ = self.vqa_engine(img)
return res_list
@@ -91,29 +114,35 @@ def save_structure_res(res, save_folder, img_name):
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
# save res
- with open(os.path.join(excel_save_folder, 'res.txt'), 'w', encoding='utf8') as f:
+ with open(
+ os.path.join(excel_save_folder, 'res.txt'), 'w',
+ encoding='utf8') as f:
for region in res:
if region['type'] == 'Table':
- excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
+ excel_path = os.path.join(excel_save_folder,
+ '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
if region['type'] == 'Figure':
roi_img = region['img']
- img_path = os.path.join(excel_save_folder, '{}.jpg'.format(region['bbox']))
+ img_path = os.path.join(excel_save_folder,
+ '{}.jpg'.format(region['bbox']))
cv2.imwrite(img_path, roi_img)
else:
for box, rec_res in zip(region['res'][0], region['res'][1]):
- f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
+ f.write('{}\t{}\n'.format(
+ np.array(box).reshape(-1).tolist(), rec_res))
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list
image_file_list = image_file_list[args.process_id::args.total_process_num]
- save_folder = args.output
- os.makedirs(save_folder, exist_ok=True)
structure_sys = OCRSystem(args)
img_num = len(image_file_list)
+ save_folder = os.path.join(args.output, structure_sys.mode)
+ os.makedirs(save_folder, exist_ok=True)
+
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
@@ -126,10 +155,16 @@ def main(args):
continue
starttime = time.time()
res = structure_sys(img)
- save_structure_res(res, save_folder, img_name)
- draw_img = draw_structure_result(img, res, args.vis_font_path)
- cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img)
- logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
+
+ if structure_sys.mode == 'structure':
+ save_structure_res(res, save_folder, img_name)
+ draw_img = draw_structure_result(img, res, args.vis_font_path)
+ img_save_path = os.path.join(save_folder, img_name, 'show.jpg')
+ elif structure_sys.mode == 'vqa':
+ draw_img = draw_ser_results(img, res, args.vis_font_path)
+ img_save_path = os.path.join(save_folder, img_name + '.jpg')
+ cv2.imwrite(img_save_path, draw_img)
+ logger.info('result save to {}'.format(img_save_path))
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
diff --git a/ppstructure/table/README.md b/ppstructure/table/README.md
index 67c4d8e26d5c615f4a930752005420ba1abcc834..30a11a20e5de90500d1408f671ba914f336a0b43 100644
--- a/ppstructure/table/README.md
+++ b/ppstructure/table/README.md
@@ -20,9 +20,9 @@ We evaluated the algorithm on the PubTabNet[1] eval dataset, and the
|Method|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
-| --- | --- |
-| EDD[2] | 88.3 |
-| Ours | 93.32 |
+| --- | --- |
+| EDD[2] | 88.3 |
+| Ours | 93.32 |
## 3. How to use
@@ -41,7 +41,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
# run
-python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
+python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_dict_path=../ppocr/utils/dict/en_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
@@ -82,8 +82,8 @@ python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./yo
The table uses [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) as the evaluation metric of the model. Before the model evaluation, the three models in the pipeline need to be exported as inference models (we have provided them), and the gt for evaluation needs to be prepared. Examples of gt are as follows:
```json
{"PMC4289340_004_00.png": [
- ["", "", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "
", "", ""],
- [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
+ ["", "", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "
", "", ""],
+ [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
[["", "F", "e", "a", "t", "u", "r", "e", ""], ["", "G", "b", "3", " ", "+", ""], ["", "G", "b", "3", " ", "-", ""], ["", "P", "a", "t", "i", "e", "n", "t", "s", ""], ["6", "2"], ["4", "5"]]
]}
```
@@ -95,7 +95,7 @@ In gt json, the key is the image name, the value is the corresponding gt, and gt
Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output.
```python
cd PaddleOCR/ppstructure
-python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
+python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
```
If the PubLatNet eval dataset is used, it will be output
@@ -113,4 +113,4 @@ After running, the excel sheet of each picture will be saved in the directory sp
Reference
1. https://github.com/ibm-aur-nlp/PubTabNet
-2. https://arxiv.org/pdf/1911.10683
\ No newline at end of file
+2. https://arxiv.org/pdf/1911.10683
diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md
index 2e90ad33423da347b5a51444f2be53ed2eb67a7a..33276b36e4973e83d7efa673b90013cf5727dfe2 100644
--- a/ppstructure/table/README_ch.md
+++ b/ppstructure/table/README_ch.md
@@ -34,9 +34,9 @@
|算法|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|
-| --- | --- |
-| EDD[2] | 88.3 |
-| Ours | 93.32 |
+| --- | --- |
+| EDD[2] | 88.3 |
+| Ours | 93.32 |
## 3. 使用
@@ -56,7 +56,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
# 执行预测
-python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
+python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_dict_path=../ppocr/utils/dict/en_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
运行完成后,每张图片的excel表格会保存到output字段指定的目录下
@@ -94,8 +94,8 @@ python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./yo
表格使用 [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下:
```json
{"PMC4289340_004_00.png": [
- ["", "", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "
", "", ""],
- [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
+ ["", "", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "", "", "", " | ", "", " | ", "", " | ", "
", "", "
", "", ""],
+ [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]],
[["", "F", "e", "a", "t", "u", "r", "e", ""], ["", "G", "b", "3", " ", "+", ""], ["", "G", "b", "3", " ", "-", ""], ["", "P", "a", "t", "i", "e", "n", "t", "s", ""], ["6", "2"], ["4", "5"]]
]}
```
@@ -107,7 +107,7 @@ json 中,key为图片名,value为对应的gt,gt是一个由三个item组
准备完成后使用如下命令进行评估,评估完成后会输出teds指标。
```python
cd PaddleOCR/ppstructure
-python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
+python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
```
如使用PubLatNet评估数据集,将会输出
```bash
@@ -123,4 +123,4 @@ python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model
Reference
1. https://github.com/ibm-aur-nlp/PubTabNet
-2. https://arxiv.org/pdf/1911.10683
\ No newline at end of file
+2. https://arxiv.org/pdf/1911.10683
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
index 7d9fa76d0ada58e363243c114519d001de3fbf2a..ce7a801b1bb4094d3f4d2ba467332c6763ad6287 100644
--- a/ppstructure/utility.py
+++ b/ppstructure/utility.py
@@ -21,13 +21,31 @@ def init_args():
parser = infer_args()
# params for output
- parser.add_argument("--output", type=str, default='./output/table')
+ parser.add_argument("--output", type=str, default='./output')
# params for table structure
parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_model_dir", type=str)
parser.add_argument("--table_char_type", type=str, default='en')
- parser.add_argument("--table_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
- parser.add_argument("--layout_path_model", type=str, default="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config")
+ parser.add_argument(
+ "--table_char_dict_path",
+ type=str,
+ default="../ppocr/utils/dict/table_structure_dict.txt")
+ parser.add_argument(
+ "--layout_path_model",
+ type=str,
+ default="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config")
+
+ # params for ser
+ parser.add_argument("--model_name_or_path", type=str)
+ parser.add_argument("--max_seq_length", type=int, default=512)
+ parser.add_argument(
+ "--label_map_path", type=str, default='./vqa/labels/labels_ser.txt')
+
+ parser.add_argument(
+ "--mode",
+ type=str,
+ default='structure',
+ help='structure and vqa is supported')
return parser
@@ -48,5 +66,6 @@ def draw_structure_result(image, result, font_path):
boxes.append(np.array(box).reshape(-1, 2))
txts.append(rec_res[0])
scores.append(rec_res[1])
- im_show = draw_ocr_box_txt(image, boxes, txts, scores, font_path=font_path,drop_score=0)
- return im_show
\ No newline at end of file
+ im_show = draw_ocr_box_txt(
+ image, boxes, txts, scores, font_path=font_path, drop_score=0)
+ return im_show
diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..23fe28f8494ce84e774c3dd21811003f772c41f8
--- /dev/null
+++ b/ppstructure/vqa/README.md
@@ -0,0 +1,246 @@
+# 文档视觉问答(DOC-VQA)
+
+VQA指视觉问答,主要针对图像内容进行提问和回答,DOC-VQA是VQA任务中的一种,DOC-VQA主要针对文本图像的文字内容提出问题。
+
+PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进行开发。
+
+主要特性如下:
+
+- 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。
+- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对(pair)。
+- 支持SER任务和RE任务的自定义训练。
+- 支持OCR+SER的端到端系统预测与评估。
+- 支持OCR+SER+RE的端到端系统预测。
+
+
+本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现,
+包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。
+
+## 1 性能
+
+我们在 [XFUN](https://github.com/doc-analysis/XFUND) 评估数据集上对算法进行了评估,性能如下
+
+|任务| f1 | 模型下载地址|
+|:---:|:---:| :---:|
+|SER|0.9056| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)|
+|RE|0.7113| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar)|
+
+
+
+## 2. 效果演示
+
+**注意:** 测试图片来源于XFUN数据集。
+
+### 2.1 SER
+
+![](./images/result_ser/zh_val_0_ser.jpg) | ![](./images/result_ser/zh_val_42_ser.jpg)
+---|---
+
+图中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别
+
+* 深紫色:HEADER
+* 浅紫色:QUESTION
+* 军绿色:ANSWER
+
+在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
+
+
+### 2.2 RE
+
+![](./images/result_re/zh_val_21_re.jpg) | ![](./images/result_re/zh_val_40_re.jpg)
+---|---
+
+
+图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
+
+
+## 3. 安装
+
+### 3.1 安装依赖
+
+- **(1) 安装PaddlePaddle**
+
+```bash
+pip3 install --upgrade pip
+
+# GPU安装
+python3 -m pip install paddlepaddle-gpu==2.2 -i https://mirror.baidu.com/pypi/simple
+
+# CPU安装
+python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
+
+```
+更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
+
+
+### 3.2 安装PaddleOCR(包含 PP-OCR 和 VQA )
+
+- **(1)pip快速安装PaddleOCR whl包(仅预测)**
+
+```bash
+pip install paddleocr
+```
+
+- **(2)下载VQA源码(预测+训练)**
+
+```bash
+【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR
+
+# 如果因为网络问题无法pull成功,也可选择使用码云上的托管:
+git clone https://gitee.com/paddlepaddle/PaddleOCR
+
+# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
+```
+
+- **(3)安装PaddleNLP**
+
+```bash
+# 需要使用PaddleNLP最新的代码版本进行安装
+git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
+cd PaddleNLP
+pip install -e .
+```
+
+
+- **(4)安装VQA的`requirements`**
+
+```bash
+cd ppstructure/vqa
+pip install -r requirements.txt
+```
+
+## 4. 使用
+
+
+### 4.1 数据和预训练模型准备
+
+处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)。
+
+
+下载并解压该数据集,解压后将数据集放置在当前目录下。
+
+```shell
+wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
+```
+
+如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py)。
+
+如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
+
+
+### 4.2 SER任务
+
+* 启动训练
+
+```shell
+python3.7 train_ser.py \
+ --model_name_or_path "layoutxlm-base-uncased" \
+ --train_data_dir "XFUND/zh_train/image" \
+ --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
+ --eval_data_dir "XFUND/zh_val/image" \
+ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
+ --num_train_epochs 200 \
+ --eval_steps 10 \
+ --save_steps 500 \
+ --output_dir "./output/ser/" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --evaluate_during_training \
+ --seed 2048
+```
+
+最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/ser/`文件夹中。
+
+* 使用评估集合中提供的OCR识别结果进行预测
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python3.7 infer_ser.py \
+ --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
+ --output_dir "output_res/" \
+ --infer_imgs "XFUND/zh_val/image/" \
+ --ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
+```
+
+最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`。
+
+* 使用`OCR引擎 + SER`串联结果
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python3.7 infer_ser_e2e.py \
+ --model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
+ --max_seq_length 512 \
+ --output_dir "output_res_e2e/" \
+ --infer_imgs "images/input/zh_val_0.jpg"
+```
+
+* 对`OCR引擎 + SER`预测系统进行端到端评估
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
+```
+
+
+### 3.3 RE任务
+
+* 启动训练
+
+```shell
+python3 train_re.py \
+ --model_name_or_path "layoutxlm-base-uncased" \
+ --train_data_dir "XFUND/zh_train/image" \
+ --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
+ --eval_data_dir "XFUND/zh_val/image" \
+ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
+ --label_map_path 'labels/labels_ser.txt' \
+ --num_train_epochs 2 \
+ --eval_steps 10 \
+ --save_steps 500 \
+ --output_dir "output/re/" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --per_gpu_train_batch_size 8 \
+ --per_gpu_eval_batch_size 8 \
+ --evaluate_during_training \
+ --seed 2048
+
+```
+
+最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/re/`文件夹中。
+
+* 使用评估集合中提供的OCR识别结果进行预测
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python3 infer_re.py \
+ --model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
+ --max_seq_length 512 \
+ --eval_data_dir "XFUND/zh_val/image" \
+ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
+ --label_map_path 'labels/labels_ser.txt' \
+ --output_dir "output_res" \
+ --per_gpu_eval_batch_size 1 \
+ --seed 2048
+```
+
+最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`。
+
+* 使用`OCR引擎 + SER + RE`串联结果
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+# python3.7 infer_ser_re_e2e.py \
+ --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
+ --re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
+ --max_seq_length 512 \
+ --output_dir "output_ser_re_e2e_train/" \
+ --infer_imgs "images/input/zh_val_21.jpg"
+```
+
+## 参考链接
+
+- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
+- microsoft/unilm/layoutxlm, https://github.com/microsoft/unilm/tree/master/layoutxlm
+- XFUND dataset, https://github.com/doc-analysis/XFUND
diff --git a/ppstructure/vqa/data_collator.py b/ppstructure/vqa/data_collator.py
new file mode 100644
index 0000000000000000000000000000000000000000..a969935b487e3d22ea5c4a3527028aa2cfe1a797
--- /dev/null
+++ b/ppstructure/vqa/data_collator.py
@@ -0,0 +1,38 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import numbers
+import numpy as np
+
+
+class DataCollator:
+ """
+ data batch
+ """
+
+ def __call__(self, batch):
+ data_dict = {}
+ to_tensor_keys = []
+ for sample in batch:
+ for k, v in sample.items():
+ if k not in data_dict:
+ data_dict[k] = []
+ if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
+ if k not in to_tensor_keys:
+ to_tensor_keys.append(k)
+ data_dict[k].append(v)
+ for k in to_tensor_keys:
+ data_dict[k] = paddle.to_tensor(data_dict[k])
+ return data_dict
diff --git a/ppstructure/vqa/helper/eval_with_label_end2end.py b/ppstructure/vqa/helper/eval_with_label_end2end.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8dd3e0ad437e51e21ebc53daeec9fdf9aa76b63
--- /dev/null
+++ b/ppstructure/vqa/helper/eval_with_label_end2end.py
@@ -0,0 +1,262 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+import sys
+# import Polygon
+import shapely
+from shapely.geometry import Polygon
+import numpy as np
+from collections import defaultdict
+import operator
+import editdistance
+import argparse
+import json
+import copy
+
+
+def parse_ser_results_fp(fp, fp_type="gt", ignore_background=True):
+ # img/zh_val_0.jpg {
+ # "height": 3508,
+ # "width": 2480,
+ # "ocr_info": [
+ # {"text": "Maribyrnong", "label": "other", "bbox": [1958, 144, 2184, 198]},
+ # {"text": "CITYCOUNCIL", "label": "other", "bbox": [2052, 183, 2171, 214]},
+ # ]
+ assert fp_type in ["gt", "pred"]
+ key = "label" if fp_type == "gt" else "pred"
+ res_dict = dict()
+ with open(fp, "r") as fin:
+ lines = fin.readlines()
+
+ for _, line in enumerate(lines):
+ img_path, info = line.strip().split("\t")
+ # get key
+ image_name = os.path.basename(img_path)
+ res_dict[image_name] = []
+ # get infos
+ json_info = json.loads(info)
+ for single_ocr_info in json_info["ocr_info"]:
+ label = single_ocr_info[key].upper()
+ if label in ["O", "OTHERS", "OTHER"]:
+ label = "O"
+ if ignore_background and label == "O":
+ continue
+ single_ocr_info["label"] = label
+ res_dict[image_name].append(copy.deepcopy(single_ocr_info))
+ return res_dict
+
+
+def polygon_from_str(polygon_points):
+ """
+ Create a shapely polygon object from gt or dt line.
+ """
+ polygon_points = np.array(polygon_points).reshape(4, 2)
+ polygon = Polygon(polygon_points).convex_hull
+ return polygon
+
+
+def polygon_iou(poly1, poly2):
+ """
+ Intersection over union between two shapely polygons.
+ """
+ if not poly1.intersects(
+ poly2): # this test is fast and can accelerate calculation
+ iou = 0
+ else:
+ try:
+ inter_area = poly1.intersection(poly2).area
+ union_area = poly1.area + poly2.area - inter_area
+ iou = float(inter_area) / union_area
+ except shapely.geos.TopologicalError:
+ # except Exception as e:
+ # print(e)
+ print('shapely.geos.TopologicalError occured, iou set to 0')
+ iou = 0
+ return iou
+
+
+def ed(args, str1, str2):
+ if args.ignore_space:
+ str1 = str1.replace(" ", "")
+ str2 = str2.replace(" ", "")
+ if args.ignore_case:
+ str1 = str1.lower()
+ str2 = str2.lower()
+ return editdistance.eval(str1, str2)
+
+
+def convert_bbox_to_polygon(bbox):
+ """
+ bbox : [x1, y1, x2, y2]
+ output: [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
+ """
+ xmin, ymin, xmax, ymax = bbox
+ poly = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
+ return poly
+
+
+def eval_e2e(args):
+ # gt
+ gt_results = parse_ser_results_fp(args.gt_json_path, "gt",
+ args.ignore_background)
+ # pred
+ dt_results = parse_ser_results_fp(args.pred_json_path, "pred",
+ args.ignore_background)
+ assert set(gt_results.keys()) == set(dt_results.keys())
+
+ iou_thresh = args.iou_thres
+ num_gt_chars = 0
+ gt_count = 0
+ dt_count = 0
+ hit = 0
+ ed_sum = 0
+
+ for img_name in gt_results:
+ gt_info = gt_results[img_name]
+ gt_count += len(gt_info)
+
+ dt_info = dt_results[img_name]
+ dt_count += len(dt_info)
+
+ dt_match = [False] * len(dt_info)
+ gt_match = [False] * len(gt_info)
+
+ all_ious = defaultdict(tuple)
+ # gt: {text, label, bbox or poly}
+ for index_gt, gt in enumerate(gt_info):
+ if "poly" not in gt:
+ gt["poly"] = convert_bbox_to_polygon(gt["bbox"])
+ gt_poly = polygon_from_str(gt["poly"])
+ for index_dt, dt in enumerate(dt_info):
+ if "poly" not in dt:
+ dt["poly"] = convert_bbox_to_polygon(dt["bbox"])
+ dt_poly = polygon_from_str(dt["poly"])
+ iou = polygon_iou(dt_poly, gt_poly)
+ if iou >= iou_thresh:
+ all_ious[(index_gt, index_dt)] = iou
+ sorted_ious = sorted(
+ all_ious.items(), key=operator.itemgetter(1), reverse=True)
+ sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
+
+ # matched gt and dt
+ for gt_dt_pair in sorted_gt_dt_pairs:
+ index_gt, index_dt = gt_dt_pair
+ if gt_match[index_gt] == False and dt_match[index_dt] == False:
+ gt_match[index_gt] = True
+ dt_match[index_dt] = True
+ # ocr rec results
+ gt_text = gt_info[index_gt]["text"]
+ dt_text = dt_info[index_dt]["text"]
+
+ # ser results
+ gt_label = gt_info[index_gt]["label"]
+ dt_label = dt_info[index_dt]["pred"]
+
+ if True: # ignore_masks[index_gt] == '0':
+ ed_sum += ed(args, gt_text, dt_text)
+ num_gt_chars += len(gt_text)
+ if gt_text == dt_text:
+ if args.ignore_ser_prediction or gt_label == dt_label:
+ hit += 1
+
+# unmatched dt
+ for tindex, dt_match_flag in enumerate(dt_match):
+ if dt_match_flag == False:
+ dt_text = dt_info[tindex]["text"]
+ gt_text = ""
+ ed_sum += ed(args, dt_text, gt_text)
+
+# unmatched gt
+ for tindex, gt_match_flag in enumerate(gt_match):
+ if gt_match_flag == False:
+ dt_text = ""
+ gt_text = gt_info[tindex]["text"]
+ ed_sum += ed(args, gt_text, dt_text)
+ num_gt_chars += len(gt_text)
+
+ eps = 1e-9
+ print("config: ", args)
+ print('hit, dt_count, gt_count', hit, dt_count, gt_count)
+ precision = hit / (dt_count + eps)
+ recall = hit / (gt_count + eps)
+ fmeasure = 2.0 * precision * recall / (precision + recall + eps)
+ avg_edit_dist_img = ed_sum / len(gt_results)
+ avg_edit_dist_field = ed_sum / (gt_count + eps)
+ character_acc = 1 - ed_sum / (num_gt_chars + eps)
+
+ print('character_acc: %.2f' % (character_acc * 100) + "%")
+ print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
+ print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
+ print('precision: %.2f' % (precision * 100) + "%")
+ print('recall: %.2f' % (recall * 100) + "%")
+ print('fmeasure: %.2f' % (fmeasure * 100) + "%")
+
+ return
+
+
+def parse_args():
+ """
+ """
+
+ def str2bool(v):
+ return v.lower() in ("true", "t", "1")
+
+ parser = argparse.ArgumentParser()
+ ## Required parameters
+ parser.add_argument(
+ "--gt_json_path",
+ default=None,
+ type=str,
+ required=True, )
+ parser.add_argument(
+ "--pred_json_path",
+ default=None,
+ type=str,
+ required=True, )
+
+ parser.add_argument("--iou_thres", default=0.5, type=float)
+
+ parser.add_argument(
+ "--ignore_case",
+ default=False,
+ type=str2bool,
+ help="whether to do lower case for the strs")
+
+ parser.add_argument(
+ "--ignore_space",
+ default=True,
+ type=str2bool,
+ help="whether to ignore space")
+
+ parser.add_argument(
+ "--ignore_background",
+ default=True,
+ type=str2bool,
+ help="whether to ignore other label")
+
+ parser.add_argument(
+ "--ignore_ser_prediction",
+ default=False,
+ type=str2bool,
+ help="whether to ignore ocr pred results")
+
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ eval_e2e(args)
diff --git a/ppstructure/vqa/helper/trans_xfun_data.py b/ppstructure/vqa/helper/trans_xfun_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5ebd5dfbd8addda0701a7cfd2387133f7a8776b
--- /dev/null
+++ b/ppstructure/vqa/helper/trans_xfun_data.py
@@ -0,0 +1,52 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+
+def transfer_xfun_data(json_path=None, output_file=None):
+ with open(json_path, "r") as fin:
+ lines = fin.readlines()
+
+ json_info = json.loads(lines[0])
+ documents = json_info["documents"]
+ label_info = {}
+ with open(output_file, "w") as fout:
+ for idx, document in enumerate(documents):
+ img_info = document["img"]
+ document = document["document"]
+ image_path = img_info["fname"]
+
+ label_info["height"] = img_info["height"]
+ label_info["width"] = img_info["width"]
+
+ label_info["ocr_info"] = []
+
+ for doc in document:
+ label_info["ocr_info"].append({
+ "text": doc["text"],
+ "label": doc["label"],
+ "bbox": doc["box"],
+ "id": doc["id"],
+ "linking": doc["linking"],
+ "words": doc["words"]
+ })
+
+ fout.write(image_path + "\t" + json.dumps(
+ label_info, ensure_ascii=False) + "\n")
+
+ print("===ok====")
+
+
+transfer_xfun_data("./xfun/zh.val.json", "./xfun_normalize_val.json")
diff --git a/ppstructure/vqa/images/input/zh_val_0.jpg b/ppstructure/vqa/images/input/zh_val_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..479b60bcd3a859b187ce5325dfc381c1b87ee27f
Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_0.jpg differ
diff --git a/ppstructure/vqa/images/input/zh_val_21.jpg b/ppstructure/vqa/images/input/zh_val_21.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..35b572d7dd6a6b42cf43a8a4b33567c0af527d30
Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_21.jpg differ
diff --git a/ppstructure/vqa/images/input/zh_val_40.jpg b/ppstructure/vqa/images/input/zh_val_40.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2a858cc33d54831335c209146853b6c302c734f8
Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_40.jpg differ
diff --git a/ppstructure/vqa/images/input/zh_val_42.jpg b/ppstructure/vqa/images/input/zh_val_42.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..42151bdd94929ede9da1a63ce8d9339971094a46
Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_42.jpg differ
diff --git a/ppstructure/vqa/images/result_re/zh_val_21_re.jpg b/ppstructure/vqa/images/result_re/zh_val_21_re.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7bf248dd0e69057c4775ff9c205317044e94ee65
Binary files /dev/null and b/ppstructure/vqa/images/result_re/zh_val_21_re.jpg differ
diff --git a/ppstructure/vqa/images/result_re/zh_val_40_re.jpg b/ppstructure/vqa/images/result_re/zh_val_40_re.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..242f9d6e80be39c595d98b57d59d48673ce62f20
Binary files /dev/null and b/ppstructure/vqa/images/result_re/zh_val_40_re.jpg differ
diff --git a/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg b/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4605c3a7f395e9868ba55cd31a99367694c78f5c
Binary files /dev/null and b/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg differ
diff --git a/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg b/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..13bc7272e49a03115085d4a7420a7acfb92d3260
Binary files /dev/null and b/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg differ
diff --git a/ppstructure/vqa/infer_re.py b/ppstructure/vqa/infer_re.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae2f52550294b072179c3bdba28c3572369e11a3
--- /dev/null
+++ b/ppstructure/vqa/infer_re.py
@@ -0,0 +1,162 @@
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+import random
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import paddle
+
+from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
+
+from xfun import XFUNDataset
+from utils import parse_args, get_bio_label_maps, draw_re_results
+from data_collator import DataCollator
+
+from ppocr.utils.logging import get_logger
+
+
+def infer(args):
+ os.makedirs(args.output_dir, exist_ok=True)
+ logger = get_logger()
+ label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
+ pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+
+ tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
+
+ model = LayoutXLMForRelationExtraction.from_pretrained(
+ args.model_name_or_path)
+
+ eval_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.eval_data_dir,
+ label_path=args.eval_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ max_seq_len=args.max_seq_length,
+ pad_token_label_id=pad_token_label_id,
+ contains_re=True,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ eval_dataloader = paddle.io.DataLoader(
+ eval_dataset,
+ batch_size=args.per_gpu_eval_batch_size,
+ num_workers=8,
+ shuffle=False,
+ collate_fn=DataCollator())
+
+ # 读取gt的oct数据
+ ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
+
+ for idx, batch in enumerate(eval_dataloader):
+ logger.info("[Infer] process: {}/{}".format(idx, len(eval_dataloader)))
+ with paddle.no_grad():
+ outputs = model(**batch)
+ pred_relations = outputs['pred_relations']
+
+ ocr_info = ocr_info_list[idx]
+ image_path = ocr_info['image_path']
+ ocr_info = ocr_info['ocr_info']
+
+ # 根据entity里的信息,做token解码后去过滤不要的ocr_info
+ ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
+
+ # 进行 relations 到 ocr信息的转换
+ result = []
+ used_tail_id = []
+ for relations in pred_relations:
+ for relation in relations:
+ if relation['tail_id'] in used_tail_id:
+ continue
+ if relation['head_id'] not in ocr_info or relation[
+ 'tail_id'] not in ocr_info:
+ continue
+ used_tail_id.append(relation['tail_id'])
+ ocr_info_head = ocr_info[relation['head_id']]
+ ocr_info_tail = ocr_info[relation['tail_id']]
+ result.append((ocr_info_head, ocr_info_tail))
+
+ img = cv2.imread(image_path)
+ img_show = draw_re_results(img, result)
+ save_path = os.path.join(args.output_dir, os.path.basename(image_path))
+ cv2.imwrite(save_path, img_show)
+
+
+def load_ocr(img_folder, json_path):
+ import json
+ d = []
+ with open(json_path, "r") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ image_name, info_str = line.split("\t")
+ info_dict = json.loads(info_str)
+ info_dict['image_path'] = os.path.join(img_folder, image_name)
+ d.append(info_dict)
+ return d
+
+
+def filter_bg_by_txt(ocr_info, batch, tokenizer):
+ entities = batch['entities'][0]
+ input_ids = batch['input_ids'][0]
+
+ new_info_dict = {}
+ for i in range(len(entities['start'])):
+ entitie_head = entities['start'][i]
+ entitie_tail = entities['end'][i]
+ word_input_ids = input_ids[entitie_head:entitie_tail].numpy().tolist()
+ txt = tokenizer.convert_ids_to_tokens(word_input_ids)
+ txt = tokenizer.convert_tokens_to_string(txt)
+
+ for i, info in enumerate(ocr_info):
+ if info['text'] == txt:
+ new_info_dict[i] = info
+ return new_info_dict
+
+
+def post_process(pred_relations, ocr_info, img):
+ result = []
+ for relations in pred_relations:
+ for relation in relations:
+ ocr_info_head = ocr_info[relation['head_id']]
+ ocr_info_tail = ocr_info[relation['tail_id']]
+ result.append((ocr_info_head, ocr_info_tail))
+ return result
+
+
+def draw_re(result, image_path, output_folder):
+ img = cv2.imread(image_path)
+
+ from matplotlib import pyplot as plt
+ for ocr_info_head, ocr_info_tail in result:
+ cv2.rectangle(
+ img,
+ tuple(ocr_info_head['bbox'][:2]),
+ tuple(ocr_info_head['bbox'][2:]), (255, 0, 0),
+ thickness=2)
+ cv2.rectangle(
+ img,
+ tuple(ocr_info_tail['bbox'][:2]),
+ tuple(ocr_info_tail['bbox'][2:]), (0, 0, 255),
+ thickness=2)
+ center_p1 = [(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
+ (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2]
+ center_p2 = [(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
+ (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2]
+ cv2.line(
+ img, tuple(center_p1), tuple(center_p2), (0, 255, 0), thickness=2)
+ plt.imshow(img)
+ plt.savefig(
+ os.path.join(output_folder, os.path.basename(image_path)), dpi=600)
+ # plt.show()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ infer(args)
diff --git a/ppstructure/vqa/infer_ser.py b/ppstructure/vqa/infer_ser.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ad220094a26b330555fbe9122a46fb56e64fe1e
--- /dev/null
+++ b/ppstructure/vqa/infer_ser.py
@@ -0,0 +1,279 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import json
+import cv2
+import numpy as np
+from copy import deepcopy
+
+import paddle
+
+# relative reference
+from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
+
+
+def pad_sentences(tokenizer,
+ encoded_inputs,
+ max_seq_len=512,
+ pad_to_max_seq_len=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False):
+ # Padding with larger size, reshape is carried out
+ max_seq_len = (
+ len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
+
+ needs_to_be_padded = pad_to_max_seq_len and \
+ max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
+
+ if needs_to_be_padded:
+ difference = max_seq_len - len(encoded_inputs["input_ids"])
+ if tokenizer.padding_side == 'right':
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"]) + [0] * difference
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] +
+ [tokenizer.pad_token_type_id] * difference)
+ if return_special_tokens_mask:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs[
+ "special_tokens_mask"] + [1] * difference
+ encoded_inputs["input_ids"] = encoded_inputs[
+ "input_ids"] + [tokenizer.pad_token_id] * difference
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
+ ] * difference
+ else:
+ assert False, f"padding_side of tokenizer just supports [\"right\"] but got {tokenizer.padding_side}"
+ else:
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"])
+
+ return encoded_inputs
+
+
+def split_page(encoded_inputs, max_seq_len=512):
+ """
+ truncate is often used in training process
+ """
+ for key in encoded_inputs:
+ encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
+ if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
+ encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
+ else: # for bbox
+ encoded_inputs[key] = encoded_inputs[key].reshape(
+ [-1, max_seq_len, 4])
+ return encoded_inputs
+
+
+def preprocess(
+ tokenizer,
+ ori_img,
+ ocr_info,
+ img_size=(224, 224),
+ pad_token_label_id=-100,
+ max_seq_len=512,
+ add_special_ids=False,
+ return_attention_mask=True, ):
+ ocr_info = deepcopy(ocr_info)
+ height = ori_img.shape[0]
+ width = ori_img.shape[1]
+
+ img = cv2.resize(ori_img,
+ (224, 224)).transpose([2, 0, 1]).astype(np.float32)
+
+ segment_offset_id = []
+ words_list = []
+ bbox_list = []
+ input_ids_list = []
+ token_type_ids_list = []
+
+ for info in ocr_info:
+ # x1, y1, x2, y2
+ bbox = info["bbox"]
+ bbox[0] = int(bbox[0] * 1000.0 / width)
+ bbox[2] = int(bbox[2] * 1000.0 / width)
+ bbox[1] = int(bbox[1] * 1000.0 / height)
+ bbox[3] = int(bbox[3] * 1000.0 / height)
+
+ text = info["text"]
+ encode_res = tokenizer.encode(
+ text, pad_to_max_seq_len=False, return_attention_mask=True)
+
+ if not add_special_ids:
+ # TODO: use tok.all_special_ids to remove
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
+
+ input_ids_list.extend(encode_res["input_ids"])
+ token_type_ids_list.extend(encode_res["token_type_ids"])
+ bbox_list.extend([bbox] * len(encode_res["input_ids"]))
+ words_list.append(text)
+ segment_offset_id.append(len(input_ids_list))
+
+ encoded_inputs = {
+ "input_ids": input_ids_list,
+ "token_type_ids": token_type_ids_list,
+ "bbox": bbox_list,
+ "attention_mask": [1] * len(input_ids_list),
+ }
+
+ encoded_inputs = pad_sentences(
+ tokenizer,
+ encoded_inputs,
+ max_seq_len=max_seq_len,
+ return_attention_mask=return_attention_mask)
+
+ encoded_inputs = split_page(encoded_inputs)
+
+ fake_bs = encoded_inputs["input_ids"].shape[0]
+
+ encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
+ [fake_bs] + list(img.shape))
+
+ encoded_inputs["segment_offset_id"] = segment_offset_id
+
+ return encoded_inputs
+
+
+def postprocess(attention_mask, preds, label_map_path):
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ preds = np.argmax(preds, axis=2)
+
+ _, label_map = get_bio_label_maps(label_map_path)
+
+ preds_list = [[] for _ in range(preds.shape[0])]
+
+ # keep batch info
+ for i in range(preds.shape[0]):
+ for j in range(preds.shape[1]):
+ if attention_mask[i][j] == 1:
+ preds_list[i].append(label_map[preds[i][j]])
+
+ return preds_list
+
+
+def merge_preds_list_with_ocr_info(label_map_path, ocr_info, segment_offset_id,
+ preds_list):
+ # must ensure the preds_list is generated from the same image
+ preds = [p for pred in preds_list for p in pred]
+ label2id_map, _ = get_bio_label_maps(label_map_path)
+ for key in label2id_map:
+ if key.startswith("I-"):
+ label2id_map[key] = label2id_map["B" + key[1:]]
+
+ id2label_map = dict()
+ for key in label2id_map:
+ val = label2id_map[key]
+ if key == "O":
+ id2label_map[val] = key
+ if key.startswith("B-") or key.startswith("I-"):
+ id2label_map[val] = key[2:]
+ else:
+ id2label_map[val] = key
+
+ for idx in range(len(segment_offset_id)):
+ if idx == 0:
+ start_id = 0
+ else:
+ start_id = segment_offset_id[idx - 1]
+
+ end_id = segment_offset_id[idx]
+
+ curr_pred = preds[start_id:end_id]
+ curr_pred = [label2id_map[p] for p in curr_pred]
+
+ if len(curr_pred) <= 0:
+ pred_id = 0
+ else:
+ counts = np.bincount(curr_pred)
+ pred_id = np.argmax(counts)
+ ocr_info[idx]["pred_id"] = int(pred_id)
+ ocr_info[idx]["pred"] = id2label_map[pred_id]
+ return ocr_info
+
+
+@paddle.no_grad()
+def infer(args):
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # init token and model
+ tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
+ # model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
+ model = LayoutXLMForTokenClassification.from_pretrained(
+ args.model_name_or_path)
+ model.eval()
+
+ # load ocr results json
+ ocr_results = dict()
+ with open(args.ocr_json_path, "r") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ img_name, json_info = line.split("\t")
+ ocr_results[os.path.basename(img_name)] = json.loads(json_info)
+
+ # get infer img list
+ infer_imgs = get_image_file_list(args.infer_imgs)
+
+ # loop for infer
+ with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
+ for idx, img_path in enumerate(infer_imgs):
+ print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
+
+ img = cv2.imread(img_path)
+
+ ocr_info = ocr_results[os.path.basename(img_path)]["ocr_info"]
+ inputs = preprocess(
+ tokenizer=tokenizer,
+ ori_img=img,
+ ocr_info=ocr_info,
+ max_seq_len=args.max_seq_length)
+
+ outputs = model(
+ input_ids=inputs["input_ids"],
+ bbox=inputs["bbox"],
+ image=inputs["image"],
+ token_type_ids=inputs["token_type_ids"],
+ attention_mask=inputs["attention_mask"])
+
+ preds = outputs[0]
+ preds = postprocess(inputs["attention_mask"], preds,
+ args.label_map_path)
+ ocr_info = merge_preds_list_with_ocr_info(
+ args.label_map_path, ocr_info, inputs["segment_offset_id"],
+ preds)
+
+ fout.write(img_path + "\t" + json.dumps(
+ {
+ "ocr_info": ocr_info,
+ }, ensure_ascii=False) + "\n")
+
+ img_res = draw_ser_results(img, ocr_info)
+ cv2.imwrite(
+ os.path.join(args.output_dir, os.path.basename(img_path)),
+ img_res)
+
+ return
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ infer(args)
diff --git a/ppstructure/vqa/infer_ser_e2e.py b/ppstructure/vqa/infer_ser_e2e.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ebb350fd9ce90fa5a5688c34f041f67105fcf86
--- /dev/null
+++ b/ppstructure/vqa/infer_ser_e2e.py
@@ -0,0 +1,132 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import json
+import cv2
+import numpy as np
+from copy import deepcopy
+from PIL import Image
+
+import paddle
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
+
+# relative reference
+from .utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
+
+from .utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
+
+
+def trans_poly_to_bbox(poly):
+ x1 = np.min([p[0] for p in poly])
+ x2 = np.max([p[0] for p in poly])
+ y1 = np.min([p[1] for p in poly])
+ y2 = np.max([p[1] for p in poly])
+ return [x1, y1, x2, y2]
+
+
+def parse_ocr_info_for_ser(ocr_result):
+ ocr_info = []
+ for res in ocr_result:
+ ocr_info.append({
+ "text": res[1][0],
+ "bbox": trans_poly_to_bbox(res[0]),
+ "poly": res[0],
+ })
+ return ocr_info
+
+
+class SerPredictor(object):
+ def __init__(self, args):
+
+ self.max_seq_length = args.max_seq_length
+
+ # init ser token and model
+ self.tokenizer = LayoutXLMTokenizer.from_pretrained(
+ args.model_name_or_path)
+ self.model = LayoutXLMForTokenClassification.from_pretrained(
+ args.model_name_or_path)
+ self.model.eval()
+
+ # init ocr_engine
+ from paddleocr import PaddleOCR
+
+ self.ocr_engine = PaddleOCR(
+ rec_model_dir=args.rec_model_dir,
+ det_model_dir=args.det_model_dir,
+ use_angle_cls=False,
+ show_log=False)
+ # init dict
+ label2id_map, self.id2label_map = get_bio_label_maps(
+ args.label_map_path)
+ self.label2id_map_for_draw = dict()
+ for key in label2id_map:
+ if key.startswith("I-"):
+ self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
+ else:
+ self.label2id_map_for_draw[key] = label2id_map[key]
+
+ def __call__(self, img):
+ ocr_result = self.ocr_engine.ocr(img, cls=False)
+
+ ocr_info = parse_ocr_info_for_ser(ocr_result)
+
+ inputs = preprocess(
+ tokenizer=self.tokenizer,
+ ori_img=img,
+ ocr_info=ocr_info,
+ max_seq_len=self.max_seq_length)
+
+ outputs = self.model(
+ input_ids=inputs["input_ids"],
+ bbox=inputs["bbox"],
+ image=inputs["image"],
+ token_type_ids=inputs["token_type_ids"],
+ attention_mask=inputs["attention_mask"])
+
+ preds = outputs[0]
+ preds = postprocess(inputs["attention_mask"], preds, self.id2label_map)
+ ocr_info = merge_preds_list_with_ocr_info(
+ ocr_info, inputs["segment_offset_id"], preds,
+ self.label2id_map_for_draw)
+ return ocr_info, inputs
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # get infer img list
+ infer_imgs = get_image_file_list(args.infer_imgs)
+
+ # loop for infer
+ ser_engine = SerPredictor(args)
+ with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
+ for idx, img_path in enumerate(infer_imgs):
+ print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
+
+ img = cv2.imread(img_path)
+
+ result, _ = ser_engine(img)
+ fout.write(img_path + "\t" + json.dumps(
+ {
+ "ser_resule": result,
+ }, ensure_ascii=False) + "\n")
+
+ img_res = draw_ser_results(img, result)
+ cv2.imwrite(
+ os.path.join(args.output_dir,
+ os.path.splitext(os.path.basename(img_path))[0] +
+ "_ser.jpg"), img_res)
diff --git a/ppstructure/vqa/infer_ser_re_e2e.py b/ppstructure/vqa/infer_ser_re_e2e.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1d0f52eeecbc6c2ceba5964355008f638f371dd
--- /dev/null
+++ b/ppstructure/vqa/infer_ser_re_e2e.py
@@ -0,0 +1,131 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import json
+import cv2
+import numpy as np
+from copy import deepcopy
+from PIL import Image
+
+import paddle
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction
+
+# relative reference
+from utils import parse_args, get_image_file_list, draw_re_results
+from infer_ser_e2e import SerPredictor
+
+
+def make_input(ser_input, ser_result, max_seq_len=512):
+ entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
+
+ entities = ser_input['entities'][0]
+ assert len(entities) == len(ser_result)
+
+ # entities
+ start = []
+ end = []
+ label = []
+ entity_idx_dict = {}
+ for i, (res, entity) in enumerate(zip(ser_result, entities)):
+ if res['pred'] == 'O':
+ continue
+ entity_idx_dict[len(start)] = i
+ start.append(entity['start'])
+ end.append(entity['end'])
+ label.append(entities_labels[res['pred']])
+ entities = dict(start=start, end=end, label=label)
+
+ # relations
+ head = []
+ tail = []
+ for i in range(len(entities["label"])):
+ for j in range(len(entities["label"])):
+ if entities["label"][i] == 1 and entities["label"][j] == 2:
+ head.append(i)
+ tail.append(j)
+
+ relations = dict(head=head, tail=tail)
+
+ batch_size = ser_input["input_ids"].shape[0]
+ entities_batch = []
+ relations_batch = []
+ for b in range(batch_size):
+ entities_batch.append(entities)
+ relations_batch.append(relations)
+
+ ser_input['entities'] = entities_batch
+ ser_input['relations'] = relations_batch
+
+ ser_input.pop('segment_offset_id')
+ return ser_input, entity_idx_dict
+
+
+class SerReSystem(object):
+ def __init__(self, args):
+ self.ser_engine = SerPredictor(args)
+ self.tokenizer = LayoutXLMTokenizer.from_pretrained(
+ args.re_model_name_or_path)
+ self.model = LayoutXLMForRelationExtraction.from_pretrained(
+ args.re_model_name_or_path)
+ self.model.eval()
+
+ def __call__(self, img):
+ ser_result, ser_inputs = self.ser_engine(img)
+ re_input, entity_idx_dict = make_input(ser_inputs, ser_result)
+
+ re_result = self.model(**re_input)
+
+ pred_relations = re_result['pred_relations'][0]
+ # 进行 relations 到 ocr信息的转换
+ result = []
+ used_tail_id = []
+ for relation in pred_relations:
+ if relation['tail_id'] in used_tail_id:
+ continue
+ used_tail_id.append(relation['tail_id'])
+ ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
+ ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
+ result.append((ocr_info_head, ocr_info_tail))
+
+ return result
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # get infer img list
+ infer_imgs = get_image_file_list(args.infer_imgs)
+
+ # loop for infer
+ ser_re_engine = SerReSystem(args)
+ with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
+ for idx, img_path in enumerate(infer_imgs):
+ print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
+
+ img = cv2.imread(img_path)
+
+ result = ser_re_engine(img)
+ fout.write(img_path + "\t" + json.dumps(
+ {
+ "result": result,
+ }, ensure_ascii=False) + "\n")
+
+ img_res = draw_re_results(img, result)
+ cv2.imwrite(
+ os.path.join(args.output_dir,
+ os.path.splitext(os.path.basename(img_path))[0] +
+ "_re.jpg"), img_res)
diff --git a/ppstructure/vqa/labels/labels_ser.txt b/ppstructure/vqa/labels/labels_ser.txt
new file mode 100644
index 0000000000000000000000000000000000000000..508e48112412f62538baf0c78bcf99ec8945196e
--- /dev/null
+++ b/ppstructure/vqa/labels/labels_ser.txt
@@ -0,0 +1,3 @@
+QUESTION
+ANSWER
+HEADER
diff --git a/ppstructure/vqa/metric.py b/ppstructure/vqa/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb58370521296886670486982caf1202cf99a489
--- /dev/null
+++ b/ppstructure/vqa/metric.py
@@ -0,0 +1,175 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+
+import numpy as np
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+PREFIX_CHECKPOINT_DIR = "checkpoint"
+_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
+
+
+def get_last_checkpoint(folder):
+ content = os.listdir(folder)
+ checkpoints = [
+ path for path in content
+ if _re_checkpoint.search(path) is not None and os.path.isdir(
+ os.path.join(folder, path))
+ ]
+ if len(checkpoints) == 0:
+ return
+ return os.path.join(
+ folder,
+ max(checkpoints,
+ key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
+
+
+def re_score(pred_relations, gt_relations, mode="strict"):
+ """Evaluate RE predictions
+
+ Args:
+ pred_relations (list) : list of list of predicted relations (several relations in each sentence)
+ gt_relations (list) : list of list of ground truth relations
+
+ rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
+ "tail": (start_idx (inclusive), end_idx (exclusive)),
+ "head_type": ent_type,
+ "tail_type": ent_type,
+ "type": rel_type}
+
+ vocab (Vocab) : dataset vocabulary
+ mode (str) : in 'strict' or 'boundaries'"""
+
+ assert mode in ["strict", "boundaries"]
+
+ relation_types = [v for v in [0, 1] if not v == 0]
+ scores = {
+ rel: {
+ "tp": 0,
+ "fp": 0,
+ "fn": 0
+ }
+ for rel in relation_types + ["ALL"]
+ }
+
+ # Count GT relations and Predicted relations
+ n_sents = len(gt_relations)
+ n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
+ n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
+
+ # Count TP, FP and FN per type
+ for pred_sent, gt_sent in zip(pred_relations, gt_relations):
+ for rel_type in relation_types:
+ # strict mode takes argument types into account
+ if mode == "strict":
+ pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
+ rel["tail_type"])
+ for rel in pred_sent if rel["type"] == rel_type}
+ gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
+ rel["tail_type"])
+ for rel in gt_sent if rel["type"] == rel_type}
+
+ # boundaries mode only takes argument spans into account
+ elif mode == "boundaries":
+ pred_rels = {(rel["head"], rel["tail"])
+ for rel in pred_sent if rel["type"] == rel_type}
+ gt_rels = {(rel["head"], rel["tail"])
+ for rel in gt_sent if rel["type"] == rel_type}
+
+ scores[rel_type]["tp"] += len(pred_rels & gt_rels)
+ scores[rel_type]["fp"] += len(pred_rels - gt_rels)
+ scores[rel_type]["fn"] += len(gt_rels - pred_rels)
+
+ # Compute per entity Precision / Recall / F1
+ for rel_type in scores.keys():
+ if scores[rel_type]["tp"]:
+ scores[rel_type]["p"] = scores[rel_type]["tp"] / (
+ scores[rel_type]["fp"] + scores[rel_type]["tp"])
+ scores[rel_type]["r"] = scores[rel_type]["tp"] / (
+ scores[rel_type]["fn"] + scores[rel_type]["tp"])
+ else:
+ scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
+
+ if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
+ scores[rel_type]["f1"] = (
+ 2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
+ (scores[rel_type]["p"] + scores[rel_type]["r"]))
+ else:
+ scores[rel_type]["f1"] = 0
+
+ # Compute micro F1 Scores
+ tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
+ fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
+ fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
+
+ if tp:
+ precision = tp / (tp + fp)
+ recall = tp / (tp + fn)
+ f1 = 2 * precision * recall / (precision + recall)
+
+ else:
+ precision, recall, f1 = 0, 0, 0
+
+ scores["ALL"]["p"] = precision
+ scores["ALL"]["r"] = recall
+ scores["ALL"]["f1"] = f1
+ scores["ALL"]["tp"] = tp
+ scores["ALL"]["fp"] = fp
+ scores["ALL"]["fn"] = fn
+
+ # Compute Macro F1 Scores
+ scores["ALL"]["Macro_f1"] = np.mean(
+ [scores[ent_type]["f1"] for ent_type in relation_types])
+ scores["ALL"]["Macro_p"] = np.mean(
+ [scores[ent_type]["p"] for ent_type in relation_types])
+ scores["ALL"]["Macro_r"] = np.mean(
+ [scores[ent_type]["r"] for ent_type in relation_types])
+
+ # logger.info(f"RE Evaluation in *** {mode.upper()} *** mode")
+
+ # logger.info(
+ # "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(
+ # n_sents, n_rels, n_found, tp
+ # )
+ # )
+ # logger.info(
+ # "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"])
+ # )
+ # logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1))
+ # logger.info(
+ # "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
+ # scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"]
+ # )
+ # )
+
+ # for rel_type in relation_types:
+ # logger.info(
+ # "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
+ # rel_type,
+ # scores[rel_type]["tp"],
+ # scores[rel_type]["fp"],
+ # scores[rel_type]["fn"],
+ # scores[rel_type]["p"],
+ # scores[rel_type]["r"],
+ # scores[rel_type]["f1"],
+ # scores[rel_type]["tp"] + scores[rel_type]["fp"],
+ # )
+ # )
+
+ return scores
diff --git a/ppstructure/vqa/requirements.txt b/ppstructure/vqa/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c259fadc395335b336cb0ecdb5aa6bca48631987
--- /dev/null
+++ b/ppstructure/vqa/requirements.txt
@@ -0,0 +1,2 @@
+sentencepiece
+yacs
diff --git a/ppstructure/vqa/train_re.py b/ppstructure/vqa/train_re.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed19646cf57e69ac99e417ae27568655a4e00039
--- /dev/null
+++ b/ppstructure/vqa/train_re.py
@@ -0,0 +1,261 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+import random
+import numpy as np
+import paddle
+
+from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
+
+from xfun import XFUNDataset
+from utils import parse_args, get_bio_label_maps, print_arguments
+from data_collator import DataCollator
+from metric import re_score
+
+from ppocr.utils.logging import get_logger
+
+
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ paddle.seed(seed)
+
+
+def cal_metric(re_preds, re_labels, entities):
+ gt_relations = []
+ for b in range(len(re_labels)):
+ rel_sent = []
+ for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
+ rel = {}
+ rel["head_id"] = head
+ rel["head"] = (entities[b]["start"][rel["head_id"]],
+ entities[b]["end"][rel["head_id"]])
+ rel["head_type"] = entities[b]["label"][rel["head_id"]]
+
+ rel["tail_id"] = tail
+ rel["tail"] = (entities[b]["start"][rel["tail_id"]],
+ entities[b]["end"][rel["tail_id"]])
+ rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
+
+ rel["type"] = 1
+ rel_sent.append(rel)
+ gt_relations.append(rel_sent)
+ re_metrics = re_score(re_preds, gt_relations, mode="boundaries")
+ return re_metrics
+
+
+def evaluate(model, eval_dataloader, logger, prefix=""):
+ # Eval!
+ logger.info("***** Running evaluation {} *****".format(prefix))
+ logger.info(" Num examples = {}".format(len(eval_dataloader.dataset)))
+
+ re_preds = []
+ re_labels = []
+ entities = []
+ eval_loss = 0.0
+ model.eval()
+ for idx, batch in enumerate(eval_dataloader):
+ with paddle.no_grad():
+ outputs = model(**batch)
+ loss = outputs['loss'].mean().item()
+ if paddle.distributed.get_rank() == 0:
+ logger.info("[Eval] process: {}/{}, loss: {:.5f}".format(
+ idx, len(eval_dataloader), loss))
+
+ eval_loss += loss
+ re_preds.extend(outputs['pred_relations'])
+ re_labels.extend(batch['relations'])
+ entities.extend(batch['entities'])
+ re_metrics = cal_metric(re_preds, re_labels, entities)
+ re_metrics = {
+ "precision": re_metrics["ALL"]["p"],
+ "recall": re_metrics["ALL"]["r"],
+ "f1": re_metrics["ALL"]["f1"],
+ }
+ model.train()
+ return re_metrics
+
+
+def train(args):
+ logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
+ print_arguments(args, logger)
+
+ # Added here for reproducibility (even between python 2 and 3)
+ set_seed(args.seed)
+
+ label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
+ pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+
+ # dist mode
+ if paddle.distributed.get_world_size() > 1:
+ paddle.distributed.init_parallel_env()
+
+ tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
+
+ model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
+ model = LayoutXLMForRelationExtraction(model, dropout=None)
+
+ # dist mode
+ if paddle.distributed.get_world_size() > 1:
+ model = paddle.distributed.DataParallel(model)
+
+ train_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.train_data_dir,
+ label_path=args.train_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ max_seq_len=args.max_seq_length,
+ pad_token_label_id=pad_token_label_id,
+ contains_re=True,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ eval_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.eval_data_dir,
+ label_path=args.eval_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ max_seq_len=args.max_seq_length,
+ pad_token_label_id=pad_token_label_id,
+ contains_re=True,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ train_sampler = paddle.io.DistributedBatchSampler(
+ train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
+ args.train_batch_size = args.per_gpu_train_batch_size * \
+ max(1, paddle.distributed.get_world_size())
+ train_dataloader = paddle.io.DataLoader(
+ train_dataset,
+ batch_sampler=train_sampler,
+ num_workers=8,
+ use_shared_memory=True,
+ collate_fn=DataCollator())
+
+ eval_dataloader = paddle.io.DataLoader(
+ eval_dataset,
+ batch_size=args.per_gpu_eval_batch_size,
+ num_workers=8,
+ shuffle=False,
+ collate_fn=DataCollator())
+
+ t_total = len(train_dataloader) * args.num_train_epochs
+
+ # build linear decay with warmup lr sch
+ lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
+ learning_rate=args.learning_rate,
+ decay_steps=t_total,
+ end_lr=0.0,
+ power=1.0)
+ if args.warmup_steps > 0:
+ lr_scheduler = paddle.optimizer.lr.LinearWarmup(
+ lr_scheduler,
+ args.warmup_steps,
+ start_lr=0,
+ end_lr=args.learning_rate, )
+ grad_clip = paddle.nn.ClipGradByNorm(clip_norm=10)
+ optimizer = paddle.optimizer.Adam(
+ learning_rate=args.learning_rate,
+ parameters=model.parameters(),
+ epsilon=args.adam_epsilon,
+ grad_clip=grad_clip,
+ weight_decay=args.weight_decay)
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(" Num examples = {}".format(len(train_dataset)))
+ logger.info(" Num Epochs = {}".format(args.num_train_epochs))
+ logger.info(" Instantaneous batch size per GPU = {}".format(
+ args.per_gpu_train_batch_size))
+ logger.info(
+ " Total train batch size (w. parallel, distributed & accumulation) = {}".
+ format(args.train_batch_size * paddle.distributed.get_world_size()))
+ logger.info(" Total optimization steps = {}".format(t_total))
+
+ global_step = 0
+ model.clear_gradients()
+ train_dataloader_len = len(train_dataloader)
+ best_metirc = {'f1': 0}
+ model.train()
+
+ for epoch in range(int(args.num_train_epochs)):
+ for step, batch in enumerate(train_dataloader):
+ outputs = model(**batch)
+ # model outputs are always tuple in ppnlp (see doc)
+ loss = outputs['loss']
+ loss = loss.mean()
+
+ logger.info(
+ "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}".
+ format(epoch, args.num_train_epochs, step, train_dataloader_len,
+ global_step, np.mean(loss.numpy()), optimizer.get_lr()))
+
+ loss.backward()
+ optimizer.step()
+ optimizer.clear_grad()
+ # lr_scheduler.step() # Update learning rate schedule
+
+ global_step += 1
+
+ if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and
+ global_step % args.eval_steps == 0):
+ # Log metrics
+ if (paddle.distributed.get_rank() == 0 and args.
+ evaluate_during_training): # Only evaluate when single GPU otherwise metrics may not average well
+ results = evaluate(model, eval_dataloader, logger)
+ if results['f1'] > best_metirc['f1']:
+ best_metirc = results
+ output_dir = os.path.join(args.output_dir,
+ "checkpoint-best")
+ os.makedirs(output_dir, exist_ok=True)
+ model.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ paddle.save(args,
+ os.path.join(output_dir,
+ "training_args.bin"))
+ logger.info("Saving model checkpoint to {}".format(
+ output_dir))
+ logger.info("eval results: {}".format(results))
+ logger.info("best_metirc: {}".format(best_metirc))
+
+ if (paddle.distributed.get_rank() == 0 and args.save_steps > 0 and
+ global_step % args.save_steps == 0):
+ # Save model checkpoint
+ output_dir = os.path.join(args.output_dir, "checkpoint-latest")
+ os.makedirs(output_dir, exist_ok=True)
+ if paddle.distributed.get_rank() == 0:
+ model.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ paddle.save(args,
+ os.path.join(output_dir, "training_args.bin"))
+ logger.info("Saving model checkpoint to {}".format(
+ output_dir))
+ logger.info("best_metirc: {}".format(best_metirc))
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+ train(args)
diff --git a/ppstructure/vqa/train_ser.py b/ppstructure/vqa/train_ser.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3144e7167c59b5883047a948abaedfd21ba9b1c
--- /dev/null
+++ b/ppstructure/vqa/train_ser.py
@@ -0,0 +1,298 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+import random
+import copy
+import logging
+
+import argparse
+import paddle
+import numpy as np
+from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
+from xfun import XFUNDataset
+from utils import parse_args
+from utils import get_bio_label_maps
+from utils import print_arguments
+
+from ppocr.utils.logging import get_logger
+
+
+def set_seed(args):
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ paddle.seed(args.seed)
+
+
+def train(args):
+ os.makedirs(args.output_dir, exist_ok=True)
+ logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
+ print_arguments(args, logger)
+
+ label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
+ pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+
+ # dist mode
+ if paddle.distributed.get_world_size() > 1:
+ paddle.distributed.init_parallel_env()
+
+ tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
+ base_model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
+ model = LayoutXLMForTokenClassification(
+ base_model, num_classes=len(label2id_map), dropout=None)
+
+ # dist mode
+ if paddle.distributed.get_world_size() > 1:
+ model = paddle.DataParallel(model)
+
+ train_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.train_data_dir,
+ label_path=args.train_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ pad_token_label_id=pad_token_label_id,
+ contains_re=False,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ train_sampler = paddle.io.DistributedBatchSampler(
+ train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
+
+ args.train_batch_size = args.per_gpu_train_batch_size * max(
+ 1, paddle.distributed.get_world_size())
+
+ train_dataloader = paddle.io.DataLoader(
+ train_dataset,
+ batch_sampler=train_sampler,
+ num_workers=0,
+ use_shared_memory=True,
+ collate_fn=None, )
+
+ t_total = len(train_dataloader) * args.num_train_epochs
+
+ # build linear decay with warmup lr sch
+ lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
+ learning_rate=args.learning_rate,
+ decay_steps=t_total,
+ end_lr=0.0,
+ power=1.0)
+ if args.warmup_steps > 0:
+ lr_scheduler = paddle.optimizer.lr.LinearWarmup(
+ lr_scheduler,
+ args.warmup_steps,
+ start_lr=0,
+ end_lr=args.learning_rate, )
+
+ optimizer = paddle.optimizer.AdamW(
+ learning_rate=lr_scheduler,
+ parameters=model.parameters(),
+ epsilon=args.adam_epsilon,
+ weight_decay=args.weight_decay)
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(" Num examples = %d", len(train_dataset))
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
+ logger.info(" Instantaneous batch size per GPU = %d",
+ args.per_gpu_train_batch_size)
+ logger.info(
+ " Total train batch size (w. parallel, distributed) = %d",
+ args.train_batch_size * paddle.distributed.get_world_size(), )
+ logger.info(" Total optimization steps = %d", t_total)
+
+ global_step = 0
+ tr_loss = 0.0
+ set_seed(args)
+ best_metrics = None
+
+ for epoch_id in range(args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ model.train()
+ outputs = model(**batch)
+ # model outputs are always tuple in ppnlp (see doc)
+ loss = outputs[0]
+ loss = loss.mean()
+ logger.info(
+ "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}".
+ format(epoch_id, args.num_train_epochs, step,
+ len(train_dataloader), global_step,
+ loss.numpy()[0], lr_scheduler.get_lr()))
+
+ loss.backward()
+ tr_loss += loss.item()
+ optimizer.step()
+ lr_scheduler.step() # Update learning rate schedule
+ optimizer.clear_grad()
+ global_step += 1
+
+ if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and
+ global_step % args.eval_steps == 0):
+ # Log metrics
+ # Only evaluate when single GPU otherwise metrics may not average well
+ if paddle.distributed.get_rank(
+ ) == 0 and args.evaluate_during_training:
+ results, _ = evaluate(args, model, tokenizer, label2id_map,
+ id2label_map, pad_token_label_id,
+ logger)
+
+ if best_metrics is None or results["f1"] >= best_metrics[
+ "f1"]:
+ best_metrics = copy.deepcopy(results)
+ output_dir = os.path.join(args.output_dir, "best_model")
+ os.makedirs(output_dir, exist_ok=True)
+ if paddle.distributed.get_rank() == 0:
+ model.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ paddle.save(
+ args,
+ os.path.join(output_dir, "training_args.bin"))
+ logger.info("Saving model checkpoint to %s",
+ output_dir)
+
+ logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format(
+ epoch_id, args.num_train_epochs, step,
+ len(train_dataloader), results))
+ if best_metrics is not None:
+ logger.info("best metrics: {}".format(best_metrics))
+
+ if paddle.distributed.get_rank(
+ ) == 0 and args.save_steps > 0 and global_step % args.save_steps == 0:
+ # Save model checkpoint
+ output_dir = os.path.join(args.output_dir,
+ "checkpoint-{}".format(global_step))
+ os.makedirs(output_dir, exist_ok=True)
+ if paddle.distributed.get_rank() == 0:
+ model.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ paddle.save(args,
+ os.path.join(output_dir, "training_args.bin"))
+ logger.info("Saving model checkpoint to %s", output_dir)
+
+ return global_step, tr_loss / global_step
+
+
+def evaluate(args,
+ model,
+ tokenizer,
+ label2id_map,
+ id2label_map,
+ pad_token_label_id,
+ logger,
+ prefix=""):
+ eval_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.eval_data_dir,
+ label_path=args.eval_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ pad_token_label_id=pad_token_label_id,
+ contains_re=False,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(
+ 1, paddle.distributed.get_world_size())
+
+ eval_dataloader = paddle.io.DataLoader(
+ eval_dataset,
+ batch_size=args.eval_batch_size,
+ num_workers=0,
+ use_shared_memory=True,
+ collate_fn=None, )
+
+ # Eval!
+ logger.info("***** Running evaluation %s *****", prefix)
+ logger.info(" Num examples = %d", len(eval_dataset))
+ logger.info(" Batch size = %d", args.eval_batch_size)
+ eval_loss = 0.0
+ nb_eval_steps = 0
+ preds = None
+ out_label_ids = None
+ model.eval()
+ for idx, batch in enumerate(eval_dataloader):
+ with paddle.no_grad():
+ outputs = model(**batch)
+ tmp_eval_loss, logits = outputs[:2]
+
+ tmp_eval_loss = tmp_eval_loss.mean()
+
+ if paddle.distributed.get_rank() == 0:
+ logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
+ idx, len(eval_dataloader), tmp_eval_loss.numpy()[0]))
+
+ eval_loss += tmp_eval_loss.item()
+ nb_eval_steps += 1
+ if preds is None:
+ preds = logits.numpy()
+ out_label_ids = batch["labels"].numpy()
+ else:
+ preds = np.append(preds, logits.numpy(), axis=0)
+ out_label_ids = np.append(
+ out_label_ids, batch["labels"].numpy(), axis=0)
+
+ eval_loss = eval_loss / nb_eval_steps
+ preds = np.argmax(preds, axis=2)
+
+ # label_map = {i: label.upper() for i, label in enumerate(labels)}
+
+ out_label_list = [[] for _ in range(out_label_ids.shape[0])]
+ preds_list = [[] for _ in range(out_label_ids.shape[0])]
+
+ for i in range(out_label_ids.shape[0]):
+ for j in range(out_label_ids.shape[1]):
+ if out_label_ids[i, j] != pad_token_label_id:
+ out_label_list[i].append(id2label_map[out_label_ids[i][j]])
+ preds_list[i].append(id2label_map[preds[i][j]])
+
+ results = {
+ "loss": eval_loss,
+ "precision": precision_score(out_label_list, preds_list),
+ "recall": recall_score(out_label_list, preds_list),
+ "f1": f1_score(out_label_list, preds_list),
+ }
+
+ with open(os.path.join(args.output_dir, "test_gt.txt"), "w") as fout:
+ for lbl in out_label_list:
+ for l in lbl:
+ fout.write(l + "\t")
+ fout.write("\n")
+ with open(os.path.join(args.output_dir, "test_pred.txt"), "w") as fout:
+ for lbl in preds_list:
+ for l in lbl:
+ fout.write(l + "\t")
+ fout.write("\n")
+
+ report = classification_report(out_label_list, preds_list)
+ logger.info("\n" + report)
+
+ logger.info("***** Eval results %s *****", prefix)
+ for key in sorted(results.keys()):
+ logger.info(" %s = %s", key, str(results[key]))
+
+ return results, preds_list
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ train(args)
diff --git a/ppstructure/vqa/utils.py b/ppstructure/vqa/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4db20d5cbcb6cf510bb794bb0e7d836da028b2f
--- /dev/null
+++ b/ppstructure/vqa/utils.py
@@ -0,0 +1,392 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import cv2
+import random
+import numpy as np
+import imghdr
+from copy import deepcopy
+
+import paddle
+
+from PIL import Image, ImageDraw, ImageFont
+
+
+def get_bio_label_maps(label_map_path):
+ with open(label_map_path, "r") as fin:
+ lines = fin.readlines()
+ lines = [line.strip() for line in lines]
+ if "O" not in lines:
+ lines.insert(0, "O")
+ labels = []
+ for line in lines:
+ if line == "O":
+ labels.append("O")
+ else:
+ labels.append("B-" + line)
+ labels.append("I-" + line)
+ label2id_map = {label: idx for idx, label in enumerate(labels)}
+ id2label_map = {idx: label for idx, label in enumerate(labels)}
+ return label2id_map, id2label_map
+
+
+def get_image_file_list(img_file):
+ imgs_lists = []
+ if img_file is None or not os.path.exists(img_file):
+ raise Exception("not found any img file in {}".format(img_file))
+
+ img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
+ if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
+ imgs_lists.append(img_file)
+ elif os.path.isdir(img_file):
+ for single_file in os.listdir(img_file):
+ file_path = os.path.join(img_file, single_file)
+ if os.path.isfile(file_path) and imghdr.what(file_path) in img_end:
+ imgs_lists.append(file_path)
+ if len(imgs_lists) == 0:
+ raise Exception("not found any img file in {}".format(img_file))
+ imgs_lists = sorted(imgs_lists)
+ return imgs_lists
+
+
+def draw_ser_results(image,
+ ocr_results,
+ font_path="../../doc/fonts/simfang.ttf",
+ font_size=18):
+ np.random.seed(2021)
+ color = (np.random.permutation(range(255)),
+ np.random.permutation(range(255)),
+ np.random.permutation(range(255)))
+ color_map = {
+ idx: (color[0][idx], color[1][idx], color[2][idx])
+ for idx in range(1, 255)
+ }
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ img_new = image.copy()
+ draw = ImageDraw.Draw(img_new)
+
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+ for ocr_info in ocr_results:
+ if ocr_info["pred_id"] not in color_map:
+ continue
+ color = color_map[ocr_info["pred_id"]]
+ text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
+
+ draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
+
+ img_new = Image.blend(image, img_new, 0.5)
+ return np.array(img_new)
+
+
+def draw_box_txt(bbox, text, draw, font, font_size, color):
+ # draw ocr results outline
+ bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
+ draw.rectangle(bbox, fill=color)
+
+ # draw ocr results
+ start_y = max(0, bbox[0][1] - font_size)
+ tw = font.getsize(text)[0]
+ draw.rectangle(
+ [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
+ fill=(0, 0, 255))
+ draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
+
+
+def draw_re_results(image,
+ result,
+ font_path="../../doc/fonts/simfang.ttf",
+ font_size=18):
+ np.random.seed(0)
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ img_new = image.copy()
+ draw = ImageDraw.Draw(img_new)
+
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+ color_head = (0, 0, 255)
+ color_tail = (255, 0, 0)
+ color_line = (0, 255, 0)
+
+ for ocr_info_head, ocr_info_tail in result:
+ draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
+ font_size, color_head)
+ draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
+ font_size, color_tail)
+
+ center_head = (
+ (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
+ (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
+ center_tail = (
+ (ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
+ (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
+
+ draw.line([center_head, center_tail], fill=color_line, width=5)
+
+ img_new = Image.blend(image, img_new, 0.5)
+ return np.array(img_new)
+
+
+# pad sentences
+def pad_sentences(tokenizer,
+ encoded_inputs,
+ max_seq_len=512,
+ pad_to_max_seq_len=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False):
+ # Padding with larger size, reshape is carried out
+ max_seq_len = (
+ len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
+
+ needs_to_be_padded = pad_to_max_seq_len and \
+ max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
+
+ if needs_to_be_padded:
+ difference = max_seq_len - len(encoded_inputs["input_ids"])
+ if tokenizer.padding_side == 'right':
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"]) + [0] * difference
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] +
+ [tokenizer.pad_token_type_id] * difference)
+ if return_special_tokens_mask:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs[
+ "special_tokens_mask"] + [1] * difference
+ encoded_inputs["input_ids"] = encoded_inputs[
+ "input_ids"] + [tokenizer.pad_token_id] * difference
+ encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
+ ] * difference
+ else:
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"])
+
+ return encoded_inputs
+
+
+def split_page(encoded_inputs, max_seq_len=512):
+ """
+ truncate is often used in training process
+ """
+ for key in encoded_inputs:
+ if key == 'entities':
+ encoded_inputs[key] = [encoded_inputs[key]]
+ continue
+ encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
+ if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
+ encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
+ else: # for bbox
+ encoded_inputs[key] = encoded_inputs[key].reshape(
+ [-1, max_seq_len, 4])
+ return encoded_inputs
+
+
+def preprocess(
+ tokenizer,
+ ori_img,
+ ocr_info,
+ img_size=(224, 224),
+ pad_token_label_id=-100,
+ max_seq_len=512,
+ add_special_ids=False,
+ return_attention_mask=True, ):
+ ocr_info = deepcopy(ocr_info)
+ height = ori_img.shape[0]
+ width = ori_img.shape[1]
+
+ img = cv2.resize(ori_img, img_size).transpose([2, 0, 1]).astype(np.float32)
+
+ segment_offset_id = []
+ words_list = []
+ bbox_list = []
+ input_ids_list = []
+ token_type_ids_list = []
+ entities = []
+
+ for info in ocr_info:
+ # x1, y1, x2, y2
+ bbox = info["bbox"]
+ bbox[0] = int(bbox[0] * 1000.0 / width)
+ bbox[2] = int(bbox[2] * 1000.0 / width)
+ bbox[1] = int(bbox[1] * 1000.0 / height)
+ bbox[3] = int(bbox[3] * 1000.0 / height)
+
+ text = info["text"]
+ encode_res = tokenizer.encode(
+ text, pad_to_max_seq_len=False, return_attention_mask=True)
+
+ if not add_special_ids:
+ # TODO: use tok.all_special_ids to remove
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
+
+ # for re
+ entities.append({
+ "start": len(input_ids_list),
+ "end": len(input_ids_list) + len(encode_res["input_ids"]),
+ "label": "O",
+ })
+
+ input_ids_list.extend(encode_res["input_ids"])
+ token_type_ids_list.extend(encode_res["token_type_ids"])
+ bbox_list.extend([bbox] * len(encode_res["input_ids"]))
+ words_list.append(text)
+ segment_offset_id.append(len(input_ids_list))
+
+ encoded_inputs = {
+ "input_ids": input_ids_list,
+ "token_type_ids": token_type_ids_list,
+ "bbox": bbox_list,
+ "attention_mask": [1] * len(input_ids_list),
+ "entities": entities
+ }
+
+ encoded_inputs = pad_sentences(
+ tokenizer,
+ encoded_inputs,
+ max_seq_len=max_seq_len,
+ return_attention_mask=return_attention_mask)
+
+ encoded_inputs = split_page(encoded_inputs)
+
+ fake_bs = encoded_inputs["input_ids"].shape[0]
+
+ encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
+ [fake_bs] + list(img.shape))
+
+ encoded_inputs["segment_offset_id"] = segment_offset_id
+
+ return encoded_inputs
+
+
+def postprocess(attention_mask, preds, id2label_map):
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ preds = np.argmax(preds, axis=2)
+
+ preds_list = [[] for _ in range(preds.shape[0])]
+
+ # keep batch info
+ for i in range(preds.shape[0]):
+ for j in range(preds.shape[1]):
+ if attention_mask[i][j] == 1:
+ preds_list[i].append(id2label_map[preds[i][j]])
+
+ return preds_list
+
+
+def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list,
+ label2id_map_for_draw):
+ # must ensure the preds_list is generated from the same image
+ preds = [p for pred in preds_list for p in pred]
+
+ id2label_map = dict()
+ for key in label2id_map_for_draw:
+ val = label2id_map_for_draw[key]
+ if key == "O":
+ id2label_map[val] = key
+ if key.startswith("B-") or key.startswith("I-"):
+ id2label_map[val] = key[2:]
+ else:
+ id2label_map[val] = key
+
+ for idx in range(len(segment_offset_id)):
+ if idx == 0:
+ start_id = 0
+ else:
+ start_id = segment_offset_id[idx - 1]
+
+ end_id = segment_offset_id[idx]
+
+ curr_pred = preds[start_id:end_id]
+ curr_pred = [label2id_map_for_draw[p] for p in curr_pred]
+
+ if len(curr_pred) <= 0:
+ pred_id = 0
+ else:
+ counts = np.bincount(curr_pred)
+ pred_id = np.argmax(counts)
+ ocr_info[idx]["pred_id"] = int(pred_id)
+ ocr_info[idx]["pred"] = id2label_map[int(pred_id)]
+ return ocr_info
+
+
+def print_arguments(args, logger=None):
+ print_func = logger.info if logger is not None else print
+ """print arguments"""
+ print_func('----------- Configuration Arguments -----------')
+ for arg, value in sorted(vars(args).items()):
+ print_func('%s: %s' % (arg, value))
+ print_func('------------------------------------------------')
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ # yapf: disable
+ parser.add_argument("--model_name_or_path",
+ default=None, type=str, required=True,)
+ parser.add_argument("--re_model_name_or_path",
+ default=None, type=str, required=False,)
+ parser.add_argument("--train_data_dir", default=None,
+ type=str, required=False,)
+ parser.add_argument("--train_label_path", default=None,
+ type=str, required=False,)
+ parser.add_argument("--eval_data_dir", default=None,
+ type=str, required=False,)
+ parser.add_argument("--eval_label_path", default=None,
+ type=str, required=False,)
+ parser.add_argument("--output_dir", default=None, type=str, required=True,)
+ parser.add_argument("--max_seq_length", default=512, type=int,)
+ parser.add_argument("--evaluate_during_training", action="store_true",)
+ parser.add_argument("--per_gpu_train_batch_size", default=8,
+ type=int, help="Batch size per GPU/CPU for training.",)
+ parser.add_argument("--per_gpu_eval_batch_size", default=8,
+ type=int, help="Batch size per GPU/CPU for eval.",)
+ parser.add_argument("--learning_rate", default=5e-5,
+ type=float, help="The initial learning rate for Adam.",)
+ parser.add_argument("--weight_decay", default=0.0,
+ type=float, help="Weight decay if we apply some.",)
+ parser.add_argument("--adam_epsilon", default=1e-8,
+ type=float, help="Epsilon for Adam optimizer.",)
+ parser.add_argument("--max_grad_norm", default=1.0,
+ type=float, help="Max gradient norm.",)
+ parser.add_argument("--num_train_epochs", default=3, type=int,
+ help="Total number of training epochs to perform.",)
+ parser.add_argument("--warmup_steps", default=0, type=int,
+ help="Linear warmup over warmup_steps.",)
+ parser.add_argument("--eval_steps", type=int, default=10,
+ help="eval every X updates steps.",)
+ parser.add_argument("--save_steps", type=int, default=50,
+ help="Save checkpoint every X updates steps.",)
+ parser.add_argument("--seed", type=int, default=2048,
+ help="random seed for initialization",)
+
+ parser.add_argument("--rec_model_dir", default=None, type=str, )
+ parser.add_argument("--det_model_dir", default=None, type=str, )
+ parser.add_argument(
+ "--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
+ parser.add_argument("--infer_imgs", default=None, type=str, required=False)
+ parser.add_argument("--ocr_json_path", default=None,
+ type=str, required=False, help="ocr prediction results")
+ # yapf: enable
+ args = parser.parse_args()
+ return args
diff --git a/ppstructure/vqa/xfun.py b/ppstructure/vqa/xfun.py
new file mode 100644
index 0000000000000000000000000000000000000000..d62cdb5da5514280b62687d80d345ede9484ee90
--- /dev/null
+++ b/ppstructure/vqa/xfun.py
@@ -0,0 +1,442 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import cv2
+import numpy as np
+import paddle
+import copy
+from paddle.io import Dataset
+
+__all__ = ["XFUNDataset"]
+
+
+class XFUNDataset(Dataset):
+ """
+ Example:
+ print("=====begin to build dataset=====")
+ from paddlenlp.transformers import LayoutXLMTokenizer
+ tokenizer = LayoutXLMTokenizer.from_pretrained("/paddle/models/transformers/layoutxlm-base-paddle/")
+ tok_res = tokenizer.tokenize("Maribyrnong")
+ # res = tokenizer.convert_ids_to_tokens(val_data["input_ids"][0])
+ dataset = XfunDatasetForSer(
+ tokenizer,
+ data_dir="./zh.val/",
+ label_path="zh.val/xfun_normalize_val.json",
+ img_size=(224,224))
+ print(len(dataset))
+
+ data = dataset[0]
+ print(data.keys())
+ print("input_ids: ", data["input_ids"])
+ print("labels: ", data["labels"])
+ print("token_type_ids: ", data["token_type_ids"])
+ print("words_list: ", data["words_list"])
+ print("image shape: ", data["image"].shape)
+ """
+
+ def __init__(self,
+ tokenizer,
+ data_dir,
+ label_path,
+ contains_re=False,
+ label2id_map=None,
+ img_size=(224, 224),
+ pad_token_label_id=None,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all',
+ max_seq_len=512):
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.data_dir = data_dir
+ self.label_path = label_path
+ self.contains_re = contains_re
+ self.label2id_map = label2id_map
+ self.img_size = img_size
+ self.pad_token_label_id = pad_token_label_id
+ self.add_special_ids = add_special_ids
+ self.return_attention_mask = return_attention_mask
+ self.load_mode = load_mode
+ self.max_seq_len = max_seq_len
+
+ if self.pad_token_label_id is None:
+ self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+
+ self.all_lines = self.read_all_lines()
+
+ self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
+ self.return_keys = {
+ 'bbox': 'np',
+ 'input_ids': 'np',
+ 'labels': 'np',
+ 'attention_mask': 'np',
+ 'image': 'np',
+ 'token_type_ids': 'np',
+ 'entities': 'dict',
+ 'relations': 'dict',
+ }
+
+ if load_mode == "all":
+ self.encoded_inputs_all = self._parse_label_file_all()
+
+ def pad_sentences(self,
+ encoded_inputs,
+ max_seq_len=512,
+ pad_to_max_seq_len=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ truncation_strategy="longest_first",
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False):
+ # Padding
+ needs_to_be_padded = pad_to_max_seq_len and \
+ max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
+
+ if needs_to_be_padded:
+ difference = max_seq_len - len(encoded_inputs["input_ids"])
+ if self.tokenizer.padding_side == 'right':
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"]) + [0] * difference
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] +
+ [self.tokenizer.pad_token_type_id] * difference)
+ if return_special_tokens_mask:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs[
+ "special_tokens_mask"] + [1] * difference
+ encoded_inputs["input_ids"] = encoded_inputs[
+ "input_ids"] + [self.tokenizer.pad_token_id] * difference
+ encoded_inputs["labels"] = encoded_inputs[
+ "labels"] + [self.pad_token_label_id] * difference
+ encoded_inputs["bbox"] = encoded_inputs[
+ "bbox"] + [[0, 0, 0, 0]] * difference
+ elif self.tokenizer.padding_side == 'left':
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [0] * difference + [
+ 1
+ ] * len(encoded_inputs["input_ids"])
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = (
+ [self.tokenizer.pad_token_type_id] * difference +
+ encoded_inputs["token_type_ids"])
+ if return_special_tokens_mask:
+ encoded_inputs["special_tokens_mask"] = [
+ 1
+ ] * difference + encoded_inputs["special_tokens_mask"]
+ encoded_inputs["input_ids"] = [
+ self.tokenizer.pad_token_id
+ ] * difference + encoded_inputs["input_ids"]
+ encoded_inputs["labels"] = [
+ self.pad_token_label_id
+ ] * difference + encoded_inputs["labels"]
+ encoded_inputs["bbox"] = [
+ [0, 0, 0, 0]
+ ] * difference + encoded_inputs["bbox"]
+ else:
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
+ "input_ids"])
+
+ return encoded_inputs
+
+ def truncate_inputs(self, encoded_inputs, max_seq_len=512):
+ for key in encoded_inputs:
+ if key == "sample_id":
+ continue
+ length = min(len(encoded_inputs[key]), max_seq_len)
+ encoded_inputs[key] = encoded_inputs[key][:length]
+ return encoded_inputs
+
+ def read_all_lines(self, ):
+ with open(self.label_path, "r") as fin:
+ lines = fin.readlines()
+ return lines
+
+ def _parse_label_file_all(self):
+ """
+ parse all samples
+ """
+ encoded_inputs_all = []
+ for line in self.all_lines:
+ encoded_inputs_all.extend(self._parse_label_file(line))
+ return encoded_inputs_all
+
+ def _parse_label_file(self, line):
+ """
+ parse single sample
+ """
+
+ image_name, info_str = line.split("\t")
+ image_path = os.path.join(self.data_dir, image_name)
+
+ def add_imgge_path(x):
+ x['image_path'] = image_path
+ return x
+
+ encoded_inputs = self._read_encoded_inputs_sample(info_str)
+ if self.contains_re:
+ encoded_inputs = self._chunk_re(encoded_inputs)
+ else:
+ encoded_inputs = self._chunk_ser(encoded_inputs)
+ encoded_inputs = list(map(add_imgge_path, encoded_inputs))
+ return encoded_inputs
+
+ def _read_encoded_inputs_sample(self, info_str):
+ """
+ parse label info
+ """
+ # read text info
+ info_dict = json.loads(info_str)
+ height = info_dict["height"]
+ width = info_dict["width"]
+
+ words_list = []
+ bbox_list = []
+ input_ids_list = []
+ token_type_ids_list = []
+ gt_label_list = []
+
+ if self.contains_re:
+ # for re
+ entities = []
+ relations = []
+ id2label = {}
+ entity_id_to_index_map = {}
+ empty_entity = set()
+ for info in info_dict["ocr_info"]:
+ if self.contains_re:
+ # for re
+ if len(info["text"]) == 0:
+ empty_entity.add(info["id"])
+ continue
+ id2label[info["id"]] = info["label"]
+ relations.extend([tuple(sorted(l)) for l in info["linking"]])
+
+ # x1, y1, x2, y2
+ bbox = info["bbox"]
+ label = info["label"]
+ bbox[0] = int(bbox[0] * 1000.0 / width)
+ bbox[2] = int(bbox[2] * 1000.0 / width)
+ bbox[1] = int(bbox[1] * 1000.0 / height)
+ bbox[3] = int(bbox[3] * 1000.0 / height)
+
+ text = info["text"]
+ encode_res = self.tokenizer.encode(
+ text, pad_to_max_seq_len=False, return_attention_mask=True)
+
+ gt_label = []
+ if not self.add_special_ids:
+ # TODO: use tok.all_special_ids to remove
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
+ -1]
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:
+ -1]
+ if label.lower() == "other":
+ gt_label.extend([0] * len(encode_res["input_ids"]))
+ else:
+ gt_label.append(self.label2id_map[("b-" + label).upper()])
+ gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
+ (len(encode_res["input_ids"]) - 1))
+ if self.contains_re:
+ if gt_label[0] != self.label2id_map["O"]:
+ entity_id_to_index_map[info["id"]] = len(entities)
+ entities.append({
+ "start": len(input_ids_list),
+ "end":
+ len(input_ids_list) + len(encode_res["input_ids"]),
+ "label": label.upper(),
+ })
+ input_ids_list.extend(encode_res["input_ids"])
+ token_type_ids_list.extend(encode_res["token_type_ids"])
+ bbox_list.extend([bbox] * len(encode_res["input_ids"]))
+ gt_label_list.extend(gt_label)
+ words_list.append(text)
+
+ encoded_inputs = {
+ "input_ids": input_ids_list,
+ "labels": gt_label_list,
+ "token_type_ids": token_type_ids_list,
+ "bbox": bbox_list,
+ "attention_mask": [1] * len(input_ids_list),
+ # "words_list": words_list,
+ }
+ encoded_inputs = self.pad_sentences(
+ encoded_inputs,
+ max_seq_len=self.max_seq_len,
+ return_attention_mask=self.return_attention_mask)
+ encoded_inputs = self.truncate_inputs(encoded_inputs)
+
+ if self.contains_re:
+ relations = self._relations(entities, relations, id2label,
+ empty_entity, entity_id_to_index_map)
+ encoded_inputs['relations'] = relations
+ encoded_inputs['entities'] = entities
+ return encoded_inputs
+
+ def _chunk_ser(self, encoded_inputs):
+ encoded_inputs_all = []
+ seq_len = len(encoded_inputs['input_ids'])
+ chunk_size = 512
+ for chunk_id, index in enumerate(range(0, seq_len, chunk_size)):
+ chunk_beg = index
+ chunk_end = min(index + chunk_size, seq_len)
+ encoded_inputs_example = {}
+ for key in encoded_inputs:
+ encoded_inputs_example[key] = encoded_inputs[key][chunk_beg:
+ chunk_end]
+
+ encoded_inputs_all.append(encoded_inputs_example)
+ return encoded_inputs_all
+
+ def _chunk_re(self, encoded_inputs):
+ # prepare data
+ entities = encoded_inputs.pop('entities')
+ relations = encoded_inputs.pop('relations')
+ encoded_inputs_all = []
+ chunk_size = 512
+ for chunk_id, index in enumerate(
+ range(0, len(encoded_inputs["input_ids"]), chunk_size)):
+ item = {}
+ for k in encoded_inputs:
+ item[k] = encoded_inputs[k][index:index + chunk_size]
+
+ # select entity in current chunk
+ entities_in_this_span = []
+ global_to_local_map = {} #
+ for entity_id, entity in enumerate(entities):
+ if (index <= entity["start"] < index + chunk_size and
+ index <= entity["end"] < index + chunk_size):
+ entity["start"] = entity["start"] - index
+ entity["end"] = entity["end"] - index
+ global_to_local_map[entity_id] = len(entities_in_this_span)
+ entities_in_this_span.append(entity)
+
+ # select relations in current chunk
+ relations_in_this_span = []
+ for relation in relations:
+ if (index <= relation["start_index"] < index + chunk_size and
+ index <= relation["end_index"] < index + chunk_size):
+ relations_in_this_span.append({
+ "head": global_to_local_map[relation["head"]],
+ "tail": global_to_local_map[relation["tail"]],
+ "start_index": relation["start_index"] - index,
+ "end_index": relation["end_index"] - index,
+ })
+ item.update({
+ "entities": reformat(entities_in_this_span),
+ "relations": reformat(relations_in_this_span),
+ })
+ item['entities']['label'] = [
+ self.entities_labels[x] for x in item['entities']['label']
+ ]
+ encoded_inputs_all.append(item)
+ return encoded_inputs_all
+
+ def _relations(self, entities, relations, id2label, empty_entity,
+ entity_id_to_index_map):
+ """
+ build relations
+ """
+ relations = list(set(relations))
+ relations = [
+ rel for rel in relations
+ if rel[0] not in empty_entity and rel[1] not in empty_entity
+ ]
+ kv_relations = []
+ for rel in relations:
+ pair = [id2label[rel[0]], id2label[rel[1]]]
+ if pair == ["question", "answer"]:
+ kv_relations.append({
+ "head": entity_id_to_index_map[rel[0]],
+ "tail": entity_id_to_index_map[rel[1]]
+ })
+ elif pair == ["answer", "question"]:
+ kv_relations.append({
+ "head": entity_id_to_index_map[rel[1]],
+ "tail": entity_id_to_index_map[rel[0]]
+ })
+ else:
+ continue
+ relations = sorted(
+ [{
+ "head": rel["head"],
+ "tail": rel["tail"],
+ "start_index": get_relation_span(rel, entities)[0],
+ "end_index": get_relation_span(rel, entities)[1],
+ } for rel in kv_relations],
+ key=lambda x: x["head"], )
+ return relations
+
+ def load_img(self, image_path):
+ # read img
+ img = cv2.imread(image_path)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ resize_h, resize_w = self.img_size
+ im_shape = img.shape[0:2]
+ im_scale_y = resize_h / im_shape[0]
+ im_scale_x = resize_w / im_shape[1]
+ img_new = cv2.resize(
+ img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=2)
+ mean = np.array([0.485, 0.456, 0.406])[np.newaxis, np.newaxis, :]
+ std = np.array([0.229, 0.224, 0.225])[np.newaxis, np.newaxis, :]
+ img_new = img_new / 255.0
+ img_new -= mean
+ img_new /= std
+ img = img_new.transpose((2, 0, 1))
+ return img
+
+ def __getitem__(self, idx):
+ if self.load_mode == "all":
+ data = copy.deepcopy(self.encoded_inputs_all[idx])
+ else:
+ data = self._parse_label_file(self.all_lines[idx])[0]
+
+ image_path = data.pop('image_path')
+ data["image"] = self.load_img(image_path)
+
+ return_data = {}
+ for k, v in data.items():
+ if k in self.return_keys:
+ if self.return_keys[k] == 'np':
+ v = np.array(v)
+ return_data[k] = v
+ return return_data
+
+ def __len__(self, ):
+ if self.load_mode == "all":
+ return len(self.encoded_inputs_all)
+ else:
+ return len(self.all_lines)
+
+
+def get_relation_span(rel, entities):
+ bound = []
+ for entity_index in [rel["head"], rel["tail"]]:
+ bound.append(entities[entity_index]["start"])
+ bound.append(entities[entity_index]["end"])
+ return min(bound), max(bound)
+
+
+def reformat(data):
+ new_data = {}
+ for item in data:
+ for k, v in item.items():
+ if k not in new_data:
+ new_data[k] = []
+ new_data[k].append(v)
+ return new_data
diff --git a/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt
index b567c08185e084384c3883f1d602cec3f312ea53..1246e380c1c113e3c96e2b2962f28fd865a8717d 100644
--- a/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:PPOCRv2_ocr_det
+model_name:PPOCRv2_det
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:null
quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
fpgm_export:
diff --git a/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt
index b61dc8bbe36ac5b21ec5f3561d39997f992d6c58..4607b0a7f5d2ffb082ecb84d80b3534d75e14f5f 100644
--- a/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt
@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
inference_dir:Student
-infer_model:./inference/ch_PP-OCRv2_rec_infer/
+infer_model:./inference/ch_PP-OCRv2_rec_infer
infer_export:null
infer_quant:False
inference:tools/infer/predict_rec.py
@@ -45,7 +45,7 @@ inference:tools/infer/predict_rec.py
--use_tensorrt:False|True
--precision:fp32|fp16|int8
--rec_model_dir:
---image_dir:/inference/rec_inference
+--image_dir:./inference/rec_inference
null:null
--benchmark:True
null:null
diff --git a/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt
index 914c1bc7575dfee3309493b9110afe8b9cb7e59b..6127896ae29dc5f4d2813e84824cda5fa0bac7ca 100644
--- a/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt
@@ -6,15 +6,15 @@ Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:pact_train
-norm_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
-pact_train:null
+norm_train:null
+pact_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
fpgm_train:null
distill_train:null
null:null
@@ -27,14 +27,14 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.pretrained_model:
-norm_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
-quant_export:
-fpgm_export:
+norm_export:null
+quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
+fpgm_export: null
distill_export:null
export1:null
export2:null
inference_dir:Student
-infer_model:./inference/ch_PP-OCRv2_rec_infer/
+infer_model:./inference/ch_PP-OCRv2_rec_infer
infer_export:null
infer_quant:True
inference:tools/infer/predict_rec.py
@@ -45,7 +45,7 @@ inference:tools/infer/predict_rec.py
--use_tensorrt:False|True
--precision:fp32|fp16|int8
--rec_model_dir:
---image_dir:/inference/rec_inference
+--image_dir:./inference/rec_inference
null:null
--benchmark:True
null:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt
index 977312f2a49e76d92e4edc11f8f0d3ecf866999a..9a5dd76437b236389f9880fdc1726e18e2cafee4 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt
@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
-Global.epoch_num:lite_train_lite_infer=5|whole_train_whole_infer=300
+Global.epoch_num:lite_train_lite_infer=100|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt
index 8a6c6568584250d269acfe63aef43ef66410fd99..05cde05467d75769965ee23bce2cebfc20408251 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt
@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
-Global.epoch_num:lite_train_lite_infer=5|whole_train_whole_infer=300
+Global.epoch_num:lite_train_lite_infer=20|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:null
quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
fpgm_export:null
@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
-null:null
\ No newline at end of file
+null:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt
index 7bbdd58ae13eca00623123cf2ca39d3b76daa72a..56b9e1896c2a1e9a7ab002884cfbc5de86997535 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt
@@ -28,7 +28,7 @@ null:null
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:null
-quant_export:deploy/slim/quantization/export_model.py -ctest_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
+quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
fpgm_export:null
distill_export:null
export1:null
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt
index bea918a7f366548056d7d62a5785353a4e689d01..ca52eeb1bc6a1853fa7015478fb9028d8dec71c3 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt
@@ -12,22 +12,22 @@ train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
-trainer:norm_train|pact_train|fpgm_export
-norm_train:tools/train.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o
-quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o
-fpgm_export:deploy/slim/prune/export_prune_model.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+quant_train:null
+fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.pretrained_model:
-norm_export:tools/export_model.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
quant_export:null
fpgm_export:null
distill_export:null
diff --git a/test_tipc/configs/det_r50_vd_pse_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_pse_v2.0/train_infer_python.txt
index 5ab6d45d7c1eb5e3c17fd53a8c8c504812c1012c..c60f4263ebc734acf3136a6542bb9e882658af2b 100644
--- a/test_tipc/configs/det_r50_vd_pse_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_pse_v2.0/train_infer_python.txt
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/det_r50_vd_pse/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/cconfigs/det_r50_vd_pse_v2.0/det_r50_vd_pse.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_pse_v2.0/det_r50_vd_pse.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
diff --git a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml b/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml
index 8e9315d2488ad187eb12708d094c5be57cb48eac..4b7340ac59851aa54effa49f73196ad863d02a95 100644
--- a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml
+++ b/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml
@@ -62,7 +62,7 @@ Train:
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
- ratio_list: [0.1, 0.45, 0.3, 0.15]
+ ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
diff --git a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
index d9f15dded4b920cb93b2180aeb9e14e93ebab5cc..e6fb2ca5b459d26cd4b099c17f81bb47cc59bc71 100644
--- a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
-null:null
+--det_algorithm:SAST
diff --git a/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt
index 602254f2f3b7eb6f5b1fc72fbaf212fbea43ca49..2387ba7b5e9bac09b4c85fa5273d0c6ba5bebcb5 100644
--- a/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt
@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
-null:null
+--det_algorithm:SAST
diff --git a/test_tipc/configs/en_server_pgnetA/train_infer_python.txt b/test_tipc/configs/en_server_pgnetA/train_infer_python.txt
index bacd9c7a0b0d7e0dd85ed6cf249025354da71c71..1a25eccb3a192823d58af1c6cf089ea15b6d394c 100644
--- a/test_tipc/configs/en_server_pgnetA/train_infer_python.txt
+++ b/test_tipc/configs/en_server_pgnetA/train_infer_python.txt
@@ -42,7 +42,7 @@ inference:tools/infer/predict_e2e.py
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
---use_tensorrt:False|True
+--use_tensorrt:False
--precision:fp32|fp16|int8
--e2e_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
diff --git a/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt b/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt
index 67630d858c7633daf8e1800b1ab10adb86e6c3bc..695fc8a42ef0f6b79901e8b62ce09d72e3500793 100644
--- a/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt
+++ b/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt
@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt
index 3791aa17b2b5a16565ab3456932e43fd77254472..18504d068740deeec42cf9620c2d9e816d88c5cc 100644
--- a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt
@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt
index 33700ad696394ad9404a5424cddf93608220917a..3bec644ced183fff4329ff08991a137c45bacfc9 100644
--- a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt
@@ -37,7 +37,7 @@ export2:null
infer_model:null
infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
infer_quant:False
-inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="StarNet"
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
diff --git a/test_tipc/configs/rec_r31_sar/train_infer_python.txt b/test_tipc/configs/rec_r31_sar/train_infer_python.txt
index 5cc31b7b8b793e7c82f6676f1fec9a5e8b2393f4..42dfc6b0275c05aef358682d031275488893e5fb 100644
--- a/test_tipc/configs/rec_r31_sar/train_infer_python.txt
+++ b/test_tipc/configs/rec_r31_sar/train_infer_python.txt
@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt
index e816868f33de7ca8794068e8498f6f7845df0324..84bda52480118f84ec5efbc1d4831950b1cdee68 100644
--- a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt
@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt
index bb49ae5977208b2921f4a825b62afa7935f572f1..ac43bd9703d7744220af40fa36b29adf64e89334 100644
--- a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt
@@ -37,7 +37,7 @@ export2:null
infer_model:null
infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
infer_quant:False
-inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="StarNet"
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
diff --git a/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt b/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt
index b3549c635f267cdb0b494341e9f250669cd74bfe..55b25122e3d934ae66051595cc0bdc75aa3386fc 100644
--- a/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt
+++ b/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt
@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index 8876157ef8f4b44b227c171d25bdfd1060007910..71d4010f4b2c3abe698e22b7e1e8f33e9ef9d45f 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -25,7 +25,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
# pretrain lite train data
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
- if [ ${model_name} == "ch_PPOCRv2_det" ]; then
+ if [[ ${model_name} =~ "PPOCRv2_det" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_PP-OCRv2_det_distill_train.tar && cd ../
fi
@@ -49,8 +49,8 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
fi
if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ] || [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
- wget -nc -P ./train_data/ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
- cd ./train_data && tar xf total_text_lite.tar && ln -s total_text && cd ../
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
+ cd ./train_data && tar xf total_text_lite.tar && ln -s total_text_lite total_text && cd ../
fi
if [ ${model_name} == "det_mv3_db_v2.0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
@@ -78,15 +78,15 @@ elif [ ${MODE} = "whole_train_whole_infer" ];then
cd ./pretrain_models/ && tar xf ch_PP-OCRv2_det_distill_train.tar && cd ../
fi
if [ ${model_name} == "en_server_pgnetA" ]; then
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/total_text.tar --no-check-certificate
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate
cd ./pretrain_models/ && tar xf en_server_pgnetA.tar && cd ../
- cd ./train_data && tar xf total_text.tar && ln -s total_text && cd ../
+ cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../
fi
if [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/total_text.tar --no-check-certificate
- cd ./train_data && tar xf total_text.tar && ln -s total_text && cd ../
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
+ cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../
fi
elif [ ${MODE} = "lite_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
@@ -103,59 +103,67 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then
fi
elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
+ cd ./inference && tar xf rec_inference.tar && cd ../
if [ ${model_name} = "ch_ppocr_mobile_v2.0_det" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_det_train"
rm -rf ./train_data/icdar2015
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_train.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_rec_infer"
- wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
- cd ./inference && tar xf ${eval_model_name}.tar && tar xf rec_inference.tar && cd ../
+ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0_rec" ]; then
eval_model_name="ch_ppocr_server_v2.0_rec_infer"
- wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
- cd ./inference && tar xf ${eval_model_name}.tar && tar xf rec_inference.tar && cd ../
+ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
+ elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_PACT" ]; then
+ eval_model_name="ch_PP-OCRv2_rec_slim_quant_train"
+ wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar --no-check-certificate
+ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
+ elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_FPGM" ]; then
+ eval_model_name="ch_PP-OCRv2_rec_train"
+ wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar --no-check-certificate
+ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
fi
- if [ ${model_name} = "ch_PPOCRv2_det" ]; then
+ if [[ ${model_name} =~ "ch_PPOCRv2_det" ]]; then
eval_model_name="ch_PP-OCRv2_det_infer"
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../
fi
+ if [[ ${model_name} =~ "PPOCRv2_ocr_rec" ]]; then
+ eval_model_name="ch_PP-OCRv2_rec_infer"
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar --no-check-certificate
+ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
+ fi
if [ ${model_name} == "en_server_pgnetA" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate
- cd ./inference && tar xf en_server_pgnetA.tar && cd ../
+ cd ./inference && tar xf en_server_pgnetA.tar && tar xf ch_det_data_50.tar && cd ../
fi
if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate
- cd ./inference/ && tar xf det_r50_vd_sast_icdar15_v2.0_train.tar && cd ../
+ cd ./inference/ && tar xf det_r50_vd_sast_icdar15_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
if [ ${model_name} == "det_mv3_db_v2.0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
- cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && cd ../
+ cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
if [ ${model_name} == "det_r50_db_v2.0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate
- cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && cd ../
+ cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
fi
if [ ${MODE} = "klquant_whole_infer" ]; then
diff --git a/test_tipc/test_inference_cpp.sh b/test_tipc/test_inference_cpp.sh
index d26954353ef1e81ae49364b7f9d20357768cff85..4787f83093b0040ae3da6d9efb9028d0cc28de00 100644
--- a/test_tipc/test_inference_cpp.sh
+++ b/test_tipc/test_inference_cpp.sh
@@ -64,10 +64,11 @@ function func_cpp_inference(){
set_infer_data=$(func_set_params "${cpp_image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${cpp_benchmark_key}" "${cpp_benchmark_value}")
set_batchsize=$(func_set_params "${cpp_batch_size_key}" "${batch_size}")
+ set_mkldnn=$(func_set_params "${cpp_use_mkldnn_key}" "${use_mkldnn}")
set_cpu_threads=$(func_set_params "${cpp_cpu_threads_key}" "${threads}")
set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}")
set_infer_params1=$(func_set_params "${cpp_infer_key1}" "${cpp_infer_value1}")
- command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${cpp_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
+ command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
diff --git a/test_tipc/test_inference_python.sh b/test_tipc/test_inference_python.sh
index 72516e044ed8a23c660a4c4f486d19f22a584fb0..27276d55b95051e167432600308f42127d784ee6 100644
--- a/test_tipc/test_inference_python.sh
+++ b/test_tipc/test_inference_python.sh
@@ -79,11 +79,12 @@ function func_inference(){
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
+ set_mkldnn=$(func_set_params "${use_mkldnn_key}" "${use_mkldnn}")
set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
set_infer_params0=$(func_set_params "${rec_model_key}" "${rec_model_value}")
set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
- command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 "
+ command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh
index 0b0a4e4a75f5e978f64404b27a5f26594dbd484e..b69c0f278f2886eeb7c01847bab5d54ff7a18af6 100644
--- a/test_tipc/test_train_inference_python.sh
+++ b/test_tipc/test_train_inference_python.sh
@@ -160,11 +160,12 @@ function func_inference(){
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
+ set_mkldnn=$(func_set_params "${use_mkldnn_key}" "${use_mkldnn}")
set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
set_infer_params0=$(func_set_params "${save_log_key}" "${save_log_value}")
set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
- command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 "
+ command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
@@ -321,10 +322,6 @@ else
save_log="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}"
fi
- # load pretrain from norm training if current trainer is pact or fpgm trainer
- if ([ ${trainer} = ${pact_key} ] || [ ${trainer} = ${fpgm_key} ]) && [ ${nodes} -le 1 ]; then
- set_pretrain="${load_norm_train_model}"
- fi
set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
@@ -340,10 +337,7 @@ else
status_check $? "${cmd}" "${status_log}"
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}")
- # save norm trained models to set pretrain for pact training and fpgm training
- if [ ${trainer} = ${trainer_norm} ] && [ ${nodes} -le 1 ]; then
- load_norm_train_model=${set_eval_pretrain}
- fi
+
# run eval
if [ ${eval_py} != "null" ]; then
set_eval_params1=$(func_set_params "${eval_key1}" "${eval_value1}")
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index f437056ec7b10e28e626d2028b6401cebc647bb1..21bbee098ef19456d05165969a9ad400400f1264 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -195,6 +195,7 @@ def create_predictor(args, mode, logger):
max_batch_size=args.max_batch_size,
min_subgraph_size=args.min_subgraph_size)
# skip the minmum trt subgraph
+ use_dynamic_shape = True
if mode == "det":
min_input_shape = {
"x": [1, 3, 50, 50],
@@ -260,6 +261,8 @@ def create_predictor(args, mode, logger):
max_input_shape.update(max_pact_shape)
opt_input_shape.update(opt_pact_shape)
elif mode == "rec":
+ if args.rec_algorithm != "CRNN":
+ use_dynamic_shape = False
min_input_shape = {"x": [1, 3, 32, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 1536]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
@@ -268,11 +271,10 @@ def create_predictor(args, mode, logger):
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
else:
- min_input_shape = {"x": [1, 3, 10, 10]}
- max_input_shape = {"x": [1, 3, 512, 512]}
- opt_input_shape = {"x": [1, 3, 256, 256]}
- config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
- opt_input_shape)
+ use_dynamic_shape = False
+ if use_dynamic_shape:
+ config.set_trt_dynamic_shape_info(
+ min_input_shape, max_input_shape, opt_input_shape)
else:
config.disable_gpu()
@@ -311,7 +313,10 @@ def create_predictor(args, mode, logger):
def get_infer_gpuid():
- cmd = "env | grep CUDA_VISIBLE_DEVICES"
+ if not paddle.fluid.core.is_compiled_with_rocm():
+ cmd = "env | grep CUDA_VISIBLE_DEVICES"
+ else:
+ cmd = "env | grep HIP_VISIBLE_DEVICES"
env_cuda = os.popen(cmd).readlines()
if len(env_cuda) == 0:
return 0