diff --git a/PPOCRLabel/libs/canvas.py b/PPOCRLabel/libs/canvas.py
index d5662ac79a85c07c79ed2b7df315f338a229535c..6ac1f28b85e65c3776d310136352b70c45628db6 100644
--- a/PPOCRLabel/libs/canvas.py
+++ b/PPOCRLabel/libs/canvas.py
@@ -704,8 +704,9 @@ class Canvas(QWidget):
def keyPressEvent(self, ev):
key = ev.key()
- shapesBackup = []
shapesBackup = copy.deepcopy(self.shapes)
+ if len(shapesBackup) == 0:
+ return
self.shapesBackups.pop()
self.shapesBackups.append(shapesBackup)
if key == Qt.Key_Escape and self.current:
diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
index ab484a44833a405513d7f2b4079a4da4c2e403c8..bb6a196864b6e9e7525f2b5217f0c90ea2ca05a4 100644
--- a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
+++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
@@ -18,6 +18,7 @@ Global:
Architecture:
name: DistillationModel
algorithm: Distillation
+ model_type: det
Models:
Teacher:
freeze_params: true
diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py
index dddae923de223178665e3bfb55a2e7a8c0d5ba17..0cb86108d2275dc6ee1a74e118c27b94131975d3 100755
--- a/deploy/slim/quantization/export_model.py
+++ b/deploy/slim/quantization/export_model.py
@@ -111,7 +111,7 @@ def main():
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
use_srn = config['Architecture']['algorithm'] == "SRN"
- model_type = config['Architecture']['model_type']
+ model_type = config['Architecture'].get('model_type', None)
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn)
@@ -120,8 +120,7 @@ def main():
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
- infer_shape = [3, 32, 100] if config['Architecture'][
- 'model_type'] != "det" else [3, 640, 640]
+ infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640]
save_path = config["Global"]["save_inference_dir"]
diff --git a/doc/datasets/ch_doc2.jpg b/doc/datasets/ch_doc2.jpg
deleted file mode 100644
index 23343b8dedbae7be025552e3a45f9b7af7cf49ee..0000000000000000000000000000000000000000
Binary files a/doc/datasets/ch_doc2.jpg and /dev/null differ
diff --git a/doc/doc_ch/datasets.md b/doc/doc_ch/datasets.md
index 6d84dbbe484be1e2b19a4dedced90f61b7085148..d365fd711aff2dffcd30dd06028734cc707d5df0 100644
--- a/doc/doc_ch/datasets.md
+++ b/doc/doc_ch/datasets.md
@@ -49,7 +49,6 @@ https://aistudio.baidu.com/aistudio/datasetdetail/8429
- 每个样本固定10个字符,字符随机截取自语料库中的句子
- 图片分辨率统一为280x32
![](../datasets/ch_doc1.jpg)
- ![](../datasets/ch_doc2.jpg)
![](../datasets/ch_doc3.jpg)
- **下载地址**:https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (密码:lu7m)
diff --git a/doc/doc_ch/distributed_training.md b/doc/doc_ch/distributed_training.md
index 411ce5ba6aea26755cc65c405be6e0f0d5fd4738..e0251b21ea1157084e4e1b1d77429264d452aa20 100644
--- a/doc/doc_ch/distributed_training.md
+++ b/doc/doc_ch/distributed_training.md
@@ -13,7 +13,7 @@
```shell
python3 -m paddle.distributed.launch \
--log_dir=./log/ \
- --gpus '0,1,2,3,4,5,6,7' \
+ --gpus "0,1,2,3,4,5,6,7" \
tools/train.py \
-c configs/rec/rec_mv3_none_bilstm_ctc.yml
```
diff --git a/doc/doc_en/datasets_en.md b/doc/doc_en/datasets_en.md
index 61d2033b4fe8f0077ad66fb9ae2cd559ce29fd65..0e6b6f381e9d008add802c5f8a30d5498a4f94b2 100644
--- a/doc/doc_en/datasets_en.md
+++ b/doc/doc_en/datasets_en.md
@@ -50,7 +50,6 @@ https://aistudio.baidu.com/aistudio/datasetdetail/8429
- Each sample is fixed with 10 characters, and the characters are randomly intercepted from the sentences in the corpus
- Image resolution is 280x32
![](../datasets/ch_doc1.jpg)
- ![](../datasets/ch_doc2.jpg)
![](../datasets/ch_doc3.jpg)
- **Download link**:https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (Password: lu7m)
diff --git a/doc/doc_en/distributed_training.md b/doc/doc_en/distributed_training.md
index 7a8b71ce308837568c84bf56292f78e9979d3907..519a42f0dc4b9bd4fa18f3f65019e4235282df92 100644
--- a/doc/doc_en/distributed_training.md
+++ b/doc/doc_en/distributed_training.md
@@ -13,7 +13,7 @@ Take recognition as an example. After the data is prepared locally, start the tr
```shell
python3 -m paddle.distributed.launch \
--log_dir=./log/ \
- --gpus '0,1,2,3,4,5,6,7' \
+ --gpus "0,1,2,3,4,5,6,7" \
tools/train.py \
-c configs/rec/rec_mv3_none_bilstm_ctc.yml
```
diff --git a/ppocr/data/imaug/copy_paste.py b/ppocr/data/imaug/copy_paste.py
index bbf62e2a3d813671551efa1a76c03754b1b764f5..0b3386c896792bd670cd2bfc757eb3b80f22bac4 100644
--- a/ppocr/data/imaug/copy_paste.py
+++ b/ppocr/data/imaug/copy_paste.py
@@ -32,6 +32,7 @@ class CopyPaste(object):
self.aug = IaaAugment(augmenter_args)
def __call__(self, data):
+ point_num = data['polys'].shape[1]
src_img = data['image']
src_polys = data['polys'].tolist()
src_ignores = data['ignore_tags'].tolist()
@@ -57,6 +58,9 @@ class CopyPaste(object):
src_img, box = self.paste_img(src_img, box_img, src_polys)
if box is not None:
+ box = box.tolist()
+ for _ in range(len(box), point_num):
+ box.append(box[-1])
src_polys.append(box)
src_ignores.append(tag)
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py
index 6a33e1342506f26ccaa4a146f3f02fadfbd741a2..ee8571b8c452bbd834fc5dbcf01ce390562163d6 100644
--- a/ppocr/data/simple_dataset.py
+++ b/ppocr/data/simple_dataset.py
@@ -14,6 +14,7 @@
import numpy as np
import os
import random
+import traceback
from paddle.io import Dataset
from .imaug import transform, create_operators
@@ -93,7 +94,8 @@ class SimpleDataSet(Dataset):
img = f.read()
data['image'] = img
data = transform(data, load_data_ops)
- if data is None:
+
+ if data is None or data['polys'].shape[1]!=4:
continue
ext_data.append(data)
return ext_data
@@ -115,10 +117,10 @@ class SimpleDataSet(Dataset):
data['image'] = img
data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops)
- except Exception as e:
+ except:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
- data_line, e))
+ data_line, traceback.format_exc()))
outs = None
if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation.
diff --git a/ppocr/modeling/backbones/det_resnet_vd.py b/ppocr/modeling/backbones/det_resnet_vd.py
index 3bb4a0d50501860d5e9df2971e93fba66c152187..a29cf1b5e1ff56e59984bc91226ef7e6b65d0da1 100644
--- a/ppocr/modeling/backbones/det_resnet_vd.py
+++ b/ppocr/modeling/backbones/det_resnet_vd.py
@@ -25,16 +25,14 @@ __all__ = ["ResNet"]
class ConvBNLayer(nn.Layer):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- groups=1,
- is_vd_mode=False,
- act=None,
- name=None, ):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ is_vd_mode=False,
+ act=None):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
@@ -47,19 +45,8 @@ class ConvBNLayer(nn.Layer):
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
- weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
- if name == "conv1":
- bn_name = "bn_" + name
- else:
- bn_name = "bn" + name[3:]
- self._batch_norm = nn.BatchNorm(
- out_channels,
- act=act,
- param_attr=ParamAttr(name=bn_name + '_scale'),
- bias_attr=ParamAttr(bn_name + '_offset'),
- moving_mean_name=bn_name + '_mean',
- moving_variance_name=bn_name + '_variance')
+ self._batch_norm = nn.BatchNorm(out_channels, act=act)
def forward(self, inputs):
if self.is_vd_mode:
@@ -75,29 +62,25 @@ class BottleneckBlock(nn.Layer):
out_channels,
stride,
shortcut=True,
- if_first=False,
- name=None):
+ if_first=False):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
- act='relu',
- name=name + "_branch2a")
+ act='relu')
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu',
- name=name + "_branch2b")
+ act='relu')
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
- act=None,
- name=name + "_branch2c")
+ act=None)
if not shortcut:
self.short = ConvBNLayer(
@@ -105,8 +88,7 @@ class BottleneckBlock(nn.Layer):
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
- is_vd_mode=False if if_first else True,
- name=name + "_branch1")
+ is_vd_mode=False if if_first else True)
self.shortcut = shortcut
@@ -125,13 +107,13 @@ class BottleneckBlock(nn.Layer):
class BasicBlock(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride,
- shortcut=True,
- if_first=False,
- name=None):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ shortcut=True,
+ if_first=False, ):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
@@ -139,14 +121,12 @@ class BasicBlock(nn.Layer):
out_channels=out_channels,
kernel_size=3,
stride=stride,
- act='relu',
- name=name + "_branch2a")
+ act='relu')
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
- act=None,
- name=name + "_branch2b")
+ act=None)
if not shortcut:
self.short = ConvBNLayer(
@@ -154,8 +134,7 @@ class BasicBlock(nn.Layer):
out_channels=out_channels,
kernel_size=1,
stride=1,
- is_vd_mode=False if if_first else True,
- name=name + "_branch1")
+ is_vd_mode=False if if_first else True)
self.shortcut = shortcut
@@ -201,22 +180,19 @@ class ResNet(nn.Layer):
out_channels=32,
kernel_size=3,
stride=2,
- act='relu',
- name="conv1_1")
+ act='relu')
self.conv1_2 = ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
- act='relu',
- name="conv1_2")
+ act='relu')
self.conv1_3 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
- act='relu',
- name="conv1_3")
+ act='relu')
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
@@ -226,13 +202,6 @@ class ResNet(nn.Layer):
block_list = []
shortcut = False
for i in range(depth[block]):
- if layers in [101, 152] and block == 2:
- if i == 0:
- conv_name = "res" + str(block + 2) + "a"
- else:
- conv_name = "res" + str(block + 2) + "b" + str(i)
- else:
- conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
@@ -241,8 +210,7 @@ class ResNet(nn.Layer):
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
- if_first=block == i == 0,
- name=conv_name))
+ if_first=block == i == 0))
shortcut = True
block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4)
@@ -252,7 +220,6 @@ class ResNet(nn.Layer):
block_list = []
shortcut = False
for i in range(depth[block]):
- conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
@@ -261,8 +228,7 @@ class ResNet(nn.Layer):
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
- if_first=block == i == 0,
- name=conv_name))
+ if_first=block == i == 0))
shortcut = True
block_list.append(basic_block)
self.out_channels.append(num_filters[block])
diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md
index 8d117fdeb16e1c0e90bf6ec89924e414fc764249..23fe28f8494ce84e774c3dd21811003f772c41f8 100644
--- a/ppstructure/vqa/README.md
+++ b/ppstructure/vqa/README.md
@@ -1,42 +1,62 @@
-# 视觉问答(VQA)
+# 文档视觉问答(DOC-VQA)
-VQA主要特性如下:
+VQA指视觉问答,主要针对图像内容进行提问和回答,DOC-VQA是VQA任务中的一种,DOC-VQA主要针对文本图像的文字内容提出问题。
+
+PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进行开发。
+
+主要特性如下:
- 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。
-- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取(比如判断问题对)
-- 支持SER任务与OCR引擎联合的端到端系统预测与评估。
-- 支持SER任务和RE任务的自定义训练
+- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对(pair)。
+- 支持SER任务和RE任务的自定义训练。
+- 支持OCR+SER的端到端系统预测与评估。
+- 支持OCR+SER+RE的端到端系统预测。
本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现,
包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。
-## 1. 效果演示
+## 1 性能
+
+我们在 [XFUN](https://github.com/doc-analysis/XFUND) 评估数据集上对算法进行了评估,性能如下
+
+|任务| f1 | 模型下载地址|
+|:---:|:---:| :---:|
+|SER|0.9056| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)|
+|RE|0.7113| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar)|
+
+
+
+## 2. 效果演示
**注意:** 测试图片来源于XFUN数据集。
-### 1.1 SER
+### 2.1 SER
-
-
-
+![](./images/result_ser/zh_val_0_ser.jpg) | ![](./images/result_ser/zh_val_42_ser.jpg)
+---|---
-
-
-
+图中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别
-其中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别,在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
+* 深紫色:HEADER
+* 浅紫色:QUESTION
+* 军绿色:ANSWER
+在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
-### 1.2 RE
-* Coming soon!
+### 2.2 RE
+![](./images/result_re/zh_val_21_re.jpg) | ![](./images/result_re/zh_val_40_re.jpg)
+---|---
-## 2. 安装
+图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
-### 2.1 安装依赖
+
+## 3. 安装
+
+### 3.1 安装依赖
- **(1) 安装PaddlePaddle**
@@ -53,12 +73,12 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
-### 2.2 安装PaddleOCR(包含 PP-OCR 和 VQA )
+### 3.2 安装PaddleOCR(包含 PP-OCR 和 VQA )
- **(1)pip快速安装PaddleOCR whl包(仅预测)**
```bash
-pip install "paddleocr>=2.2" # 推荐使用2.2+版本
+pip install paddleocr
```
- **(2)下载VQA源码(预测+训练)**
@@ -85,13 +105,14 @@ pip install -e .
- **(4)安装VQA的`requirements`**
```bash
+cd ppstructure/vqa
pip install -r requirements.txt
```
-## 3. 使用
+## 4. 使用
-### 3.1 数据和预训练模型准备
+### 4.1 数据和预训练模型准备
处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)。
@@ -104,18 +125,15 @@ wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py)。
-如果希望直接体验预测过程,可以下载我们提供的SER预训练模型,跳过训练过程,直接预测即可。
-
-* SER任务预训练模型下载链接:[链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)
-* RE任务预训练模型下载链接:coming soon!
+如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
-### 3.2 SER任务
+### 4.2 SER任务
* 启动训练
```shell
-python train_ser.py \
+python3.7 train_ser.py \
--model_name_or_path "layoutxlm-base-uncased" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
@@ -131,13 +149,7 @@ python train_ser.py \
--seed 2048
```
-最终会打印出`precision`, `recall`, `f1`等指标,如下所示。
-
-```
-best metrics: {'loss': 1.066644651549203, 'precision': 0.8770182068017863, 'recall': 0.9361936193619362, 'f1': 0.9056402979780063}
-```
-
-模型和训练日志会保存在`./output/ser/`文件夹中。
+最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/ser/`文件夹中。
* 使用评估集合中提供的OCR识别结果进行预测
@@ -159,21 +171,73 @@ export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser_e2e.py \
--model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
--max_seq_length 512 \
- --output_dir "output_res_e2e/"
+ --output_dir "output_res_e2e/" \
+ --infer_imgs "images/input/zh_val_0.jpg"
```
* 对`OCR引擎 + SER`预测系统进行端到端评估
```shell
export CUDA_VISIBLE_DEVICES=0
-python helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
+python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
```
-3.3 RE任务
+### 3.3 RE任务
-coming soon!
+* 启动训练
+```shell
+python3 train_re.py \
+ --model_name_or_path "layoutxlm-base-uncased" \
+ --train_data_dir "XFUND/zh_train/image" \
+ --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
+ --eval_data_dir "XFUND/zh_val/image" \
+ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
+ --label_map_path 'labels/labels_ser.txt' \
+ --num_train_epochs 2 \
+ --eval_steps 10 \
+ --save_steps 500 \
+ --output_dir "output/re/" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --per_gpu_train_batch_size 8 \
+ --per_gpu_eval_batch_size 8 \
+ --evaluate_during_training \
+ --seed 2048
+
+```
+
+最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/re/`文件夹中。
+
+* 使用评估集合中提供的OCR识别结果进行预测
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python3 infer_re.py \
+ --model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
+ --max_seq_length 512 \
+ --eval_data_dir "XFUND/zh_val/image" \
+ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
+ --label_map_path 'labels/labels_ser.txt' \
+ --output_dir "output_res" \
+ --per_gpu_eval_batch_size 1 \
+ --seed 2048
+```
+
+最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`。
+
+* 使用`OCR引擎 + SER + RE`串联结果
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+# python3.7 infer_ser_re_e2e.py \
+ --model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
+ --re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
+ --max_seq_length 512 \
+ --output_dir "output_ser_re_e2e_train/" \
+ --infer_imgs "images/input/zh_val_21.jpg"
+```
## 参考链接
diff --git a/ppstructure/vqa/data_collator.py b/ppstructure/vqa/data_collator.py
new file mode 100644
index 0000000000000000000000000000000000000000..a969935b487e3d22ea5c4a3527028aa2cfe1a797
--- /dev/null
+++ b/ppstructure/vqa/data_collator.py
@@ -0,0 +1,38 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import numbers
+import numpy as np
+
+
+class DataCollator:
+ """
+ data batch
+ """
+
+ def __call__(self, batch):
+ data_dict = {}
+ to_tensor_keys = []
+ for sample in batch:
+ for k, v in sample.items():
+ if k not in data_dict:
+ data_dict[k] = []
+ if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
+ if k not in to_tensor_keys:
+ to_tensor_keys.append(k)
+ data_dict[k].append(v)
+ for k in to_tensor_keys:
+ data_dict[k] = paddle.to_tensor(data_dict[k])
+ return data_dict
diff --git a/ppstructure/vqa/images/input/zh_val_21.jpg b/ppstructure/vqa/images/input/zh_val_21.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..35b572d7dd6a6b42cf43a8a4b33567c0af527d30
Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_21.jpg differ
diff --git a/ppstructure/vqa/images/input/zh_val_40.jpg b/ppstructure/vqa/images/input/zh_val_40.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2a858cc33d54831335c209146853b6c302c734f8
Binary files /dev/null and b/ppstructure/vqa/images/input/zh_val_40.jpg differ
diff --git a/ppstructure/vqa/images/result_re/zh_val_21_re.jpg b/ppstructure/vqa/images/result_re/zh_val_21_re.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7bf248dd0e69057c4775ff9c205317044e94ee65
Binary files /dev/null and b/ppstructure/vqa/images/result_re/zh_val_21_re.jpg differ
diff --git a/ppstructure/vqa/images/result_re/zh_val_40_re.jpg b/ppstructure/vqa/images/result_re/zh_val_40_re.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..242f9d6e80be39c595d98b57d59d48673ce62f20
Binary files /dev/null and b/ppstructure/vqa/images/result_re/zh_val_40_re.jpg differ
diff --git a/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg b/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg
index 22ba9a6f1b7652ca9ce6848093c7a39affb4886b..4605c3a7f395e9868ba55cd31a99367694c78f5c 100644
Binary files a/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg and b/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg differ
diff --git a/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg b/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg
index 951864e5f35a987ff241f276c8da523d8c8eeaf3..13bc7272e49a03115085d4a7420a7acfb92d3260 100644
Binary files a/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg and b/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg differ
diff --git a/ppstructure/vqa/infer_re.py b/ppstructure/vqa/infer_re.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae2f52550294b072179c3bdba28c3572369e11a3
--- /dev/null
+++ b/ppstructure/vqa/infer_re.py
@@ -0,0 +1,162 @@
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+import random
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import paddle
+
+from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
+
+from xfun import XFUNDataset
+from utils import parse_args, get_bio_label_maps, draw_re_results
+from data_collator import DataCollator
+
+from ppocr.utils.logging import get_logger
+
+
+def infer(args):
+ os.makedirs(args.output_dir, exist_ok=True)
+ logger = get_logger()
+ label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
+ pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+
+ tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
+
+ model = LayoutXLMForRelationExtraction.from_pretrained(
+ args.model_name_or_path)
+
+ eval_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.eval_data_dir,
+ label_path=args.eval_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ max_seq_len=args.max_seq_length,
+ pad_token_label_id=pad_token_label_id,
+ contains_re=True,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ eval_dataloader = paddle.io.DataLoader(
+ eval_dataset,
+ batch_size=args.per_gpu_eval_batch_size,
+ num_workers=8,
+ shuffle=False,
+ collate_fn=DataCollator())
+
+ # 读取gt的oct数据
+ ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
+
+ for idx, batch in enumerate(eval_dataloader):
+ logger.info("[Infer] process: {}/{}".format(idx, len(eval_dataloader)))
+ with paddle.no_grad():
+ outputs = model(**batch)
+ pred_relations = outputs['pred_relations']
+
+ ocr_info = ocr_info_list[idx]
+ image_path = ocr_info['image_path']
+ ocr_info = ocr_info['ocr_info']
+
+ # 根据entity里的信息,做token解码后去过滤不要的ocr_info
+ ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
+
+ # 进行 relations 到 ocr信息的转换
+ result = []
+ used_tail_id = []
+ for relations in pred_relations:
+ for relation in relations:
+ if relation['tail_id'] in used_tail_id:
+ continue
+ if relation['head_id'] not in ocr_info or relation[
+ 'tail_id'] not in ocr_info:
+ continue
+ used_tail_id.append(relation['tail_id'])
+ ocr_info_head = ocr_info[relation['head_id']]
+ ocr_info_tail = ocr_info[relation['tail_id']]
+ result.append((ocr_info_head, ocr_info_tail))
+
+ img = cv2.imread(image_path)
+ img_show = draw_re_results(img, result)
+ save_path = os.path.join(args.output_dir, os.path.basename(image_path))
+ cv2.imwrite(save_path, img_show)
+
+
+def load_ocr(img_folder, json_path):
+ import json
+ d = []
+ with open(json_path, "r") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ image_name, info_str = line.split("\t")
+ info_dict = json.loads(info_str)
+ info_dict['image_path'] = os.path.join(img_folder, image_name)
+ d.append(info_dict)
+ return d
+
+
+def filter_bg_by_txt(ocr_info, batch, tokenizer):
+ entities = batch['entities'][0]
+ input_ids = batch['input_ids'][0]
+
+ new_info_dict = {}
+ for i in range(len(entities['start'])):
+ entitie_head = entities['start'][i]
+ entitie_tail = entities['end'][i]
+ word_input_ids = input_ids[entitie_head:entitie_tail].numpy().tolist()
+ txt = tokenizer.convert_ids_to_tokens(word_input_ids)
+ txt = tokenizer.convert_tokens_to_string(txt)
+
+ for i, info in enumerate(ocr_info):
+ if info['text'] == txt:
+ new_info_dict[i] = info
+ return new_info_dict
+
+
+def post_process(pred_relations, ocr_info, img):
+ result = []
+ for relations in pred_relations:
+ for relation in relations:
+ ocr_info_head = ocr_info[relation['head_id']]
+ ocr_info_tail = ocr_info[relation['tail_id']]
+ result.append((ocr_info_head, ocr_info_tail))
+ return result
+
+
+def draw_re(result, image_path, output_folder):
+ img = cv2.imread(image_path)
+
+ from matplotlib import pyplot as plt
+ for ocr_info_head, ocr_info_tail in result:
+ cv2.rectangle(
+ img,
+ tuple(ocr_info_head['bbox'][:2]),
+ tuple(ocr_info_head['bbox'][2:]), (255, 0, 0),
+ thickness=2)
+ cv2.rectangle(
+ img,
+ tuple(ocr_info_tail['bbox'][:2]),
+ tuple(ocr_info_tail['bbox'][2:]), (0, 0, 255),
+ thickness=2)
+ center_p1 = [(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
+ (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2]
+ center_p2 = [(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
+ (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2]
+ cv2.line(
+ img, tuple(center_p1), tuple(center_p2), (0, 255, 0), thickness=2)
+ plt.imshow(img)
+ plt.savefig(
+ os.path.join(output_folder, os.path.basename(image_path)), dpi=600)
+ # plt.show()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ infer(args)
diff --git a/ppstructure/vqa/infer_ser_e2e.py b/ppstructure/vqa/infer_ser_e2e.py
index da027a140bdb4fa12a40d423998d94e438a7cd11..1638e78a11105feb1cb037a545005b2384672eb8 100644
--- a/ppstructure/vqa/infer_ser_e2e.py
+++ b/ppstructure/vqa/infer_ser_e2e.py
@@ -23,8 +23,10 @@ from PIL import Image
import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
+from paddleocr import PaddleOCR
+
# relative reference
-from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps, build_ocr_engine
+from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
@@ -48,74 +50,82 @@ def parse_ocr_info_for_ser(ocr_result):
return ocr_info
-@paddle.no_grad()
-def infer(args):
- os.makedirs(args.output_dir, exist_ok=True)
+class SerPredictor(object):
+ def __init__(self, args):
+ self.max_seq_length = args.max_seq_length
+
+ # init ser token and model
+ self.tokenizer = LayoutXLMTokenizer.from_pretrained(
+ args.model_name_or_path)
+ self.model = LayoutXLMForTokenClassification.from_pretrained(
+ args.model_name_or_path)
+ self.model.eval()
+
+ # init ocr_engine
+ self.ocr_engine = PaddleOCR(
+ rec_model_dir=args.ocr_rec_model_dir,
+ det_model_dir=args.ocr_det_model_dir,
+ use_angle_cls=False,
+ show_log=False)
+ # init dict
+ label2id_map, self.id2label_map = get_bio_label_maps(
+ args.label_map_path)
+ self.label2id_map_for_draw = dict()
+ for key in label2id_map:
+ if key.startswith("I-"):
+ self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
+ else:
+ self.label2id_map_for_draw[key] = label2id_map[key]
+
+ def __call__(self, img):
+ ocr_result = self.ocr_engine.ocr(img, cls=False)
+
+ ocr_info = parse_ocr_info_for_ser(ocr_result)
+
+ inputs = preprocess(
+ tokenizer=self.tokenizer,
+ ori_img=img,
+ ocr_info=ocr_info,
+ max_seq_len=self.max_seq_length)
+
+ outputs = self.model(
+ input_ids=inputs["input_ids"],
+ bbox=inputs["bbox"],
+ image=inputs["image"],
+ token_type_ids=inputs["token_type_ids"],
+ attention_mask=inputs["attention_mask"])
+
+ preds = outputs[0]
+ preds = postprocess(inputs["attention_mask"], preds, self.id2label_map)
+ ocr_info = merge_preds_list_with_ocr_info(
+ ocr_info, inputs["segment_offset_id"], preds,
+ self.label2id_map_for_draw)
+ return ocr_info, inputs
- # init token and model
- tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
- model = LayoutXLMForTokenClassification.from_pretrained(
- args.model_name_or_path)
- model.eval()
- label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
- label2id_map_for_draw = dict()
- for key in label2id_map:
- if key.startswith("I-"):
- label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
- else:
- label2id_map_for_draw[key] = label2id_map[key]
+if __name__ == "__main__":
+ args = parse_args()
+ os.makedirs(args.output_dir, exist_ok=True)
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
- ocr_engine = build_ocr_engine(args.ocr_rec_model_dir,
- args.ocr_det_model_dir)
-
# loop for infer
+ ser_engine = SerPredictor(args)
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
for idx, img_path in enumerate(infer_imgs):
- print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
+ print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
img = cv2.imread(img_path)
- ocr_result = ocr_engine.ocr(img_path, cls=False)
-
- ocr_info = parse_ocr_info_for_ser(ocr_result)
-
- inputs = preprocess(
- tokenizer=tokenizer,
- ori_img=img,
- ocr_info=ocr_info,
- max_seq_len=args.max_seq_length)
-
- outputs = model(
- input_ids=inputs["input_ids"],
- bbox=inputs["bbox"],
- image=inputs["image"],
- token_type_ids=inputs["token_type_ids"],
- attention_mask=inputs["attention_mask"])
-
- preds = outputs[0]
- preds = postprocess(inputs["attention_mask"], preds, id2label_map)
- ocr_info = merge_preds_list_with_ocr_info(
- ocr_info, inputs["segment_offset_id"], preds,
- label2id_map_for_draw)
-
+ result, _ = ser_engine(img)
fout.write(img_path + "\t" + json.dumps(
{
- "ocr_info": ocr_info,
+ "ser_resule": result,
}, ensure_ascii=False) + "\n")
- img_res = draw_ser_results(img, ocr_info)
+ img_res = draw_ser_results(img, result)
cv2.imwrite(
os.path.join(args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] +
"_ser.jpg"), img_res)
-
- return
-
-
-if __name__ == "__main__":
- args = parse_args()
- infer(args)
diff --git a/ppstructure/vqa/infer_ser_re_e2e.py b/ppstructure/vqa/infer_ser_re_e2e.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1d0f52eeecbc6c2ceba5964355008f638f371dd
--- /dev/null
+++ b/ppstructure/vqa/infer_ser_re_e2e.py
@@ -0,0 +1,131 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import json
+import cv2
+import numpy as np
+from copy import deepcopy
+from PIL import Image
+
+import paddle
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction
+
+# relative reference
+from utils import parse_args, get_image_file_list, draw_re_results
+from infer_ser_e2e import SerPredictor
+
+
+def make_input(ser_input, ser_result, max_seq_len=512):
+ entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
+
+ entities = ser_input['entities'][0]
+ assert len(entities) == len(ser_result)
+
+ # entities
+ start = []
+ end = []
+ label = []
+ entity_idx_dict = {}
+ for i, (res, entity) in enumerate(zip(ser_result, entities)):
+ if res['pred'] == 'O':
+ continue
+ entity_idx_dict[len(start)] = i
+ start.append(entity['start'])
+ end.append(entity['end'])
+ label.append(entities_labels[res['pred']])
+ entities = dict(start=start, end=end, label=label)
+
+ # relations
+ head = []
+ tail = []
+ for i in range(len(entities["label"])):
+ for j in range(len(entities["label"])):
+ if entities["label"][i] == 1 and entities["label"][j] == 2:
+ head.append(i)
+ tail.append(j)
+
+ relations = dict(head=head, tail=tail)
+
+ batch_size = ser_input["input_ids"].shape[0]
+ entities_batch = []
+ relations_batch = []
+ for b in range(batch_size):
+ entities_batch.append(entities)
+ relations_batch.append(relations)
+
+ ser_input['entities'] = entities_batch
+ ser_input['relations'] = relations_batch
+
+ ser_input.pop('segment_offset_id')
+ return ser_input, entity_idx_dict
+
+
+class SerReSystem(object):
+ def __init__(self, args):
+ self.ser_engine = SerPredictor(args)
+ self.tokenizer = LayoutXLMTokenizer.from_pretrained(
+ args.re_model_name_or_path)
+ self.model = LayoutXLMForRelationExtraction.from_pretrained(
+ args.re_model_name_or_path)
+ self.model.eval()
+
+ def __call__(self, img):
+ ser_result, ser_inputs = self.ser_engine(img)
+ re_input, entity_idx_dict = make_input(ser_inputs, ser_result)
+
+ re_result = self.model(**re_input)
+
+ pred_relations = re_result['pred_relations'][0]
+ # 进行 relations 到 ocr信息的转换
+ result = []
+ used_tail_id = []
+ for relation in pred_relations:
+ if relation['tail_id'] in used_tail_id:
+ continue
+ used_tail_id.append(relation['tail_id'])
+ ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
+ ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
+ result.append((ocr_info_head, ocr_info_tail))
+
+ return result
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # get infer img list
+ infer_imgs = get_image_file_list(args.infer_imgs)
+
+ # loop for infer
+ ser_re_engine = SerReSystem(args)
+ with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
+ for idx, img_path in enumerate(infer_imgs):
+ print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
+
+ img = cv2.imread(img_path)
+
+ result = ser_re_engine(img)
+ fout.write(img_path + "\t" + json.dumps(
+ {
+ "result": result,
+ }, ensure_ascii=False) + "\n")
+
+ img_res = draw_re_results(img, result)
+ cv2.imwrite(
+ os.path.join(args.output_dir,
+ os.path.splitext(os.path.basename(img_path))[0] +
+ "_re.jpg"), img_res)
diff --git a/ppstructure/vqa/metric.py b/ppstructure/vqa/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb58370521296886670486982caf1202cf99a489
--- /dev/null
+++ b/ppstructure/vqa/metric.py
@@ -0,0 +1,175 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+
+import numpy as np
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+PREFIX_CHECKPOINT_DIR = "checkpoint"
+_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
+
+
+def get_last_checkpoint(folder):
+ content = os.listdir(folder)
+ checkpoints = [
+ path for path in content
+ if _re_checkpoint.search(path) is not None and os.path.isdir(
+ os.path.join(folder, path))
+ ]
+ if len(checkpoints) == 0:
+ return
+ return os.path.join(
+ folder,
+ max(checkpoints,
+ key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
+
+
+def re_score(pred_relations, gt_relations, mode="strict"):
+ """Evaluate RE predictions
+
+ Args:
+ pred_relations (list) : list of list of predicted relations (several relations in each sentence)
+ gt_relations (list) : list of list of ground truth relations
+
+ rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
+ "tail": (start_idx (inclusive), end_idx (exclusive)),
+ "head_type": ent_type,
+ "tail_type": ent_type,
+ "type": rel_type}
+
+ vocab (Vocab) : dataset vocabulary
+ mode (str) : in 'strict' or 'boundaries'"""
+
+ assert mode in ["strict", "boundaries"]
+
+ relation_types = [v for v in [0, 1] if not v == 0]
+ scores = {
+ rel: {
+ "tp": 0,
+ "fp": 0,
+ "fn": 0
+ }
+ for rel in relation_types + ["ALL"]
+ }
+
+ # Count GT relations and Predicted relations
+ n_sents = len(gt_relations)
+ n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
+ n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
+
+ # Count TP, FP and FN per type
+ for pred_sent, gt_sent in zip(pred_relations, gt_relations):
+ for rel_type in relation_types:
+ # strict mode takes argument types into account
+ if mode == "strict":
+ pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
+ rel["tail_type"])
+ for rel in pred_sent if rel["type"] == rel_type}
+ gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
+ rel["tail_type"])
+ for rel in gt_sent if rel["type"] == rel_type}
+
+ # boundaries mode only takes argument spans into account
+ elif mode == "boundaries":
+ pred_rels = {(rel["head"], rel["tail"])
+ for rel in pred_sent if rel["type"] == rel_type}
+ gt_rels = {(rel["head"], rel["tail"])
+ for rel in gt_sent if rel["type"] == rel_type}
+
+ scores[rel_type]["tp"] += len(pred_rels & gt_rels)
+ scores[rel_type]["fp"] += len(pred_rels - gt_rels)
+ scores[rel_type]["fn"] += len(gt_rels - pred_rels)
+
+ # Compute per entity Precision / Recall / F1
+ for rel_type in scores.keys():
+ if scores[rel_type]["tp"]:
+ scores[rel_type]["p"] = scores[rel_type]["tp"] / (
+ scores[rel_type]["fp"] + scores[rel_type]["tp"])
+ scores[rel_type]["r"] = scores[rel_type]["tp"] / (
+ scores[rel_type]["fn"] + scores[rel_type]["tp"])
+ else:
+ scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
+
+ if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
+ scores[rel_type]["f1"] = (
+ 2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
+ (scores[rel_type]["p"] + scores[rel_type]["r"]))
+ else:
+ scores[rel_type]["f1"] = 0
+
+ # Compute micro F1 Scores
+ tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
+ fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
+ fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
+
+ if tp:
+ precision = tp / (tp + fp)
+ recall = tp / (tp + fn)
+ f1 = 2 * precision * recall / (precision + recall)
+
+ else:
+ precision, recall, f1 = 0, 0, 0
+
+ scores["ALL"]["p"] = precision
+ scores["ALL"]["r"] = recall
+ scores["ALL"]["f1"] = f1
+ scores["ALL"]["tp"] = tp
+ scores["ALL"]["fp"] = fp
+ scores["ALL"]["fn"] = fn
+
+ # Compute Macro F1 Scores
+ scores["ALL"]["Macro_f1"] = np.mean(
+ [scores[ent_type]["f1"] for ent_type in relation_types])
+ scores["ALL"]["Macro_p"] = np.mean(
+ [scores[ent_type]["p"] for ent_type in relation_types])
+ scores["ALL"]["Macro_r"] = np.mean(
+ [scores[ent_type]["r"] for ent_type in relation_types])
+
+ # logger.info(f"RE Evaluation in *** {mode.upper()} *** mode")
+
+ # logger.info(
+ # "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(
+ # n_sents, n_rels, n_found, tp
+ # )
+ # )
+ # logger.info(
+ # "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"])
+ # )
+ # logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1))
+ # logger.info(
+ # "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
+ # scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"]
+ # )
+ # )
+
+ # for rel_type in relation_types:
+ # logger.info(
+ # "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
+ # rel_type,
+ # scores[rel_type]["tp"],
+ # scores[rel_type]["fp"],
+ # scores[rel_type]["fn"],
+ # scores[rel_type]["p"],
+ # scores[rel_type]["r"],
+ # scores[rel_type]["f1"],
+ # scores[rel_type]["tp"] + scores[rel_type]["fp"],
+ # )
+ # )
+
+ return scores
diff --git a/ppstructure/vqa/train_re.py b/ppstructure/vqa/train_re.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed19646cf57e69ac99e417ae27568655a4e00039
--- /dev/null
+++ b/ppstructure/vqa/train_re.py
@@ -0,0 +1,261 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+import random
+import numpy as np
+import paddle
+
+from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
+
+from xfun import XFUNDataset
+from utils import parse_args, get_bio_label_maps, print_arguments
+from data_collator import DataCollator
+from metric import re_score
+
+from ppocr.utils.logging import get_logger
+
+
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ paddle.seed(seed)
+
+
+def cal_metric(re_preds, re_labels, entities):
+ gt_relations = []
+ for b in range(len(re_labels)):
+ rel_sent = []
+ for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
+ rel = {}
+ rel["head_id"] = head
+ rel["head"] = (entities[b]["start"][rel["head_id"]],
+ entities[b]["end"][rel["head_id"]])
+ rel["head_type"] = entities[b]["label"][rel["head_id"]]
+
+ rel["tail_id"] = tail
+ rel["tail"] = (entities[b]["start"][rel["tail_id"]],
+ entities[b]["end"][rel["tail_id"]])
+ rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
+
+ rel["type"] = 1
+ rel_sent.append(rel)
+ gt_relations.append(rel_sent)
+ re_metrics = re_score(re_preds, gt_relations, mode="boundaries")
+ return re_metrics
+
+
+def evaluate(model, eval_dataloader, logger, prefix=""):
+ # Eval!
+ logger.info("***** Running evaluation {} *****".format(prefix))
+ logger.info(" Num examples = {}".format(len(eval_dataloader.dataset)))
+
+ re_preds = []
+ re_labels = []
+ entities = []
+ eval_loss = 0.0
+ model.eval()
+ for idx, batch in enumerate(eval_dataloader):
+ with paddle.no_grad():
+ outputs = model(**batch)
+ loss = outputs['loss'].mean().item()
+ if paddle.distributed.get_rank() == 0:
+ logger.info("[Eval] process: {}/{}, loss: {:.5f}".format(
+ idx, len(eval_dataloader), loss))
+
+ eval_loss += loss
+ re_preds.extend(outputs['pred_relations'])
+ re_labels.extend(batch['relations'])
+ entities.extend(batch['entities'])
+ re_metrics = cal_metric(re_preds, re_labels, entities)
+ re_metrics = {
+ "precision": re_metrics["ALL"]["p"],
+ "recall": re_metrics["ALL"]["r"],
+ "f1": re_metrics["ALL"]["f1"],
+ }
+ model.train()
+ return re_metrics
+
+
+def train(args):
+ logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
+ print_arguments(args, logger)
+
+ # Added here for reproducibility (even between python 2 and 3)
+ set_seed(args.seed)
+
+ label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
+ pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+
+ # dist mode
+ if paddle.distributed.get_world_size() > 1:
+ paddle.distributed.init_parallel_env()
+
+ tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
+
+ model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
+ model = LayoutXLMForRelationExtraction(model, dropout=None)
+
+ # dist mode
+ if paddle.distributed.get_world_size() > 1:
+ model = paddle.distributed.DataParallel(model)
+
+ train_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.train_data_dir,
+ label_path=args.train_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ max_seq_len=args.max_seq_length,
+ pad_token_label_id=pad_token_label_id,
+ contains_re=True,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ eval_dataset = XFUNDataset(
+ tokenizer,
+ data_dir=args.eval_data_dir,
+ label_path=args.eval_label_path,
+ label2id_map=label2id_map,
+ img_size=(224, 224),
+ max_seq_len=args.max_seq_length,
+ pad_token_label_id=pad_token_label_id,
+ contains_re=True,
+ add_special_ids=False,
+ return_attention_mask=True,
+ load_mode='all')
+
+ train_sampler = paddle.io.DistributedBatchSampler(
+ train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
+ args.train_batch_size = args.per_gpu_train_batch_size * \
+ max(1, paddle.distributed.get_world_size())
+ train_dataloader = paddle.io.DataLoader(
+ train_dataset,
+ batch_sampler=train_sampler,
+ num_workers=8,
+ use_shared_memory=True,
+ collate_fn=DataCollator())
+
+ eval_dataloader = paddle.io.DataLoader(
+ eval_dataset,
+ batch_size=args.per_gpu_eval_batch_size,
+ num_workers=8,
+ shuffle=False,
+ collate_fn=DataCollator())
+
+ t_total = len(train_dataloader) * args.num_train_epochs
+
+ # build linear decay with warmup lr sch
+ lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
+ learning_rate=args.learning_rate,
+ decay_steps=t_total,
+ end_lr=0.0,
+ power=1.0)
+ if args.warmup_steps > 0:
+ lr_scheduler = paddle.optimizer.lr.LinearWarmup(
+ lr_scheduler,
+ args.warmup_steps,
+ start_lr=0,
+ end_lr=args.learning_rate, )
+ grad_clip = paddle.nn.ClipGradByNorm(clip_norm=10)
+ optimizer = paddle.optimizer.Adam(
+ learning_rate=args.learning_rate,
+ parameters=model.parameters(),
+ epsilon=args.adam_epsilon,
+ grad_clip=grad_clip,
+ weight_decay=args.weight_decay)
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(" Num examples = {}".format(len(train_dataset)))
+ logger.info(" Num Epochs = {}".format(args.num_train_epochs))
+ logger.info(" Instantaneous batch size per GPU = {}".format(
+ args.per_gpu_train_batch_size))
+ logger.info(
+ " Total train batch size (w. parallel, distributed & accumulation) = {}".
+ format(args.train_batch_size * paddle.distributed.get_world_size()))
+ logger.info(" Total optimization steps = {}".format(t_total))
+
+ global_step = 0
+ model.clear_gradients()
+ train_dataloader_len = len(train_dataloader)
+ best_metirc = {'f1': 0}
+ model.train()
+
+ for epoch in range(int(args.num_train_epochs)):
+ for step, batch in enumerate(train_dataloader):
+ outputs = model(**batch)
+ # model outputs are always tuple in ppnlp (see doc)
+ loss = outputs['loss']
+ loss = loss.mean()
+
+ logger.info(
+ "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}".
+ format(epoch, args.num_train_epochs, step, train_dataloader_len,
+ global_step, np.mean(loss.numpy()), optimizer.get_lr()))
+
+ loss.backward()
+ optimizer.step()
+ optimizer.clear_grad()
+ # lr_scheduler.step() # Update learning rate schedule
+
+ global_step += 1
+
+ if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and
+ global_step % args.eval_steps == 0):
+ # Log metrics
+ if (paddle.distributed.get_rank() == 0 and args.
+ evaluate_during_training): # Only evaluate when single GPU otherwise metrics may not average well
+ results = evaluate(model, eval_dataloader, logger)
+ if results['f1'] > best_metirc['f1']:
+ best_metirc = results
+ output_dir = os.path.join(args.output_dir,
+ "checkpoint-best")
+ os.makedirs(output_dir, exist_ok=True)
+ model.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ paddle.save(args,
+ os.path.join(output_dir,
+ "training_args.bin"))
+ logger.info("Saving model checkpoint to {}".format(
+ output_dir))
+ logger.info("eval results: {}".format(results))
+ logger.info("best_metirc: {}".format(best_metirc))
+
+ if (paddle.distributed.get_rank() == 0 and args.save_steps > 0 and
+ global_step % args.save_steps == 0):
+ # Save model checkpoint
+ output_dir = os.path.join(args.output_dir, "checkpoint-latest")
+ os.makedirs(output_dir, exist_ok=True)
+ if paddle.distributed.get_rank() == 0:
+ model.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ paddle.save(args,
+ os.path.join(output_dir, "training_args.bin"))
+ logger.info("Saving model checkpoint to {}".format(
+ output_dir))
+ logger.info("best_metirc: {}".format(best_metirc))
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+ train(args)
diff --git a/ppstructure/vqa/train_ser.py b/ppstructure/vqa/train_ser.py
index 90ca69d93fd22983533fcacd639bbd64dc3e11ec..d3144e7167c59b5883047a948abaedfd21ba9b1c 100644
--- a/ppstructure/vqa/train_ser.py
+++ b/ppstructure/vqa/train_ser.py
@@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
import random
import copy
import logging
@@ -26,8 +31,9 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM
from xfun import XFUNDataset
from utils import parse_args
from utils import get_bio_label_maps
+from utils import print_arguments
-logger = logging.getLogger(__name__)
+from ppocr.utils.logging import get_logger
def set_seed(args):
@@ -38,17 +44,8 @@ def set_seed(args):
def train(args):
os.makedirs(args.output_dir, exist_ok=True)
- logging.basicConfig(
- filename=os.path.join(args.output_dir, "train.log")
- if paddle.distributed.get_rank() == 0 else None,
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%m/%d/%Y %H:%M:%S",
- level=logging.INFO
- if paddle.distributed.get_rank() == 0 else logging.WARN, )
-
- ch = logging.StreamHandler()
- ch.setLevel(logging.DEBUG)
- logger.addHandler(ch)
+ logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
+ print_arguments(args, logger)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
@@ -136,10 +133,10 @@ def train(args):
loss = outputs[0]
loss = loss.mean()
logger.info(
- "[epoch {}/{}][iter: {}/{}] lr: {:.5f}, train loss: {:.5f}, ".
+ "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}".
format(epoch_id, args.num_train_epochs, step,
- len(train_dataloader),
- lr_scheduler.get_lr(), loss.numpy()[0]))
+ len(train_dataloader), global_step,
+ loss.numpy()[0], lr_scheduler.get_lr()))
loss.backward()
tr_loss += loss.item()
@@ -154,13 +151,9 @@ def train(args):
# Only evaluate when single GPU otherwise metrics may not average well
if paddle.distributed.get_rank(
) == 0 and args.evaluate_during_training:
- results, _ = evaluate(
- args,
- model,
- tokenizer,
- label2id_map,
- id2label_map,
- pad_token_label_id, )
+ results, _ = evaluate(args, model, tokenizer, label2id_map,
+ id2label_map, pad_token_label_id,
+ logger)
if best_metrics is None or results["f1"] >= best_metrics[
"f1"]:
@@ -204,6 +197,7 @@ def evaluate(args,
label2id_map,
id2label_map,
pad_token_label_id,
+ logger,
prefix=""):
eval_dataset = XFUNDataset(
tokenizer,
@@ -299,15 +293,6 @@ def evaluate(args,
return results, preds_list
-def print_arguments(args):
- """print arguments"""
- print('----------- Configuration Arguments -----------')
- for arg, value in sorted(vars(args).items()):
- print('%s: %s' % (arg, value))
- print('------------------------------------------------')
-
-
if __name__ == "__main__":
args = parse_args()
- print_arguments(args)
train(args)
diff --git a/ppstructure/vqa/utils.py b/ppstructure/vqa/utils.py
index a4ac1e77d37d0a662294480a393c2f67e7f4cc64..0af180ada2eae740c042378c73b884239ddbf7b9 100644
--- a/ppstructure/vqa/utils.py
+++ b/ppstructure/vqa/utils.py
@@ -24,8 +24,6 @@ import paddle
from PIL import Image, ImageDraw, ImageFont
-from paddleocr import PaddleOCR
-
def get_bio_label_maps(label_map_path):
with open(label_map_path, "r") as fin:
@@ -66,9 +64,9 @@ def get_image_file_list(img_file):
def draw_ser_results(image,
ocr_results,
- font_path="../doc/fonts/simfang.ttf",
+ font_path="../../doc/fonts/simfang.ttf",
font_size=18):
- np.random.seed(0)
+ np.random.seed(2021)
color = (np.random.permutation(range(255)),
np.random.permutation(range(255)),
np.random.permutation(range(255)))
@@ -82,38 +80,64 @@ def draw_ser_results(image,
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
-
for ocr_info in ocr_results:
if ocr_info["pred_id"] not in color_map:
continue
color = color_map[ocr_info["pred_id"]]
-
- # draw ocr results outline
- bbox = ocr_info["bbox"]
- bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
- draw.rectangle(bbox, fill=color)
-
- # draw ocr results
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
- start_y = max(0, bbox[0][1] - font_size)
- tw = font.getsize(text)[0]
- draw.rectangle(
- [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1,
- start_y + font_size)],
- fill=(0, 0, 255))
- draw.text(
- (bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
+
+ draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
-def build_ocr_engine(rec_model_dir, det_model_dir):
- ocr_engine = PaddleOCR(
- rec_model_dir=rec_model_dir,
- det_model_dir=det_model_dir,
- use_angle_cls=False)
- return ocr_engine
+def draw_box_txt(bbox, text, draw, font, font_size, color):
+ # draw ocr results outline
+ bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
+ draw.rectangle(bbox, fill=color)
+
+ # draw ocr results
+ start_y = max(0, bbox[0][1] - font_size)
+ tw = font.getsize(text)[0]
+ draw.rectangle(
+ [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
+ fill=(0, 0, 255))
+ draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
+
+
+def draw_re_results(image,
+ result,
+ font_path="../../doc/fonts/simfang.ttf",
+ font_size=18):
+ np.random.seed(0)
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ img_new = image.copy()
+ draw = ImageDraw.Draw(img_new)
+
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+ color_head = (0, 0, 255)
+ color_tail = (255, 0, 0)
+ color_line = (0, 255, 0)
+
+ for ocr_info_head, ocr_info_tail in result:
+ draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
+ font_size, color_head)
+ draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
+ font_size, color_tail)
+
+ center_head = (
+ (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
+ (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
+ center_tail = (
+ (ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
+ (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
+
+ draw.line([center_head, center_tail], fill=color_line, width=5)
+
+ img_new = Image.blend(image, img_new, 0.5)
+ return np.array(img_new)
# pad sentences
@@ -130,7 +154,7 @@ def pad_sentences(tokenizer,
len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
needs_to_be_padded = pad_to_max_seq_len and \
- max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
+ max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
if needs_to_be_padded:
difference = max_seq_len - len(encoded_inputs["input_ids"])
@@ -162,6 +186,9 @@ def split_page(encoded_inputs, max_seq_len=512):
truncate is often used in training process
"""
for key in encoded_inputs:
+ if key == 'entities':
+ encoded_inputs[key] = [encoded_inputs[key]]
+ continue
encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
@@ -184,14 +211,14 @@ def preprocess(
height = ori_img.shape[0]
width = ori_img.shape[1]
- img = cv2.resize(ori_img,
- (224, 224)).transpose([2, 0, 1]).astype(np.float32)
+ img = cv2.resize(ori_img, img_size).transpose([2, 0, 1]).astype(np.float32)
segment_offset_id = []
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
+ entities = []
for info in ocr_info:
# x1, y1, x2, y2
@@ -211,6 +238,13 @@ def preprocess(
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
+ # for re
+ entities.append({
+ "start": len(input_ids_list),
+ "end": len(input_ids_list) + len(encode_res["input_ids"]),
+ "label": "O",
+ })
+
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
@@ -222,6 +256,7 @@ def preprocess(
"token_type_ids": token_type_ids_list,
"bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list),
+ "entities": entities
}
encoded_inputs = pad_sentences(
@@ -294,35 +329,64 @@ def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list,
return ocr_info
+def print_arguments(args, logger=None):
+ print_func = logger.info if logger is not None else print
+ """print arguments"""
+ print_func('----------- Configuration Arguments -----------')
+ for arg, value in sorted(vars(args).items()):
+ print_func('%s: %s' % (arg, value))
+ print_func('------------------------------------------------')
+
+
def parse_args():
parser = argparse.ArgumentParser()
# Required parameters
# yapf: disable
- parser.add_argument("--model_name_or_path", default=None, type=str, required=True,)
- parser.add_argument("--train_data_dir", default=None, type=str, required=False,)
- parser.add_argument("--train_label_path", default=None, type=str, required=False,)
- parser.add_argument("--eval_data_dir", default=None, type=str, required=False,)
- parser.add_argument("--eval_label_path", default=None, type=str, required=False,)
+ parser.add_argument("--model_name_or_path",
+ default=None, type=str, required=True,)
+ parser.add_argument("--re_model_name_or_path",
+ default=None, type=str, required=False,)
+ parser.add_argument("--train_data_dir", default=None,
+ type=str, required=False,)
+ parser.add_argument("--train_label_path", default=None,
+ type=str, required=False,)
+ parser.add_argument("--eval_data_dir", default=None,
+ type=str, required=False,)
+ parser.add_argument("--eval_label_path", default=None,
+ type=str, required=False,)
parser.add_argument("--output_dir", default=None, type=str, required=True,)
parser.add_argument("--max_seq_length", default=512, type=int,)
parser.add_argument("--evaluate_during_training", action="store_true",)
- parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",)
- parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for eval.",)
- parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.",)
- parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.",)
- parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.",)
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.",)
- parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.",)
- parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.",)
- parser.add_argument("--eval_steps", type=int, default=10, help="eval every X updates steps.",)
- parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.",)
- parser.add_argument("--seed", type=int, default=2048, help="random seed for initialization",)
+ parser.add_argument("--per_gpu_train_batch_size", default=8,
+ type=int, help="Batch size per GPU/CPU for training.",)
+ parser.add_argument("--per_gpu_eval_batch_size", default=8,
+ type=int, help="Batch size per GPU/CPU for eval.",)
+ parser.add_argument("--learning_rate", default=5e-5,
+ type=float, help="The initial learning rate for Adam.",)
+ parser.add_argument("--weight_decay", default=0.0,
+ type=float, help="Weight decay if we apply some.",)
+ parser.add_argument("--adam_epsilon", default=1e-8,
+ type=float, help="Epsilon for Adam optimizer.",)
+ parser.add_argument("--max_grad_norm", default=1.0,
+ type=float, help="Max gradient norm.",)
+ parser.add_argument("--num_train_epochs", default=3, type=int,
+ help="Total number of training epochs to perform.",)
+ parser.add_argument("--warmup_steps", default=0, type=int,
+ help="Linear warmup over warmup_steps.",)
+ parser.add_argument("--eval_steps", type=int, default=10,
+ help="eval every X updates steps.",)
+ parser.add_argument("--save_steps", type=int, default=50,
+ help="Save checkpoint every X updates steps.",)
+ parser.add_argument("--seed", type=int, default=2048,
+ help="random seed for initialization",)
parser.add_argument("--ocr_rec_model_dir", default=None, type=str, )
parser.add_argument("--ocr_det_model_dir", default=None, type=str, )
- parser.add_argument("--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
+ parser.add_argument(
+ "--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
parser.add_argument("--infer_imgs", default=None, type=str, required=False)
- parser.add_argument("--ocr_json_path", default=None, type=str, required=False, help="ocr prediction results")
+ parser.add_argument("--ocr_json_path", default=None,
+ type=str, required=False, help="ocr prediction results")
# yapf: enable
args = parser.parse_args()
return args
diff --git a/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt
index b567c08185e084384c3883f1d602cec3f312ea53..1246e380c1c113e3c96e2b2962f28fd865a8717d 100644
--- a/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_det_PACT/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:PPOCRv2_ocr_det
+model_name:PPOCRv2_det
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:null
quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
fpgm_export:
diff --git a/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt
index b61dc8bbe36ac5b21ec5f3561d39997f992d6c58..4607b0a7f5d2ffb082ecb84d80b3534d75e14f5f 100644
--- a/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt
@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
@@ -34,7 +34,7 @@ distill_export:null
export1:null
export2:null
inference_dir:Student
-infer_model:./inference/ch_PP-OCRv2_rec_infer/
+infer_model:./inference/ch_PP-OCRv2_rec_infer
infer_export:null
infer_quant:False
inference:tools/infer/predict_rec.py
@@ -45,7 +45,7 @@ inference:tools/infer/predict_rec.py
--use_tensorrt:False|True
--precision:fp32|fp16|int8
--rec_model_dir:
---image_dir:/inference/rec_inference
+--image_dir:./inference/rec_inference
null:null
--benchmark:True
null:null
diff --git a/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt
index 914c1bc7575dfee3309493b9110afe8b9cb7e59b..6127896ae29dc5f4d2813e84824cda5fa0bac7ca 100644
--- a/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv2_rec_PACT/train_infer_python.txt
@@ -6,15 +6,15 @@ Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:pact_train
-norm_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
-pact_train:null
+norm_train:null
+pact_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
fpgm_train:null
distill_train:null
null:null
@@ -27,14 +27,14 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.pretrained_model:
-norm_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
-quant_export:
-fpgm_export:
+norm_export:null
+quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
+fpgm_export: null
distill_export:null
export1:null
export2:null
inference_dir:Student
-infer_model:./inference/ch_PP-OCRv2_rec_infer/
+infer_model:./inference/ch_PP-OCRv2_rec_infer
infer_export:null
infer_quant:True
inference:tools/infer/predict_rec.py
@@ -45,7 +45,7 @@ inference:tools/infer/predict_rec.py
--use_tensorrt:False|True
--precision:fp32|fp16|int8
--rec_model_dir:
---image_dir:/inference/rec_inference
+--image_dir:./inference/rec_inference
null:null
--benchmark:True
null:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt
index 977312f2a49e76d92e4edc11f8f0d3ecf866999a..9a5dd76437b236389f9880fdc1726e18e2cafee4 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt
@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
-Global.epoch_num:lite_train_lite_infer=5|whole_train_whole_infer=300
+Global.epoch_num:lite_train_lite_infer=100|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt
index 8a6c6568584250d269acfe63aef43ef66410fd99..05cde05467d75769965ee23bce2cebfc20408251 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/train_infer_python.txt
@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
-Global.epoch_num:lite_train_lite_infer=5|whole_train_whole_infer=300
+Global.epoch_num:lite_train_lite_infer=20|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null
@@ -26,7 +26,7 @@ null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
-Global.pretrained_model:
+Global.checkpoints:
norm_export:null
quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
fpgm_export:null
@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
-null:null
\ No newline at end of file
+null:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt
index 7bbdd58ae13eca00623123cf2ca39d3b76daa72a..56b9e1896c2a1e9a7ab002884cfbc5de86997535 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/train_infer_python.txt
@@ -28,7 +28,7 @@ null:null
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:null
-quant_export:deploy/slim/quantization/export_model.py -ctest_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
+quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
fpgm_export:null
distill_export:null
export1:null
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt
index bea918a7f366548056d7d62a5785353a4e689d01..ca52eeb1bc6a1853fa7015478fb9028d8dec71c3 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt
@@ -12,22 +12,22 @@ train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
-trainer:norm_train|pact_train|fpgm_export
-norm_train:tools/train.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o
-quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o
-fpgm_export:deploy/slim/prune/export_prune_model.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+quant_train:null
+fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.pretrained_model:
-norm_export:tools/export_model.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
quant_export:null
fpgm_export:null
distill_export:null
diff --git a/test_tipc/configs/det_r50_vd_pse_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_pse_v2.0/train_infer_python.txt
index 5ab6d45d7c1eb5e3c17fd53a8c8c504812c1012c..c60f4263ebc734acf3136a6542bb9e882658af2b 100644
--- a/test_tipc/configs/det_r50_vd_pse_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_pse_v2.0/train_infer_python.txt
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/det_r50_vd_pse/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/cconfigs/det_r50_vd_pse_v2.0/det_r50_vd_pse.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_pse_v2.0/det_r50_vd_pse.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
diff --git a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml b/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml
index 8e9315d2488ad187eb12708d094c5be57cb48eac..4b7340ac59851aa54effa49f73196ad863d02a95 100644
--- a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml
+++ b/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml
@@ -62,7 +62,7 @@ Train:
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
- ratio_list: [0.1, 0.45, 0.3, 0.15]
+ ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
diff --git a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
index d9f15dded4b920cb93b2180aeb9e14e93ebab5cc..e6fb2ca5b459d26cd4b099c17f81bb47cc59bc71 100644
--- a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
-null:null
+--det_algorithm:SAST
diff --git a/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt
index 602254f2f3b7eb6f5b1fc72fbaf212fbea43ca49..2387ba7b5e9bac09b4c85fa5273d0c6ba5bebcb5 100644
--- a/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt
@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
-null:null
+--det_algorithm:SAST
diff --git a/test_tipc/configs/en_server_pgnetA/train_infer_python.txt b/test_tipc/configs/en_server_pgnetA/train_infer_python.txt
index d70776998c4e326905920586e90f2833fe42e89b..1a25eccb3a192823d58af1c6cf089ea15b6d394c 100644
--- a/test_tipc/configs/en_server_pgnetA/train_infer_python.txt
+++ b/test_tipc/configs/en_server_pgnetA/train_infer_python.txt
@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
-Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500
+Global.epoch_num:lite_train_lite_infer=5|whole_train_whole_infer=500
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=14
Global.pretrained_model:null
@@ -42,7 +42,7 @@ inference:tools/infer/predict_e2e.py
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
---use_tensorrt:False|True
+--use_tensorrt:False
--precision:fp32|fp16|int8
--e2e_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
diff --git a/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt b/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt
index 67630d858c7633daf8e1800b1ab10adb86e6c3bc..695fc8a42ef0f6b79901e8b62ce09d72e3500793 100644
--- a/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt
+++ b/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt
@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt
index 3791aa17b2b5a16565ab3456932e43fd77254472..18504d068740deeec42cf9620c2d9e816d88c5cc 100644
--- a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt
@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt
index 33700ad696394ad9404a5424cddf93608220917a..3bec644ced183fff4329ff08991a137c45bacfc9 100644
--- a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt
@@ -37,7 +37,7 @@ export2:null
infer_model:null
infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
infer_quant:False
-inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="StarNet"
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
diff --git a/test_tipc/configs/rec_r31_sar/train_infer_python.txt b/test_tipc/configs/rec_r31_sar/train_infer_python.txt
index 5cc31b7b8b793e7c82f6676f1fec9a5e8b2393f4..42dfc6b0275c05aef358682d031275488893e5fb 100644
--- a/test_tipc/configs/rec_r31_sar/train_infer_python.txt
+++ b/test_tipc/configs/rec_r31_sar/train_infer_python.txt
@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt
index e816868f33de7ca8794068e8498f6f7845df0324..84bda52480118f84ec5efbc1d4831950b1cdee68 100644
--- a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt
@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt
index bb49ae5977208b2921f4a825b62afa7935f572f1..ac43bd9703d7744220af40fa36b29adf64e89334 100644
--- a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt
@@ -37,7 +37,7 @@ export2:null
infer_model:null
infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
infer_quant:False
-inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="StarNet"
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
diff --git a/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt b/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt
index b3549c635f267cdb0b494341e9f250669cd74bfe..55b25122e3d934ae66051595cc0bdc75aa3386fc 100644
--- a/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt
+++ b/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt
@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index 8876157ef8f4b44b227c171d25bdfd1060007910..71d4010f4b2c3abe698e22b7e1e8f33e9ef9d45f 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -25,7 +25,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
# pretrain lite train data
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
- if [ ${model_name} == "ch_PPOCRv2_det" ]; then
+ if [[ ${model_name} =~ "PPOCRv2_det" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_PP-OCRv2_det_distill_train.tar && cd ../
fi
@@ -49,8 +49,8 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
fi
if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ] || [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
- wget -nc -P ./train_data/ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
- cd ./train_data && tar xf total_text_lite.tar && ln -s total_text && cd ../
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
+ cd ./train_data && tar xf total_text_lite.tar && ln -s total_text_lite total_text && cd ../
fi
if [ ${model_name} == "det_mv3_db_v2.0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
@@ -78,15 +78,15 @@ elif [ ${MODE} = "whole_train_whole_infer" ];then
cd ./pretrain_models/ && tar xf ch_PP-OCRv2_det_distill_train.tar && cd ../
fi
if [ ${model_name} == "en_server_pgnetA" ]; then
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/total_text.tar --no-check-certificate
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate
cd ./pretrain_models/ && tar xf en_server_pgnetA.tar && cd ../
- cd ./train_data && tar xf total_text.tar && ln -s total_text && cd ../
+ cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../
fi
if [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/total_text.tar --no-check-certificate
- cd ./train_data && tar xf total_text.tar && ln -s total_text && cd ../
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
+ cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../
fi
elif [ ${MODE} = "lite_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
@@ -103,59 +103,67 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then
fi
elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
+ cd ./inference && tar xf rec_inference.tar && cd ../
if [ ${model_name} = "ch_ppocr_mobile_v2.0_det" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_det_train"
rm -rf ./train_data/icdar2015
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_train.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_rec_infer"
- wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
- cd ./inference && tar xf ${eval_model_name}.tar && tar xf rec_inference.tar && cd ../
+ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0_rec" ]; then
eval_model_name="ch_ppocr_server_v2.0_rec_infer"
- wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
- cd ./inference && tar xf ${eval_model_name}.tar && tar xf rec_inference.tar && cd ../
+ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
+ elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_PACT" ]; then
+ eval_model_name="ch_PP-OCRv2_rec_slim_quant_train"
+ wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar --no-check-certificate
+ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
+ elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_FPGM" ]; then
+ eval_model_name="ch_PP-OCRv2_rec_train"
+ wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar --no-check-certificate
+ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
fi
- if [ ${model_name} = "ch_PPOCRv2_det" ]; then
+ if [[ ${model_name} =~ "ch_PPOCRv2_det" ]]; then
eval_model_name="ch_PP-OCRv2_det_infer"
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../
fi
+ if [[ ${model_name} =~ "PPOCRv2_ocr_rec" ]]; then
+ eval_model_name="ch_PP-OCRv2_rec_infer"
+ wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar --no-check-certificate
+ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
+ fi
if [ ${model_name} == "en_server_pgnetA" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate
- cd ./inference && tar xf en_server_pgnetA.tar && cd ../
+ cd ./inference && tar xf en_server_pgnetA.tar && tar xf ch_det_data_50.tar && cd ../
fi
if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate
- cd ./inference/ && tar xf det_r50_vd_sast_icdar15_v2.0_train.tar && cd ../
+ cd ./inference/ && tar xf det_r50_vd_sast_icdar15_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
if [ ${model_name} == "det_mv3_db_v2.0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
- cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && cd ../
+ cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
if [ ${model_name} == "det_r50_db_v2.0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate
- cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && cd ../
+ cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
fi
if [ ${MODE} = "klquant_whole_infer" ]; then
diff --git a/test_tipc/test_inference_cpp.sh b/test_tipc/test_inference_cpp.sh
index d26954353ef1e81ae49364b7f9d20357768cff85..4787f83093b0040ae3da6d9efb9028d0cc28de00 100644
--- a/test_tipc/test_inference_cpp.sh
+++ b/test_tipc/test_inference_cpp.sh
@@ -64,10 +64,11 @@ function func_cpp_inference(){
set_infer_data=$(func_set_params "${cpp_image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${cpp_benchmark_key}" "${cpp_benchmark_value}")
set_batchsize=$(func_set_params "${cpp_batch_size_key}" "${batch_size}")
+ set_mkldnn=$(func_set_params "${cpp_use_mkldnn_key}" "${use_mkldnn}")
set_cpu_threads=$(func_set_params "${cpp_cpu_threads_key}" "${threads}")
set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}")
set_infer_params1=$(func_set_params "${cpp_infer_key1}" "${cpp_infer_value1}")
- command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${cpp_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
+ command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
diff --git a/test_tipc/test_inference_python.sh b/test_tipc/test_inference_python.sh
index 72516e044ed8a23c660a4c4f486d19f22a584fb0..27276d55b95051e167432600308f42127d784ee6 100644
--- a/test_tipc/test_inference_python.sh
+++ b/test_tipc/test_inference_python.sh
@@ -79,11 +79,12 @@ function func_inference(){
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
+ set_mkldnn=$(func_set_params "${use_mkldnn_key}" "${use_mkldnn}")
set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
set_infer_params0=$(func_set_params "${rec_model_key}" "${rec_model_value}")
set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
- command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 "
+ command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh
index 0b0a4e4a75f5e978f64404b27a5f26594dbd484e..b69c0f278f2886eeb7c01847bab5d54ff7a18af6 100644
--- a/test_tipc/test_train_inference_python.sh
+++ b/test_tipc/test_train_inference_python.sh
@@ -160,11 +160,12 @@ function func_inference(){
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
+ set_mkldnn=$(func_set_params "${use_mkldnn_key}" "${use_mkldnn}")
set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
set_infer_params0=$(func_set_params "${save_log_key}" "${save_log_value}")
set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
- command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 "
+ command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
@@ -321,10 +322,6 @@ else
save_log="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}"
fi
- # load pretrain from norm training if current trainer is pact or fpgm trainer
- if ([ ${trainer} = ${pact_key} ] || [ ${trainer} = ${fpgm_key} ]) && [ ${nodes} -le 1 ]; then
- set_pretrain="${load_norm_train_model}"
- fi
set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
@@ -340,10 +337,7 @@ else
status_check $? "${cmd}" "${status_log}"
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}")
- # save norm trained models to set pretrain for pact training and fpgm training
- if [ ${trainer} = ${trainer_norm} ] && [ ${nodes} -le 1 ]; then
- load_norm_train_model=${set_eval_pretrain}
- fi
+
# run eval
if [ ${eval_py} != "null" ]; then
set_eval_params1=$(func_set_params "${eval_key1}" "${eval_value1}")
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index af50c5e6a8cb39faf416dbe7adb516c4db05aef5..21bbee098ef19456d05165969a9ad400400f1264 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -195,6 +195,7 @@ def create_predictor(args, mode, logger):
max_batch_size=args.max_batch_size,
min_subgraph_size=args.min_subgraph_size)
# skip the minmum trt subgraph
+ use_dynamic_shape = True
if mode == "det":
min_input_shape = {
"x": [1, 3, 50, 50],
@@ -260,6 +261,8 @@ def create_predictor(args, mode, logger):
max_input_shape.update(max_pact_shape)
opt_input_shape.update(opt_pact_shape)
elif mode == "rec":
+ if args.rec_algorithm != "CRNN":
+ use_dynamic_shape = False
min_input_shape = {"x": [1, 3, 32, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 1536]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
@@ -268,11 +271,10 @@ def create_predictor(args, mode, logger):
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
else:
- min_input_shape = {"x": [1, 3, 10, 10]}
- max_input_shape = {"x": [1, 3, 512, 512]}
- opt_input_shape = {"x": [1, 3, 256, 256]}
- config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
- opt_input_shape)
+ use_dynamic_shape = False
+ if use_dynamic_shape:
+ config.set_trt_dynamic_shape_info(
+ min_input_shape, max_input_shape, opt_input_shape)
else:
config.disable_gpu()