提交 4f8b5113 编写于 作者: qq_25193841's avatar qq_25193841

Merge remote-tracking branch 'origin/dygraph' into dygraph

...@@ -704,8 +704,9 @@ class Canvas(QWidget): ...@@ -704,8 +704,9 @@ class Canvas(QWidget):
def keyPressEvent(self, ev): def keyPressEvent(self, ev):
key = ev.key() key = ev.key()
shapesBackup = []
shapesBackup = copy.deepcopy(self.shapes) shapesBackup = copy.deepcopy(self.shapes)
if len(shapesBackup) == 0:
return
self.shapesBackups.pop() self.shapesBackups.pop()
self.shapesBackups.append(shapesBackup) self.shapesBackups.append(shapesBackup)
if key == Qt.Key_Escape and self.current: if key == Qt.Key_Escape and self.current:
......
...@@ -18,6 +18,7 @@ Global: ...@@ -18,6 +18,7 @@ Global:
Architecture: Architecture:
name: DistillationModel name: DistillationModel
algorithm: Distillation algorithm: Distillation
model_type: det
Models: Models:
Teacher: Teacher:
freeze_params: true freeze_params: true
......
...@@ -111,7 +111,7 @@ def main(): ...@@ -111,7 +111,7 @@ def main():
valid_dataloader = build_dataloader(config, 'Eval', device, logger) valid_dataloader = build_dataloader(config, 'Eval', device, logger)
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type'] model_type = config['Architecture'].get('model_type', None)
# start eval # start eval
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn) eval_class, model_type, use_srn)
...@@ -120,8 +120,7 @@ def main(): ...@@ -120,8 +120,7 @@ def main():
for k, v in metric.items(): for k, v in metric.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
infer_shape = [3, 32, 100] if config['Architecture'][ infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640]
'model_type'] != "det" else [3, 640, 640]
save_path = config["Global"]["save_inference_dir"] save_path = config["Global"]["save_inference_dir"]
......
...@@ -49,7 +49,6 @@ https://aistudio.baidu.com/aistudio/datasetdetail/8429 ...@@ -49,7 +49,6 @@ https://aistudio.baidu.com/aistudio/datasetdetail/8429
- 每个样本固定10个字符,字符随机截取自语料库中的句子 - 每个样本固定10个字符,字符随机截取自语料库中的句子
- 图片分辨率统一为280x32 - 图片分辨率统一为280x32
![](../datasets/ch_doc1.jpg) ![](../datasets/ch_doc1.jpg)
![](../datasets/ch_doc2.jpg)
![](../datasets/ch_doc3.jpg) ![](../datasets/ch_doc3.jpg)
- **下载地址**:https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (密码:lu7m) - **下载地址**:https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (密码:lu7m)
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
```shell ```shell
python3 -m paddle.distributed.launch \ python3 -m paddle.distributed.launch \
--log_dir=./log/ \ --log_dir=./log/ \
--gpus '0,1,2,3,4,5,6,7' \ --gpus "0,1,2,3,4,5,6,7" \
tools/train.py \ tools/train.py \
-c configs/rec/rec_mv3_none_bilstm_ctc.yml -c configs/rec/rec_mv3_none_bilstm_ctc.yml
``` ```
......
...@@ -50,7 +50,6 @@ https://aistudio.baidu.com/aistudio/datasetdetail/8429 ...@@ -50,7 +50,6 @@ https://aistudio.baidu.com/aistudio/datasetdetail/8429
- Each sample is fixed with 10 characters, and the characters are randomly intercepted from the sentences in the corpus - Each sample is fixed with 10 characters, and the characters are randomly intercepted from the sentences in the corpus
- Image resolution is 280x32 - Image resolution is 280x32
![](../datasets/ch_doc1.jpg) ![](../datasets/ch_doc1.jpg)
![](../datasets/ch_doc2.jpg)
![](../datasets/ch_doc3.jpg) ![](../datasets/ch_doc3.jpg)
- **Download link**:https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (Password: lu7m) - **Download link**:https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (Password: lu7m)
......
...@@ -13,7 +13,7 @@ Take recognition as an example. After the data is prepared locally, start the tr ...@@ -13,7 +13,7 @@ Take recognition as an example. After the data is prepared locally, start the tr
```shell ```shell
python3 -m paddle.distributed.launch \ python3 -m paddle.distributed.launch \
--log_dir=./log/ \ --log_dir=./log/ \
--gpus '0,1,2,3,4,5,6,7' \ --gpus "0,1,2,3,4,5,6,7" \
tools/train.py \ tools/train.py \
-c configs/rec/rec_mv3_none_bilstm_ctc.yml -c configs/rec/rec_mv3_none_bilstm_ctc.yml
``` ```
......
...@@ -32,6 +32,7 @@ class CopyPaste(object): ...@@ -32,6 +32,7 @@ class CopyPaste(object):
self.aug = IaaAugment(augmenter_args) self.aug = IaaAugment(augmenter_args)
def __call__(self, data): def __call__(self, data):
point_num = data['polys'].shape[1]
src_img = data['image'] src_img = data['image']
src_polys = data['polys'].tolist() src_polys = data['polys'].tolist()
src_ignores = data['ignore_tags'].tolist() src_ignores = data['ignore_tags'].tolist()
...@@ -57,6 +58,9 @@ class CopyPaste(object): ...@@ -57,6 +58,9 @@ class CopyPaste(object):
src_img, box = self.paste_img(src_img, box_img, src_polys) src_img, box = self.paste_img(src_img, box_img, src_polys)
if box is not None: if box is not None:
box = box.tolist()
for _ in range(len(box), point_num):
box.append(box[-1])
src_polys.append(box) src_polys.append(box)
src_ignores.append(tag) src_ignores.append(tag)
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR) src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import numpy as np import numpy as np
import os import os
import random import random
import traceback
from paddle.io import Dataset from paddle.io import Dataset
from .imaug import transform, create_operators from .imaug import transform, create_operators
...@@ -93,7 +94,8 @@ class SimpleDataSet(Dataset): ...@@ -93,7 +94,8 @@ class SimpleDataSet(Dataset):
img = f.read() img = f.read()
data['image'] = img data['image'] = img
data = transform(data, load_data_ops) data = transform(data, load_data_ops)
if data is None:
if data is None or data['polys'].shape[1]!=4:
continue continue
ext_data.append(data) ext_data.append(data)
return ext_data return ext_data
...@@ -115,10 +117,10 @@ class SimpleDataSet(Dataset): ...@@ -115,10 +117,10 @@ class SimpleDataSet(Dataset):
data['image'] = img data['image'] = img
data['ext_data'] = self.get_ext_data() data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops) outs = transform(data, self.ops)
except Exception as e: except:
self.logger.error( self.logger.error(
"When parsing line {}, error happened with msg: {}".format( "When parsing line {}, error happened with msg: {}".format(
data_line, e)) data_line, traceback.format_exc()))
outs = None outs = None
if outs is None: if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation. # during evaluation, we should fix the idx to get same results for many times of evaluation.
......
...@@ -25,16 +25,14 @@ __all__ = ["ResNet"] ...@@ -25,16 +25,14 @@ __all__ = ["ResNet"]
class ConvBNLayer(nn.Layer): class ConvBNLayer(nn.Layer):
def __init__( def __init__(self,
self, in_channels,
in_channels, out_channels,
out_channels, kernel_size,
kernel_size, stride=1,
stride=1, groups=1,
groups=1, is_vd_mode=False,
is_vd_mode=False, act=None):
act=None,
name=None, ):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode self.is_vd_mode = is_vd_mode
...@@ -47,19 +45,8 @@ class ConvBNLayer(nn.Layer): ...@@ -47,19 +45,8 @@ class ConvBNLayer(nn.Layer):
stride=stride, stride=stride,
padding=(kernel_size - 1) // 2, padding=(kernel_size - 1) // 2,
groups=groups, groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False) bias_attr=False)
if name == "conv1": self._batch_norm = nn.BatchNorm(out_channels, act=act)
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def forward(self, inputs): def forward(self, inputs):
if self.is_vd_mode: if self.is_vd_mode:
...@@ -75,29 +62,25 @@ class BottleneckBlock(nn.Layer): ...@@ -75,29 +62,25 @@ class BottleneckBlock(nn.Layer):
out_channels, out_channels,
stride, stride,
shortcut=True, shortcut=True,
if_first=False, if_first=False):
name=None):
super(BottleneckBlock, self).__init__() super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=1, kernel_size=1,
act='relu', act='relu')
name=name + "_branch2a")
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=3, kernel_size=3,
stride=stride, stride=stride,
act='relu', act='relu')
name=name + "_branch2b")
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels * 4, out_channels=out_channels * 4,
kernel_size=1, kernel_size=1,
act=None, act=None)
name=name + "_branch2c")
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(
...@@ -105,8 +88,7 @@ class BottleneckBlock(nn.Layer): ...@@ -105,8 +88,7 @@ class BottleneckBlock(nn.Layer):
out_channels=out_channels * 4, out_channels=out_channels * 4,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
is_vd_mode=False if if_first else True, is_vd_mode=False if if_first else True)
name=name + "_branch1")
self.shortcut = shortcut self.shortcut = shortcut
...@@ -125,13 +107,13 @@ class BottleneckBlock(nn.Layer): ...@@ -125,13 +107,13 @@ class BottleneckBlock(nn.Layer):
class BasicBlock(nn.Layer): class BasicBlock(nn.Layer):
def __init__(self, def __init__(
in_channels, self,
out_channels, in_channels,
stride, out_channels,
shortcut=True, stride,
if_first=False, shortcut=True,
name=None): if_first=False, ):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.stride = stride self.stride = stride
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
...@@ -139,14 +121,12 @@ class BasicBlock(nn.Layer): ...@@ -139,14 +121,12 @@ class BasicBlock(nn.Layer):
out_channels=out_channels, out_channels=out_channels,
kernel_size=3, kernel_size=3,
stride=stride, stride=stride,
act='relu', act='relu')
name=name + "_branch2a")
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=3, kernel_size=3,
act=None, act=None)
name=name + "_branch2b")
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(
...@@ -154,8 +134,7 @@ class BasicBlock(nn.Layer): ...@@ -154,8 +134,7 @@ class BasicBlock(nn.Layer):
out_channels=out_channels, out_channels=out_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
is_vd_mode=False if if_first else True, is_vd_mode=False if if_first else True)
name=name + "_branch1")
self.shortcut = shortcut self.shortcut = shortcut
...@@ -201,22 +180,19 @@ class ResNet(nn.Layer): ...@@ -201,22 +180,19 @@ class ResNet(nn.Layer):
out_channels=32, out_channels=32,
kernel_size=3, kernel_size=3,
stride=2, stride=2,
act='relu', act='relu')
name="conv1_1")
self.conv1_2 = ConvBNLayer( self.conv1_2 = ConvBNLayer(
in_channels=32, in_channels=32,
out_channels=32, out_channels=32,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
act='relu', act='relu')
name="conv1_2")
self.conv1_3 = ConvBNLayer( self.conv1_3 = ConvBNLayer(
in_channels=32, in_channels=32,
out_channels=64, out_channels=64,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
act='relu', act='relu')
name="conv1_3")
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = [] self.stages = []
...@@ -226,13 +202,6 @@ class ResNet(nn.Layer): ...@@ -226,13 +202,6 @@ class ResNet(nn.Layer):
block_list = [] block_list = []
shortcut = False shortcut = False
for i in range(depth[block]): for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer( bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i), 'bb_%d_%d' % (block, i),
BottleneckBlock( BottleneckBlock(
...@@ -241,8 +210,7 @@ class ResNet(nn.Layer): ...@@ -241,8 +210,7 @@ class ResNet(nn.Layer):
out_channels=num_filters[block], out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1, stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut, shortcut=shortcut,
if_first=block == i == 0, if_first=block == i == 0))
name=conv_name))
shortcut = True shortcut = True
block_list.append(bottleneck_block) block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4) self.out_channels.append(num_filters[block] * 4)
...@@ -252,7 +220,6 @@ class ResNet(nn.Layer): ...@@ -252,7 +220,6 @@ class ResNet(nn.Layer):
block_list = [] block_list = []
shortcut = False shortcut = False
for i in range(depth[block]): for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer( basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i), 'bb_%d_%d' % (block, i),
BasicBlock( BasicBlock(
...@@ -261,8 +228,7 @@ class ResNet(nn.Layer): ...@@ -261,8 +228,7 @@ class ResNet(nn.Layer):
out_channels=num_filters[block], out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1, stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut, shortcut=shortcut,
if_first=block == i == 0, if_first=block == i == 0))
name=conv_name))
shortcut = True shortcut = True
block_list.append(basic_block) block_list.append(basic_block)
self.out_channels.append(num_filters[block]) self.out_channels.append(num_filters[block])
......
# 视觉问答(VQA) # 文档视觉问答(DOC-VQA)
VQA主要特性如下: VQA指视觉问答,主要针对图像内容进行提问和回答,DOC-VQA是VQA任务中的一种,DOC-VQA主要针对文本图像的文字内容提出问题。
PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进行开发。
主要特性如下:
- 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。 - 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。
- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取(比如判断问题对) - 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对(pair)。
- 支持SER任务与OCR引擎联合的端到端系统预测与评估。 - 支持SER任务和RE任务的自定义训练。
- 支持SER任务和RE任务的自定义训练 - 支持OCR+SER的端到端系统预测与评估。
- 支持OCR+SER+RE的端到端系统预测。
本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现, 本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现,
包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。 包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。
## 1. 效果演示 ## 1 性能
我们在 [XFUN](https://github.com/doc-analysis/XFUND) 评估数据集上对算法进行了评估,性能如下
|任务| f1 | 模型下载地址|
|:---:|:---:| :---:|
|SER|0.9056| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)|
|RE|0.7113| [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar)|
## 2. 效果演示
**注意:** 测试图片来源于XFUN数据集。 **注意:** 测试图片来源于XFUN数据集。
### 1.1 SER ### 2.1 SER
<div align="center"> ![](./images/result_ser/zh_val_0_ser.jpg) | ![](./images/result_ser/zh_val_42_ser.jpg)
<img src="./images/result_ser/zh_val_0_ser.jpg" width = "600" /> ---|---
</div>
<div align="center"> 图中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别
<img src="./images/result_ser/zh_val_42_ser.jpg" width = "600" />
</div>
其中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别,在OCR检测框的左上方也标出了对应的类别和OCR识别结果。 * 深紫色:HEADER
* 浅紫色:QUESTION
* 军绿色:ANSWER
在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
### 1.2 RE
* Coming soon! ### 2.2 RE
![](./images/result_re/zh_val_21_re.jpg) | ![](./images/result_re/zh_val_40_re.jpg)
---|---
## 2. 安装 图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
### 2.1 安装依赖
## 3. 安装
### 3.1 安装依赖
- **(1) 安装PaddlePaddle** - **(1) 安装PaddlePaddle**
...@@ -53,12 +73,12 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple ...@@ -53,12 +73,12 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
### 2.2 安装PaddleOCR(包含 PP-OCR 和 VQA ) ### 3.2 安装PaddleOCR(包含 PP-OCR 和 VQA )
- **(1)pip快速安装PaddleOCR whl包(仅预测)** - **(1)pip快速安装PaddleOCR whl包(仅预测)**
```bash ```bash
pip install "paddleocr>=2.2" # 推荐使用2.2+版本 pip install paddleocr
``` ```
- **(2)下载VQA源码(预测+训练)** - **(2)下载VQA源码(预测+训练)**
...@@ -85,13 +105,14 @@ pip install -e . ...@@ -85,13 +105,14 @@ pip install -e .
- **(4)安装VQA的`requirements`** - **(4)安装VQA的`requirements`**
```bash ```bash
cd ppstructure/vqa
pip install -r requirements.txt pip install -r requirements.txt
``` ```
## 3. 使用 ## 4. 使用
### 3.1 数据和预训练模型准备 ### 4.1 数据和预训练模型准备
处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar) 处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)
...@@ -104,18 +125,15 @@ wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar ...@@ -104,18 +125,15 @@ wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py) 如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py)
如果希望直接体验预测过程,可以下载我们提供的SER预训练模型,跳过训练过程,直接预测即可。 如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
* SER任务预训练模型下载链接:[链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar)
* RE任务预训练模型下载链接:coming soon!
### 3.2 SER任务 ### 4.2 SER任务
* 启动训练 * 启动训练
```shell ```shell
python train_ser.py \ python3.7 train_ser.py \
--model_name_or_path "layoutxlm-base-uncased" \ --model_name_or_path "layoutxlm-base-uncased" \
--train_data_dir "XFUND/zh_train/image" \ --train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \ --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
...@@ -131,13 +149,7 @@ python train_ser.py \ ...@@ -131,13 +149,7 @@ python train_ser.py \
--seed 2048 --seed 2048
``` ```
最终会打印出`precision`, `recall`, `f1`等指标,如下所示。 最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/ser/`文件夹中。
```
best metrics: {'loss': 1.066644651549203, 'precision': 0.8770182068017863, 'recall': 0.9361936193619362, 'f1': 0.9056402979780063}
```
模型和训练日志会保存在`./output/ser/`文件夹中。
* 使用评估集合中提供的OCR识别结果进行预测 * 使用评估集合中提供的OCR识别结果进行预测
...@@ -159,21 +171,73 @@ export CUDA_VISIBLE_DEVICES=0 ...@@ -159,21 +171,73 @@ export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser_e2e.py \ python3.7 infer_ser_e2e.py \
--model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \ --model_name_or_path "./output/PP-Layout_v1.0_ser_pretrained/" \
--max_seq_length 512 \ --max_seq_length 512 \
--output_dir "output_res_e2e/" --output_dir "output_res_e2e/" \
--infer_imgs "images/input/zh_val_0.jpg"
``` ```
*`OCR引擎 + SER`预测系统进行端到端评估 *`OCR引擎 + SER`预测系统进行端到端评估
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
``` ```
3.3 RE任务 ### 3.3 RE任务
coming soon! * 启动训练
```shell
python3 train_re.py \
--model_name_or_path "layoutxlm-base-uncased" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path 'labels/labels_ser.txt' \
--num_train_epochs 2 \
--eval_steps 10 \
--save_steps 500 \
--output_dir "output/re/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--per_gpu_train_batch_size 8 \
--per_gpu_eval_batch_size 8 \
--evaluate_during_training \
--seed 2048
```
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/re/`文件夹中。
* 使用评估集合中提供的OCR识别结果进行预测
```shell
export CUDA_VISIBLE_DEVICES=0
python3 infer_re.py \
--model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path 'labels/labels_ser.txt' \
--output_dir "output_res" \
--per_gpu_eval_batch_size 1 \
--seed 2048
```
最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`
* 使用`OCR引擎 + SER + RE`串联结果
```shell
export CUDA_VISIBLE_DEVICES=0
# python3.7 infer_ser_re_e2e.py \
--model_name_or_path "./PP-Layout_v1.0_ser_pretrained/" \
--re_model_name_or_path "./PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \
--output_dir "output_ser_re_e2e_train/" \
--infer_imgs "images/input/zh_val_21.jpg"
```
## 参考链接 ## 参考链接
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numbers
import numpy as np
class DataCollator:
"""
data batch
"""
def __call__(self, batch):
data_dict = {}
to_tensor_keys = []
for sample in batch:
for k, v in sample.items():
if k not in data_dict:
data_dict[k] = []
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if k not in to_tensor_keys:
to_tensor_keys.append(k)
data_dict[k].append(v)
for k in to_tensor_keys:
data_dict[k] = paddle.to_tensor(data_dict[k])
return data_dict
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import cv2
import matplotlib.pyplot as plt
import numpy as np
import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, draw_re_results
from data_collator import DataCollator
from ppocr.utils.logging import get_logger
def infer(args):
os.makedirs(args.output_dir, exist_ok=True)
logger = get_logger()
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
model = LayoutXLMForRelationExtraction.from_pretrained(
args.model_name_or_path)
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=8,
shuffle=False,
collate_fn=DataCollator())
# 读取gt的oct数据
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
for idx, batch in enumerate(eval_dataloader):
logger.info("[Infer] process: {}/{}".format(idx, len(eval_dataloader)))
with paddle.no_grad():
outputs = model(**batch)
pred_relations = outputs['pred_relations']
ocr_info = ocr_info_list[idx]
image_path = ocr_info['image_path']
ocr_info = ocr_info['ocr_info']
# 根据entity里的信息,做token解码后去过滤不要的ocr_info
ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
# 进行 relations 到 ocr信息的转换
result = []
used_tail_id = []
for relations in pred_relations:
for relation in relations:
if relation['tail_id'] in used_tail_id:
continue
if relation['head_id'] not in ocr_info or relation[
'tail_id'] not in ocr_info:
continue
used_tail_id.append(relation['tail_id'])
ocr_info_head = ocr_info[relation['head_id']]
ocr_info_tail = ocr_info[relation['tail_id']]
result.append((ocr_info_head, ocr_info_tail))
img = cv2.imread(image_path)
img_show = draw_re_results(img, result)
save_path = os.path.join(args.output_dir, os.path.basename(image_path))
cv2.imwrite(save_path, img_show)
def load_ocr(img_folder, json_path):
import json
d = []
with open(json_path, "r") as fin:
lines = fin.readlines()
for line in lines:
image_name, info_str = line.split("\t")
info_dict = json.loads(info_str)
info_dict['image_path'] = os.path.join(img_folder, image_name)
d.append(info_dict)
return d
def filter_bg_by_txt(ocr_info, batch, tokenizer):
entities = batch['entities'][0]
input_ids = batch['input_ids'][0]
new_info_dict = {}
for i in range(len(entities['start'])):
entitie_head = entities['start'][i]
entitie_tail = entities['end'][i]
word_input_ids = input_ids[entitie_head:entitie_tail].numpy().tolist()
txt = tokenizer.convert_ids_to_tokens(word_input_ids)
txt = tokenizer.convert_tokens_to_string(txt)
for i, info in enumerate(ocr_info):
if info['text'] == txt:
new_info_dict[i] = info
return new_info_dict
def post_process(pred_relations, ocr_info, img):
result = []
for relations in pred_relations:
for relation in relations:
ocr_info_head = ocr_info[relation['head_id']]
ocr_info_tail = ocr_info[relation['tail_id']]
result.append((ocr_info_head, ocr_info_tail))
return result
def draw_re(result, image_path, output_folder):
img = cv2.imread(image_path)
from matplotlib import pyplot as plt
for ocr_info_head, ocr_info_tail in result:
cv2.rectangle(
img,
tuple(ocr_info_head['bbox'][:2]),
tuple(ocr_info_head['bbox'][2:]), (255, 0, 0),
thickness=2)
cv2.rectangle(
img,
tuple(ocr_info_tail['bbox'][:2]),
tuple(ocr_info_tail['bbox'][2:]), (0, 0, 255),
thickness=2)
center_p1 = [(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2]
center_p2 = [(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2]
cv2.line(
img, tuple(center_p1), tuple(center_p2), (0, 255, 0), thickness=2)
plt.imshow(img)
plt.savefig(
os.path.join(output_folder, os.path.basename(image_path)), dpi=600)
# plt.show()
if __name__ == "__main__":
args = parse_args()
infer(args)
...@@ -23,8 +23,10 @@ from PIL import Image ...@@ -23,8 +23,10 @@ from PIL import Image
import paddle import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddleocr import PaddleOCR
# relative reference # relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps, build_ocr_engine from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
...@@ -48,74 +50,82 @@ def parse_ocr_info_for_ser(ocr_result): ...@@ -48,74 +50,82 @@ def parse_ocr_info_for_ser(ocr_result):
return ocr_info return ocr_info
@paddle.no_grad() class SerPredictor(object):
def infer(args): def __init__(self, args):
os.makedirs(args.output_dir, exist_ok=True) self.max_seq_length = args.max_seq_length
# init ser token and model
self.tokenizer = LayoutXLMTokenizer.from_pretrained(
args.model_name_or_path)
self.model = LayoutXLMForTokenClassification.from_pretrained(
args.model_name_or_path)
self.model.eval()
# init ocr_engine
self.ocr_engine = PaddleOCR(
rec_model_dir=args.ocr_rec_model_dir,
det_model_dir=args.ocr_det_model_dir,
use_angle_cls=False,
show_log=False)
# init dict
label2id_map, self.id2label_map = get_bio_label_maps(
args.label_map_path)
self.label2id_map_for_draw = dict()
for key in label2id_map:
if key.startswith("I-"):
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
else:
self.label2id_map_for_draw[key] = label2id_map[key]
def __call__(self, img):
ocr_result = self.ocr_engine.ocr(img, cls=False)
ocr_info = parse_ocr_info_for_ser(ocr_result)
inputs = preprocess(
tokenizer=self.tokenizer,
ori_img=img,
ocr_info=ocr_info,
max_seq_len=self.max_seq_length)
outputs = self.model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = outputs[0]
preds = postprocess(inputs["attention_mask"], preds, self.id2label_map)
ocr_info = merge_preds_list_with_ocr_info(
ocr_info, inputs["segment_offset_id"], preds,
self.label2id_map_for_draw)
return ocr_info, inputs
# init token and model
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
model = LayoutXLMForTokenClassification.from_pretrained(
args.model_name_or_path)
model.eval()
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) if __name__ == "__main__":
label2id_map_for_draw = dict() args = parse_args()
for key in label2id_map: os.makedirs(args.output_dir, exist_ok=True)
if key.startswith("I-"):
label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
else:
label2id_map_for_draw[key] = label2id_map[key]
# get infer img list # get infer img list
infer_imgs = get_image_file_list(args.infer_imgs) infer_imgs = get_image_file_list(args.infer_imgs)
ocr_engine = build_ocr_engine(args.ocr_rec_model_dir,
args.ocr_det_model_dir)
# loop for infer # loop for infer
ser_engine = SerPredictor(args)
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout: with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
for idx, img_path in enumerate(infer_imgs): for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}]".format(idx, len(infer_imgs), img_path)) print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
img = cv2.imread(img_path) img = cv2.imread(img_path)
ocr_result = ocr_engine.ocr(img_path, cls=False) result, _ = ser_engine(img)
ocr_info = parse_ocr_info_for_ser(ocr_result)
inputs = preprocess(
tokenizer=tokenizer,
ori_img=img,
ocr_info=ocr_info,
max_seq_len=args.max_seq_length)
outputs = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = outputs[0]
preds = postprocess(inputs["attention_mask"], preds, id2label_map)
ocr_info = merge_preds_list_with_ocr_info(
ocr_info, inputs["segment_offset_id"], preds,
label2id_map_for_draw)
fout.write(img_path + "\t" + json.dumps( fout.write(img_path + "\t" + json.dumps(
{ {
"ocr_info": ocr_info, "ser_resule": result,
}, ensure_ascii=False) + "\n") }, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, ocr_info) img_res = draw_ser_results(img, result)
cv2.imwrite( cv2.imwrite(
os.path.join(args.output_dir, os.path.join(args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] + os.path.splitext(os.path.basename(img_path))[0] +
"_ser.jpg"), img_res) "_ser.jpg"), img_res)
return
if __name__ == "__main__":
args = parse_args()
infer(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import json
import cv2
import numpy as np
from copy import deepcopy
from PIL import Image
import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction
# relative reference
from utils import parse_args, get_image_file_list, draw_re_results
from infer_ser_e2e import SerPredictor
def make_input(ser_input, ser_result, max_seq_len=512):
entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
entities = ser_input['entities'][0]
assert len(entities) == len(ser_result)
# entities
start = []
end = []
label = []
entity_idx_dict = {}
for i, (res, entity) in enumerate(zip(ser_result, entities)):
if res['pred'] == 'O':
continue
entity_idx_dict[len(start)] = i
start.append(entity['start'])
end.append(entity['end'])
label.append(entities_labels[res['pred']])
entities = dict(start=start, end=end, label=label)
# relations
head = []
tail = []
for i in range(len(entities["label"])):
for j in range(len(entities["label"])):
if entities["label"][i] == 1 and entities["label"][j] == 2:
head.append(i)
tail.append(j)
relations = dict(head=head, tail=tail)
batch_size = ser_input["input_ids"].shape[0]
entities_batch = []
relations_batch = []
for b in range(batch_size):
entities_batch.append(entities)
relations_batch.append(relations)
ser_input['entities'] = entities_batch
ser_input['relations'] = relations_batch
ser_input.pop('segment_offset_id')
return ser_input, entity_idx_dict
class SerReSystem(object):
def __init__(self, args):
self.ser_engine = SerPredictor(args)
self.tokenizer = LayoutXLMTokenizer.from_pretrained(
args.re_model_name_or_path)
self.model = LayoutXLMForRelationExtraction.from_pretrained(
args.re_model_name_or_path)
self.model.eval()
def __call__(self, img):
ser_result, ser_inputs = self.ser_engine(img)
re_input, entity_idx_dict = make_input(ser_inputs, ser_result)
re_result = self.model(**re_input)
pred_relations = re_result['pred_relations'][0]
# 进行 relations 到 ocr信息的转换
result = []
used_tail_id = []
for relation in pred_relations:
if relation['tail_id'] in used_tail_id:
continue
used_tail_id.append(relation['tail_id'])
ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
result.append((ocr_info_head, ocr_info_tail))
return result
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
ser_re_engine = SerReSystem(args)
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
img = cv2.imread(img_path)
result = ser_re_engine(img)
fout.write(img_path + "\t" + json.dumps(
{
"result": result,
}, ensure_ascii=False) + "\n")
img_res = draw_re_results(img, result)
cv2.imwrite(
os.path.join(args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] +
"_re.jpg"), img_res)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import numpy as np
import logging
logger = logging.getLogger(__name__)
PREFIX_CHECKPOINT_DIR = "checkpoint"
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
def get_last_checkpoint(folder):
content = os.listdir(folder)
checkpoints = [
path for path in content
if _re_checkpoint.search(path) is not None and os.path.isdir(
os.path.join(folder, path))
]
if len(checkpoints) == 0:
return
return os.path.join(
folder,
max(checkpoints,
key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
def re_score(pred_relations, gt_relations, mode="strict"):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert mode in ["strict", "boundaries"]
relation_types = [v for v in [0, 1] if not v == 0]
scores = {
rel: {
"tp": 0,
"fp": 0,
"fn": 0
}
for rel in relation_types + ["ALL"]
}
# Count GT relations and Predicted relations
n_sents = len(gt_relations)
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
# Count TP, FP and FN per type
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
for rel_type in relation_types:
# strict mode takes argument types into account
if mode == "strict":
pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in pred_sent if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in gt_sent if rel["type"] == rel_type}
# boundaries mode only takes argument spans into account
elif mode == "boundaries":
pred_rels = {(rel["head"], rel["tail"])
for rel in pred_sent if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["tail"])
for rel in gt_sent if rel["type"] == rel_type}
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
# Compute per entity Precision / Recall / F1
for rel_type in scores.keys():
if scores[rel_type]["tp"]:
scores[rel_type]["p"] = scores[rel_type]["tp"] / (
scores[rel_type]["fp"] + scores[rel_type]["tp"])
scores[rel_type]["r"] = scores[rel_type]["tp"] / (
scores[rel_type]["fn"] + scores[rel_type]["tp"])
else:
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
scores[rel_type]["f1"] = (
2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
(scores[rel_type]["p"] + scores[rel_type]["r"]))
else:
scores[rel_type]["f1"] = 0
# Compute micro F1 Scores
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
if tp:
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * precision * recall / (precision + recall)
else:
precision, recall, f1 = 0, 0, 0
scores["ALL"]["p"] = precision
scores["ALL"]["r"] = recall
scores["ALL"]["f1"] = f1
scores["ALL"]["tp"] = tp
scores["ALL"]["fp"] = fp
scores["ALL"]["fn"] = fn
# Compute Macro F1 Scores
scores["ALL"]["Macro_f1"] = np.mean(
[scores[ent_type]["f1"] for ent_type in relation_types])
scores["ALL"]["Macro_p"] = np.mean(
[scores[ent_type]["p"] for ent_type in relation_types])
scores["ALL"]["Macro_r"] = np.mean(
[scores[ent_type]["r"] for ent_type in relation_types])
# logger.info(f"RE Evaluation in *** {mode.upper()} *** mode")
# logger.info(
# "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(
# n_sents, n_rels, n_found, tp
# )
# )
# logger.info(
# "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"])
# )
# logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1))
# logger.info(
# "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
# scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"]
# )
# )
# for rel_type in relation_types:
# logger.info(
# "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
# rel_type,
# scores[rel_type]["tp"],
# scores[rel_type]["fp"],
# scores[rel_type]["fn"],
# scores[rel_type]["p"],
# scores[rel_type]["r"],
# scores[rel_type]["f1"],
# scores[rel_type]["tp"] + scores[rel_type]["fp"],
# )
# )
return scores
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import numpy as np
import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, print_arguments
from data_collator import DataCollator
from metric import re_score
from ppocr.utils.logging import get_logger
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
def cal_metric(re_preds, re_labels, entities):
gt_relations = []
for b in range(len(re_labels)):
rel_sent = []
for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
rel = {}
rel["head_id"] = head
rel["head"] = (entities[b]["start"][rel["head_id"]],
entities[b]["end"][rel["head_id"]])
rel["head_type"] = entities[b]["label"][rel["head_id"]]
rel["tail_id"] = tail
rel["tail"] = (entities[b]["start"][rel["tail_id"]],
entities[b]["end"][rel["tail_id"]])
rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent)
re_metrics = re_score(re_preds, gt_relations, mode="boundaries")
return re_metrics
def evaluate(model, eval_dataloader, logger, prefix=""):
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = {}".format(len(eval_dataloader.dataset)))
re_preds = []
re_labels = []
entities = []
eval_loss = 0.0
model.eval()
for idx, batch in enumerate(eval_dataloader):
with paddle.no_grad():
outputs = model(**batch)
loss = outputs['loss'].mean().item()
if paddle.distributed.get_rank() == 0:
logger.info("[Eval] process: {}/{}, loss: {:.5f}".format(
idx, len(eval_dataloader), loss))
eval_loss += loss
re_preds.extend(outputs['pred_relations'])
re_labels.extend(batch['relations'])
entities.extend(batch['entities'])
re_metrics = cal_metric(re_preds, re_labels, entities)
re_metrics = {
"precision": re_metrics["ALL"]["p"],
"recall": re_metrics["ALL"]["r"],
"f1": re_metrics["ALL"]["f1"],
}
model.train()
return re_metrics
def train(args):
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
print_arguments(args, logger)
# Added here for reproducibility (even between python 2 and 3)
set_seed(args.seed)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
# dist mode
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
model = LayoutXLMForRelationExtraction(model, dropout=None)
# dist mode
if paddle.distributed.get_world_size() > 1:
model = paddle.distributed.DataParallel(model)
train_dataset = XFUNDataset(
tokenizer,
data_dir=args.train_data_dir,
label_path=args.train_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
train_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
args.train_batch_size = args.per_gpu_train_batch_size * \
max(1, paddle.distributed.get_world_size())
train_dataloader = paddle.io.DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=8,
use_shared_memory=True,
collate_fn=DataCollator())
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=8,
shuffle=False,
collate_fn=DataCollator())
t_total = len(train_dataloader) * args.num_train_epochs
# build linear decay with warmup lr sch
lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
learning_rate=args.learning_rate,
decay_steps=t_total,
end_lr=0.0,
power=1.0)
if args.warmup_steps > 0:
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
lr_scheduler,
args.warmup_steps,
start_lr=0,
end_lr=args.learning_rate, )
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=10)
optimizer = paddle.optimizer.Adam(
learning_rate=args.learning_rate,
parameters=model.parameters(),
epsilon=args.adam_epsilon,
grad_clip=grad_clip,
weight_decay=args.weight_decay)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = {}".format(len(train_dataset)))
logger.info(" Num Epochs = {}".format(args.num_train_epochs))
logger.info(" Instantaneous batch size per GPU = {}".format(
args.per_gpu_train_batch_size))
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = {}".
format(args.train_batch_size * paddle.distributed.get_world_size()))
logger.info(" Total optimization steps = {}".format(t_total))
global_step = 0
model.clear_gradients()
train_dataloader_len = len(train_dataloader)
best_metirc = {'f1': 0}
model.train()
for epoch in range(int(args.num_train_epochs)):
for step, batch in enumerate(train_dataloader):
outputs = model(**batch)
# model outputs are always tuple in ppnlp (see doc)
loss = outputs['loss']
loss = loss.mean()
logger.info(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}".
format(epoch, args.num_train_epochs, step, train_dataloader_len,
global_step, np.mean(loss.numpy()), optimizer.get_lr()))
loss.backward()
optimizer.step()
optimizer.clear_grad()
# lr_scheduler.step() # Update learning rate schedule
global_step += 1
if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and
global_step % args.eval_steps == 0):
# Log metrics
if (paddle.distributed.get_rank() == 0 and args.
evaluate_during_training): # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(model, eval_dataloader, logger)
if results['f1'] > best_metirc['f1']:
best_metirc = results
output_dir = os.path.join(args.output_dir,
"checkpoint-best")
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args,
os.path.join(output_dir,
"training_args.bin"))
logger.info("Saving model checkpoint to {}".format(
output_dir))
logger.info("eval results: {}".format(results))
logger.info("best_metirc: {}".format(best_metirc))
if (paddle.distributed.get_rank() == 0 and args.save_steps > 0 and
global_step % args.save_steps == 0):
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "checkpoint-latest")
os.makedirs(output_dir, exist_ok=True)
if paddle.distributed.get_rank() == 0:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args,
os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to {}".format(
output_dir))
logger.info("best_metirc: {}".format(best_metirc))
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
train(args)
...@@ -12,8 +12,13 @@ ...@@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import os import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random import random
import copy import copy
import logging import logging
...@@ -26,8 +31,9 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM ...@@ -26,8 +31,9 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM
from xfun import XFUNDataset from xfun import XFUNDataset
from utils import parse_args from utils import parse_args
from utils import get_bio_label_maps from utils import get_bio_label_maps
from utils import print_arguments
logger = logging.getLogger(__name__) from ppocr.utils.logging import get_logger
def set_seed(args): def set_seed(args):
...@@ -38,17 +44,8 @@ def set_seed(args): ...@@ -38,17 +44,8 @@ def set_seed(args):
def train(args): def train(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
logging.basicConfig( logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
filename=os.path.join(args.output_dir, "train.log") print_arguments(args, logger)
if paddle.distributed.get_rank() == 0 else None,
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO
if paddle.distributed.get_rank() == 0 else logging.WARN, )
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
logger.addHandler(ch)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
...@@ -136,10 +133,10 @@ def train(args): ...@@ -136,10 +133,10 @@ def train(args):
loss = outputs[0] loss = outputs[0]
loss = loss.mean() loss = loss.mean()
logger.info( logger.info(
"[epoch {}/{}][iter: {}/{}] lr: {:.5f}, train loss: {:.5f}, ". "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {}, lr: {}".
format(epoch_id, args.num_train_epochs, step, format(epoch_id, args.num_train_epochs, step,
len(train_dataloader), len(train_dataloader), global_step,
lr_scheduler.get_lr(), loss.numpy()[0])) loss.numpy()[0], lr_scheduler.get_lr()))
loss.backward() loss.backward()
tr_loss += loss.item() tr_loss += loss.item()
...@@ -154,13 +151,9 @@ def train(args): ...@@ -154,13 +151,9 @@ def train(args):
# Only evaluate when single GPU otherwise metrics may not average well # Only evaluate when single GPU otherwise metrics may not average well
if paddle.distributed.get_rank( if paddle.distributed.get_rank(
) == 0 and args.evaluate_during_training: ) == 0 and args.evaluate_during_training:
results, _ = evaluate( results, _ = evaluate(args, model, tokenizer, label2id_map,
args, id2label_map, pad_token_label_id,
model, logger)
tokenizer,
label2id_map,
id2label_map,
pad_token_label_id, )
if best_metrics is None or results["f1"] >= best_metrics[ if best_metrics is None or results["f1"] >= best_metrics[
"f1"]: "f1"]:
...@@ -204,6 +197,7 @@ def evaluate(args, ...@@ -204,6 +197,7 @@ def evaluate(args,
label2id_map, label2id_map,
id2label_map, id2label_map,
pad_token_label_id, pad_token_label_id,
logger,
prefix=""): prefix=""):
eval_dataset = XFUNDataset( eval_dataset = XFUNDataset(
tokenizer, tokenizer,
...@@ -299,15 +293,6 @@ def evaluate(args, ...@@ -299,15 +293,6 @@ def evaluate(args,
return results, preds_list return results, preds_list
def print_arguments(args):
"""print arguments"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
print_arguments(args)
train(args) train(args)
...@@ -24,8 +24,6 @@ import paddle ...@@ -24,8 +24,6 @@ import paddle
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from paddleocr import PaddleOCR
def get_bio_label_maps(label_map_path): def get_bio_label_maps(label_map_path):
with open(label_map_path, "r") as fin: with open(label_map_path, "r") as fin:
...@@ -66,9 +64,9 @@ def get_image_file_list(img_file): ...@@ -66,9 +64,9 @@ def get_image_file_list(img_file):
def draw_ser_results(image, def draw_ser_results(image,
ocr_results, ocr_results,
font_path="../doc/fonts/simfang.ttf", font_path="../../doc/fonts/simfang.ttf",
font_size=18): font_size=18):
np.random.seed(0) np.random.seed(2021)
color = (np.random.permutation(range(255)), color = (np.random.permutation(range(255)),
np.random.permutation(range(255)), np.random.permutation(range(255)),
np.random.permutation(range(255))) np.random.permutation(range(255)))
...@@ -82,38 +80,64 @@ def draw_ser_results(image, ...@@ -82,38 +80,64 @@ def draw_ser_results(image,
draw = ImageDraw.Draw(img_new) draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8") font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
for ocr_info in ocr_results: for ocr_info in ocr_results:
if ocr_info["pred_id"] not in color_map: if ocr_info["pred_id"] not in color_map:
continue continue
color = color_map[ocr_info["pred_id"]] color = color_map[ocr_info["pred_id"]]
# draw ocr results outline
bbox = ocr_info["bbox"]
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
draw.rectangle(bbox, fill=color)
# draw ocr results
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"]) text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
start_y = max(0, bbox[0][1] - font_size)
tw = font.getsize(text)[0] draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
draw.rectangle(
[(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1,
start_y + font_size)],
fill=(0, 0, 255))
draw.text(
(bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
img_new = Image.blend(image, img_new, 0.5) img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new) return np.array(img_new)
def build_ocr_engine(rec_model_dir, det_model_dir): def draw_box_txt(bbox, text, draw, font, font_size, color):
ocr_engine = PaddleOCR( # draw ocr results outline
rec_model_dir=rec_model_dir, bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
det_model_dir=det_model_dir, draw.rectangle(bbox, fill=color)
use_angle_cls=False)
return ocr_engine # draw ocr results
start_y = max(0, bbox[0][1] - font_size)
tw = font.getsize(text)[0]
draw.rectangle(
[(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
fill=(0, 0, 255))
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
def draw_re_results(image,
result,
font_path="../../doc/fonts/simfang.ttf",
font_size=18):
np.random.seed(0)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
color_head = (0, 0, 255)
color_tail = (255, 0, 0)
color_line = (0, 255, 0)
for ocr_info_head, ocr_info_tail in result:
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
font_size, color_head)
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
font_size, color_tail)
center_head = (
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
center_tail = (
(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
draw.line([center_head, center_tail], fill=color_line, width=5)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
# pad sentences # pad sentences
...@@ -130,7 +154,7 @@ def pad_sentences(tokenizer, ...@@ -130,7 +154,7 @@ def pad_sentences(tokenizer,
len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
needs_to_be_padded = pad_to_max_seq_len and \ needs_to_be_padded = pad_to_max_seq_len and \
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
if needs_to_be_padded: if needs_to_be_padded:
difference = max_seq_len - len(encoded_inputs["input_ids"]) difference = max_seq_len - len(encoded_inputs["input_ids"])
...@@ -162,6 +186,9 @@ def split_page(encoded_inputs, max_seq_len=512): ...@@ -162,6 +186,9 @@ def split_page(encoded_inputs, max_seq_len=512):
truncate is often used in training process truncate is often used in training process
""" """
for key in encoded_inputs: for key in encoded_inputs:
if key == 'entities':
encoded_inputs[key] = [encoded_inputs[key]]
continue
encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key]) encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len]) encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
...@@ -184,14 +211,14 @@ def preprocess( ...@@ -184,14 +211,14 @@ def preprocess(
height = ori_img.shape[0] height = ori_img.shape[0]
width = ori_img.shape[1] width = ori_img.shape[1]
img = cv2.resize(ori_img, img = cv2.resize(ori_img, img_size).transpose([2, 0, 1]).astype(np.float32)
(224, 224)).transpose([2, 0, 1]).astype(np.float32)
segment_offset_id = [] segment_offset_id = []
words_list = [] words_list = []
bbox_list = [] bbox_list = []
input_ids_list = [] input_ids_list = []
token_type_ids_list = [] token_type_ids_list = []
entities = []
for info in ocr_info: for info in ocr_info:
# x1, y1, x2, y2 # x1, y1, x2, y2
...@@ -211,6 +238,13 @@ def preprocess( ...@@ -211,6 +238,13 @@ def preprocess(
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1] encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1] encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
# for re
entities.append({
"start": len(input_ids_list),
"end": len(input_ids_list) + len(encode_res["input_ids"]),
"label": "O",
})
input_ids_list.extend(encode_res["input_ids"]) input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"]) token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"])) bbox_list.extend([bbox] * len(encode_res["input_ids"]))
...@@ -222,6 +256,7 @@ def preprocess( ...@@ -222,6 +256,7 @@ def preprocess(
"token_type_ids": token_type_ids_list, "token_type_ids": token_type_ids_list,
"bbox": bbox_list, "bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list), "attention_mask": [1] * len(input_ids_list),
"entities": entities
} }
encoded_inputs = pad_sentences( encoded_inputs = pad_sentences(
...@@ -294,35 +329,64 @@ def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list, ...@@ -294,35 +329,64 @@ def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list,
return ocr_info return ocr_info
def print_arguments(args, logger=None):
print_func = logger.info if logger is not None else print
"""print arguments"""
print_func('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print_func('%s: %s' % (arg, value))
print_func('------------------------------------------------')
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Required parameters # Required parameters
# yapf: disable # yapf: disable
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,) parser.add_argument("--model_name_or_path",
parser.add_argument("--train_data_dir", default=None, type=str, required=False,) default=None, type=str, required=True,)
parser.add_argument("--train_label_path", default=None, type=str, required=False,) parser.add_argument("--re_model_name_or_path",
parser.add_argument("--eval_data_dir", default=None, type=str, required=False,) default=None, type=str, required=False,)
parser.add_argument("--eval_label_path", default=None, type=str, required=False,) parser.add_argument("--train_data_dir", default=None,
type=str, required=False,)
parser.add_argument("--train_label_path", default=None,
type=str, required=False,)
parser.add_argument("--eval_data_dir", default=None,
type=str, required=False,)
parser.add_argument("--eval_label_path", default=None,
type=str, required=False,)
parser.add_argument("--output_dir", default=None, type=str, required=True,) parser.add_argument("--output_dir", default=None, type=str, required=True,)
parser.add_argument("--max_seq_length", default=512, type=int,) parser.add_argument("--max_seq_length", default=512, type=int,)
parser.add_argument("--evaluate_during_training", action="store_true",) parser.add_argument("--evaluate_during_training", action="store_true",)
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",) parser.add_argument("--per_gpu_train_batch_size", default=8,
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for eval.",) type=int, help="Batch size per GPU/CPU for training.",)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.",) parser.add_argument("--per_gpu_eval_batch_size", default=8,
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.",) type=int, help="Batch size per GPU/CPU for eval.",)
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.",) parser.add_argument("--learning_rate", default=5e-5,
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.",) type=float, help="The initial learning rate for Adam.",)
parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.",) parser.add_argument("--weight_decay", default=0.0,
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.",) type=float, help="Weight decay if we apply some.",)
parser.add_argument("--eval_steps", type=int, default=10, help="eval every X updates steps.",) parser.add_argument("--adam_epsilon", default=1e-8,
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.",) type=float, help="Epsilon for Adam optimizer.",)
parser.add_argument("--seed", type=int, default=2048, help="random seed for initialization",) parser.add_argument("--max_grad_norm", default=1.0,
type=float, help="Max gradient norm.",)
parser.add_argument("--num_train_epochs", default=3, type=int,
help="Total number of training epochs to perform.",)
parser.add_argument("--warmup_steps", default=0, type=int,
help="Linear warmup over warmup_steps.",)
parser.add_argument("--eval_steps", type=int, default=10,
help="eval every X updates steps.",)
parser.add_argument("--save_steps", type=int, default=50,
help="Save checkpoint every X updates steps.",)
parser.add_argument("--seed", type=int, default=2048,
help="random seed for initialization",)
parser.add_argument("--ocr_rec_model_dir", default=None, type=str, ) parser.add_argument("--ocr_rec_model_dir", default=None, type=str, )
parser.add_argument("--ocr_det_model_dir", default=None, type=str, ) parser.add_argument("--ocr_det_model_dir", default=None, type=str, )
parser.add_argument("--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, ) parser.add_argument(
"--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
parser.add_argument("--infer_imgs", default=None, type=str, required=False) parser.add_argument("--infer_imgs", default=None, type=str, required=False)
parser.add_argument("--ocr_json_path", default=None, type=str, required=False, help="ocr prediction results") parser.add_argument("--ocr_json_path", default=None,
type=str, required=False, help="ocr prediction results")
# yapf: enable # yapf: enable
args = parser.parse_args() args = parser.parse_args()
return args return args
===========================train_params=========================== ===========================train_params===========================
model_name:PPOCRv2_ocr_det model_name:PPOCRv2_det
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
...@@ -26,7 +26,7 @@ null:null ...@@ -26,7 +26,7 @@ null:null
## ##
===========================infer_params=========================== ===========================infer_params===========================
Global.save_inference_dir:./output/ Global.save_inference_dir:./output/
Global.pretrained_model: Global.checkpoints:
norm_export:null norm_export:null
quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o
fpgm_export: fpgm_export:
......
...@@ -6,7 +6,7 @@ Global.use_gpu:True|True ...@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:fp32 Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128 Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
Global.pretrained_model:null Global.pretrained_model:null
train_model_name:latest train_model_name:latest
train_infer_img_dir:./inference/rec_inference train_infer_img_dir:./inference/rec_inference
...@@ -34,7 +34,7 @@ distill_export:null ...@@ -34,7 +34,7 @@ distill_export:null
export1:null export1:null
export2:null export2:null
inference_dir:Student inference_dir:Student
infer_model:./inference/ch_PP-OCRv2_rec_infer/ infer_model:./inference/ch_PP-OCRv2_rec_infer
infer_export:null infer_export:null
infer_quant:False infer_quant:False
inference:tools/infer/predict_rec.py inference:tools/infer/predict_rec.py
...@@ -45,7 +45,7 @@ inference:tools/infer/predict_rec.py ...@@ -45,7 +45,7 @@ inference:tools/infer/predict_rec.py
--use_tensorrt:False|True --use_tensorrt:False|True
--precision:fp32|fp16|int8 --precision:fp32|fp16|int8
--rec_model_dir: --rec_model_dir:
--image_dir:/inference/rec_inference --image_dir:./inference/rec_inference
null:null null:null
--benchmark:True --benchmark:True
null:null null:null
......
...@@ -6,15 +6,15 @@ Global.use_gpu:True|True ...@@ -6,15 +6,15 @@ Global.use_gpu:True|True
Global.auto_cast:fp32 Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128 Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
Global.pretrained_model:null Global.pretrained_model:null
train_model_name:latest train_model_name:latest
train_infer_img_dir:./inference/rec_inference train_infer_img_dir:./inference/rec_inference
null:null null:null
## ##
trainer:pact_train trainer:pact_train
norm_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o norm_train:null
pact_train:null pact_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
null:null null:null
...@@ -27,14 +27,14 @@ null:null ...@@ -27,14 +27,14 @@ null:null
===========================infer_params=========================== ===========================infer_params===========================
Global.save_inference_dir:./output/ Global.save_inference_dir:./output/
Global.pretrained_model: Global.pretrained_model:
norm_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o norm_export:null
quant_export: quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o
fpgm_export: fpgm_export: null
distill_export:null distill_export:null
export1:null export1:null
export2:null export2:null
inference_dir:Student inference_dir:Student
infer_model:./inference/ch_PP-OCRv2_rec_infer/ infer_model:./inference/ch_PP-OCRv2_rec_infer
infer_export:null infer_export:null
infer_quant:True infer_quant:True
inference:tools/infer/predict_rec.py inference:tools/infer/predict_rec.py
...@@ -45,7 +45,7 @@ inference:tools/infer/predict_rec.py ...@@ -45,7 +45,7 @@ inference:tools/infer/predict_rec.py
--use_tensorrt:False|True --use_tensorrt:False|True
--precision:fp32|fp16|int8 --precision:fp32|fp16|int8
--rec_model_dir: --rec_model_dir:
--image_dir:/inference/rec_inference --image_dir:./inference/rec_inference
null:null null:null
--benchmark:True --benchmark:True
null:null null:null
......
...@@ -4,7 +4,7 @@ python:python3.7 ...@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
Global.auto_cast:null Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=5|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=100|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4 Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null Global.pretrained_model:null
......
...@@ -4,7 +4,7 @@ python:python3.7 ...@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
Global.auto_cast:null Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=5|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=20|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4 Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null Global.pretrained_model:null
...@@ -26,7 +26,7 @@ null:null ...@@ -26,7 +26,7 @@ null:null
## ##
===========================infer_params=========================== ===========================infer_params===========================
Global.save_inference_dir:./output/ Global.save_inference_dir:./output/
Global.pretrained_model: Global.checkpoints:
norm_export:null norm_export:null
quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o quant_export:deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
fpgm_export:null fpgm_export:null
...@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py ...@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py
--image_dir:./inference/ch_det_data_50/all-sum-510/ --image_dir:./inference/ch_det_data_50/all-sum-510/
null:null null:null
--benchmark:True --benchmark:True
null:null null:null
\ No newline at end of file
...@@ -28,7 +28,7 @@ null:null ...@@ -28,7 +28,7 @@ null:null
Global.save_inference_dir:./output/ Global.save_inference_dir:./output/
Global.checkpoints: Global.checkpoints:
norm_export:null norm_export:null
quant_export:deploy/slim/quantization/export_model.py -ctest_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
fpgm_export:null fpgm_export:null
distill_export:null distill_export:null
export1:null export1:null
......
...@@ -12,22 +12,22 @@ train_model_name:latest ...@@ -12,22 +12,22 @@ train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null null:null
## ##
trainer:norm_train|pact_train|fpgm_export trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o quant_train:null
fpgm_export:deploy/slim/prune/export_prune_model.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o fpgm_train:null
distill_train:null distill_train:null
null:null null:null
null:null null:null
## ##
===========================eval_params=========================== ===========================eval_params===========================
eval:tools/eval.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
null:null null:null
## ##
===========================infer_params=========================== ===========================infer_params===========================
Global.save_inference_dir:./output/ Global.save_inference_dir:./output/
Global.pretrained_model: Global.pretrained_model:
norm_export:tools/export_model.py -c test_tipc/configs/ppocr_det_server/det_r50_vd_db.yml -o norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
quant_export:null quant_export:null
fpgm_export:null fpgm_export:null
distill_export:null distill_export:null
......
...@@ -35,7 +35,7 @@ export1:null ...@@ -35,7 +35,7 @@ export1:null
export2:null export2:null
## ##
train_model:./inference/det_r50_vd_pse/best_accuracy train_model:./inference/det_r50_vd_pse/best_accuracy
infer_export:tools/export_model.py -c test_tipc/cconfigs/det_r50_vd_pse_v2.0/det_r50_vd_pse.yml -o infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_pse_v2.0/det_r50_vd_pse.yml -o
infer_quant:False infer_quant:False
inference:tools/infer/predict_det.py inference:tools/infer/predict_det.py
--use_gpu:True|False --use_gpu:True|False
......
...@@ -62,7 +62,7 @@ Train: ...@@ -62,7 +62,7 @@ Train:
data_dir: ./train_data/icdar2015/text_localization/ data_dir: ./train_data/icdar2015/text_localization/
label_file_list: label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [0.1, 0.45, 0.3, 0.15] ratio_list: [1.0]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
......
...@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py ...@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py
--image_dir:./inference/ch_det_data_50/all-sum-510/ --image_dir:./inference/ch_det_data_50/all-sum-510/
null:null null:null
--benchmark:True --benchmark:True
null:null --det_algorithm:SAST
...@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py ...@@ -48,4 +48,4 @@ inference:tools/infer/predict_det.py
--image_dir:./inference/ch_det_data_50/all-sum-510/ --image_dir:./inference/ch_det_data_50/all-sum-510/
null:null null:null
--benchmark:True --benchmark:True
null:null --det_algorithm:SAST
...@@ -4,7 +4,7 @@ python:python3.7 ...@@ -4,7 +4,7 @@ python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
Global.auto_cast:null Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500 Global.epoch_num:lite_train_lite_infer=5|whole_train_whole_infer=500
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=14 Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=14
Global.pretrained_model:null Global.pretrained_model:null
...@@ -42,7 +42,7 @@ inference:tools/infer/predict_e2e.py ...@@ -42,7 +42,7 @@ inference:tools/infer/predict_e2e.py
--enable_mkldnn:True|False --enable_mkldnn:True|False
--cpu_threads:1|6 --cpu_threads:1|6
--rec_batch_num:1 --rec_batch_num:1
--use_tensorrt:False|True --use_tensorrt:False
--precision:fp32|fp16|int8 --precision:fp32|fp16|int8
--e2e_model_dir: --e2e_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/ --image_dir:./inference/ch_det_data_50/all-sum-510/
......
...@@ -6,7 +6,7 @@ Global.use_gpu:True|True ...@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128 Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null Global.pretrained_model:null
train_model_name:latest train_model_name:latest
train_infer_img_dir:./inference/rec_inference train_infer_img_dir:./inference/rec_inference
......
...@@ -6,7 +6,7 @@ Global.use_gpu:True|True ...@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128 Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null Global.pretrained_model:null
train_model_name:latest train_model_name:latest
train_infer_img_dir:./inference/rec_inference train_infer_img_dir:./inference/rec_inference
......
...@@ -37,7 +37,7 @@ export2:null ...@@ -37,7 +37,7 @@ export2:null
infer_model:null infer_model:null
infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
infer_quant:False infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="StarNet"
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:True|False --enable_mkldnn:True|False
--cpu_threads:1|6 --cpu_threads:1|6
......
...@@ -6,7 +6,7 @@ Global.use_gpu:True|True ...@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128 Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null Global.pretrained_model:null
train_model_name:latest train_model_name:latest
train_infer_img_dir:./inference/rec_inference train_infer_img_dir:./inference/rec_inference
......
...@@ -6,7 +6,7 @@ Global.use_gpu:True|True ...@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128 Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null Global.pretrained_model:null
train_model_name:latest train_model_name:latest
train_infer_img_dir:./inference/rec_inference train_infer_img_dir:./inference/rec_inference
......
...@@ -37,7 +37,7 @@ export2:null ...@@ -37,7 +37,7 @@ export2:null
infer_model:null infer_model:null
infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
infer_quant:False infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="StarNet"
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:True|False --enable_mkldnn:True|False
--cpu_threads:1|6 --cpu_threads:1|6
......
...@@ -6,7 +6,7 @@ Global.use_gpu:True|True ...@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:null Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128 Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null Global.pretrained_model:null
train_model_name:latest train_model_name:latest
train_infer_img_dir:./inference/rec_inference train_infer_img_dir:./inference/rec_inference
......
...@@ -25,7 +25,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -25,7 +25,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
# pretrain lite train data # pretrain lite train data
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
if [ ${model_name} == "ch_PPOCRv2_det" ]; then if [[ ${model_name} =~ "PPOCRv2_det" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_PP-OCRv2_det_distill_train.tar && cd ../ cd ./pretrain_models/ && tar xf ch_PP-OCRv2_det_distill_train.tar && cd ../
fi fi
...@@ -49,8 +49,8 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -49,8 +49,8 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
fi fi
if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ] || [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ] || [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
cd ./train_data && tar xf total_text_lite.tar && ln -s total_text && cd ../ cd ./train_data && tar xf total_text_lite.tar && ln -s total_text_lite total_text && cd ../
fi fi
if [ ${model_name} == "det_mv3_db_v2.0" ]; then if [ ${model_name} == "det_mv3_db_v2.0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
...@@ -78,15 +78,15 @@ elif [ ${MODE} = "whole_train_whole_infer" ];then ...@@ -78,15 +78,15 @@ elif [ ${MODE} = "whole_train_whole_infer" ];then
cd ./pretrain_models/ && tar xf ch_PP-OCRv2_det_distill_train.tar && cd ../ cd ./pretrain_models/ && tar xf ch_PP-OCRv2_det_distill_train.tar && cd ../
fi fi
if [ ${model_name} == "en_server_pgnetA" ]; then if [ ${model_name} == "en_server_pgnetA" ]; then
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/total_text.tar --no-check-certificate wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate
cd ./pretrain_models/ && tar xf en_server_pgnetA.tar && cd ../ cd ./pretrain_models/ && tar xf en_server_pgnetA.tar && cd ../
cd ./train_data && tar xf total_text.tar && ln -s total_text && cd ../ cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../
fi fi
if [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then if [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/total_text.tar --no-check-certificate wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
cd ./train_data && tar xf total_text.tar && ln -s total_text && cd ../ cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../
fi fi
elif [ ${MODE} = "lite_train_whole_infer" ];then elif [ ${MODE} = "lite_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
...@@ -103,59 +103,67 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then ...@@ -103,59 +103,67 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then
fi fi
elif [ ${MODE} = "whole_infer" ];then elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
cd ./inference && tar xf rec_inference.tar && cd ../
if [ ${model_name} = "ch_ppocr_mobile_v2.0_det" ]; then if [ ${model_name} = "ch_ppocr_mobile_v2.0_det" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_det_train" eval_model_name="ch_ppocr_mobile_v2.0_det_train"
rm -rf ./train_data/icdar2015 rm -rf ./train_data/icdar2015
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && cd ../ cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0_det" ]; then elif [ ${model_name} = "ch_ppocr_server_v2.0_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_train.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_ppocr_server_v2.0_det_train.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0" ]; then elif [ ${model_name} = "ch_ppocr_mobile_v2.0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0" ]; then elif [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec" ]; then elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_rec_infer" eval_model_name="ch_ppocr_mobile_v2.0_rec_infer"
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf rec_inference.tar && cd ../ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0_rec" ]; then elif [ ${model_name} = "ch_ppocr_server_v2.0_rec" ]; then
eval_model_name="ch_ppocr_server_v2.0_rec_infer" eval_model_name="ch_ppocr_server_v2.0_rec_infer"
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf rec_inference.tar && cd ../ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_PACT" ]; then
eval_model_name="ch_PP-OCRv2_rec_slim_quant_train"
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_FPGM" ]; then
eval_model_name="ch_PP-OCRv2_rec_train"
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
fi fi
if [ ${model_name} = "ch_PPOCRv2_det" ]; then if [[ ${model_name} =~ "ch_PPOCRv2_det" ]]; then
eval_model_name="ch_PP-OCRv2_det_infer" eval_model_name="ch_PP-OCRv2_det_infer"
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../
fi fi
if [[ ${model_name} =~ "PPOCRv2_ocr_rec" ]]; then
eval_model_name="ch_PP-OCRv2_rec_infer"
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
fi
if [ ${model_name} == "en_server_pgnetA" ]; then if [ ${model_name} == "en_server_pgnetA" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate
cd ./inference && tar xf en_server_pgnetA.tar && cd ../ cd ./inference && tar xf en_server_pgnetA.tar && tar xf ch_det_data_50.tar && cd ../
fi fi
if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ]; then if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_sast_icdar15_v2.0_train.tar && cd ../ cd ./inference/ && tar xf det_r50_vd_sast_icdar15_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi fi
if [ ${model_name} == "det_mv3_db_v2.0" ]; then if [ ${model_name} == "det_mv3_db_v2.0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && cd ../ cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi fi
if [ ${model_name} == "det_r50_db_v2.0" ]; then if [ ${model_name} == "det_r50_db_v2.0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && cd ../ cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi fi
fi fi
if [ ${MODE} = "klquant_whole_infer" ]; then if [ ${MODE} = "klquant_whole_infer" ]; then
......
...@@ -64,10 +64,11 @@ function func_cpp_inference(){ ...@@ -64,10 +64,11 @@ function func_cpp_inference(){
set_infer_data=$(func_set_params "${cpp_image_dir_key}" "${_img_dir}") set_infer_data=$(func_set_params "${cpp_image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${cpp_benchmark_key}" "${cpp_benchmark_value}") set_benchmark=$(func_set_params "${cpp_benchmark_key}" "${cpp_benchmark_value}")
set_batchsize=$(func_set_params "${cpp_batch_size_key}" "${batch_size}") set_batchsize=$(func_set_params "${cpp_batch_size_key}" "${batch_size}")
set_mkldnn=$(func_set_params "${cpp_use_mkldnn_key}" "${use_mkldnn}")
set_cpu_threads=$(func_set_params "${cpp_cpu_threads_key}" "${threads}") set_cpu_threads=$(func_set_params "${cpp_cpu_threads_key}" "${threads}")
set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}") set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}")
set_infer_params1=$(func_set_params "${cpp_infer_key1}" "${cpp_infer_value1}") set_infer_params1=$(func_set_params "${cpp_infer_key1}" "${cpp_infer_value1}")
command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${cpp_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 " command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command eval $command
last_status=${PIPESTATUS[0]} last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}" eval "cat ${_save_log_path}"
......
...@@ -79,11 +79,12 @@ function func_inference(){ ...@@ -79,11 +79,12 @@ function func_inference(){
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}") set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}") set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}") set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
set_mkldnn=$(func_set_params "${use_mkldnn_key}" "${use_mkldnn}")
set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}") set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}") set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
set_infer_params0=$(func_set_params "${rec_model_key}" "${rec_model_value}") set_infer_params0=$(func_set_params "${rec_model_key}" "${rec_model_value}")
set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}") set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 " command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command eval $command
last_status=${PIPESTATUS[0]} last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}" eval "cat ${_save_log_path}"
......
...@@ -160,11 +160,12 @@ function func_inference(){ ...@@ -160,11 +160,12 @@ function func_inference(){
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}") set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}") set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}") set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
set_mkldnn=$(func_set_params "${use_mkldnn_key}" "${use_mkldnn}")
set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}") set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}") set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
set_infer_params0=$(func_set_params "${save_log_key}" "${save_log_value}") set_infer_params0=$(func_set_params "${save_log_key}" "${save_log_value}")
set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}") set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 " command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command eval $command
last_status=${PIPESTATUS[0]} last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}" eval "cat ${_save_log_path}"
...@@ -321,10 +322,6 @@ else ...@@ -321,10 +322,6 @@ else
save_log="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}" save_log="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}"
fi fi
# load pretrain from norm training if current trainer is pact or fpgm trainer
if ([ ${trainer} = ${pact_key} ] || [ ${trainer} = ${fpgm_key} ]) && [ ${nodes} -le 1 ]; then
set_pretrain="${load_norm_train_model}"
fi
set_save_model=$(func_set_params "${save_model_key}" "${save_log}") set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
if [ ${#gpu} -le 2 ];then # train with cpu or single gpu if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
...@@ -340,10 +337,7 @@ else ...@@ -340,10 +337,7 @@ else
status_check $? "${cmd}" "${status_log}" status_check $? "${cmd}" "${status_log}"
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}") set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}")
# save norm trained models to set pretrain for pact training and fpgm training
if [ ${trainer} = ${trainer_norm} ] && [ ${nodes} -le 1 ]; then
load_norm_train_model=${set_eval_pretrain}
fi
# run eval # run eval
if [ ${eval_py} != "null" ]; then if [ ${eval_py} != "null" ]; then
set_eval_params1=$(func_set_params "${eval_key1}" "${eval_value1}") set_eval_params1=$(func_set_params "${eval_key1}" "${eval_value1}")
......
...@@ -195,6 +195,7 @@ def create_predictor(args, mode, logger): ...@@ -195,6 +195,7 @@ def create_predictor(args, mode, logger):
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
min_subgraph_size=args.min_subgraph_size) min_subgraph_size=args.min_subgraph_size)
# skip the minmum trt subgraph # skip the minmum trt subgraph
use_dynamic_shape = True
if mode == "det": if mode == "det":
min_input_shape = { min_input_shape = {
"x": [1, 3, 50, 50], "x": [1, 3, 50, 50],
...@@ -260,6 +261,8 @@ def create_predictor(args, mode, logger): ...@@ -260,6 +261,8 @@ def create_predictor(args, mode, logger):
max_input_shape.update(max_pact_shape) max_input_shape.update(max_pact_shape)
opt_input_shape.update(opt_pact_shape) opt_input_shape.update(opt_pact_shape)
elif mode == "rec": elif mode == "rec":
if args.rec_algorithm != "CRNN":
use_dynamic_shape = False
min_input_shape = {"x": [1, 3, 32, 10]} min_input_shape = {"x": [1, 3, 32, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 1536]} max_input_shape = {"x": [args.rec_batch_num, 3, 32, 1536]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]} opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
...@@ -268,11 +271,10 @@ def create_predictor(args, mode, logger): ...@@ -268,11 +271,10 @@ def create_predictor(args, mode, logger):
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]} max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]} opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
else: else:
min_input_shape = {"x": [1, 3, 10, 10]} use_dynamic_shape = False
max_input_shape = {"x": [1, 3, 512, 512]} if use_dynamic_shape:
opt_input_shape = {"x": [1, 3, 256, 256]} config.set_trt_dynamic_shape_info(
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, min_input_shape, max_input_shape, opt_input_shape)
opt_input_shape)
else: else:
config.disable_gpu() config.disable_gpu()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册