提交 5f0c84a8 编写于 作者: 文幕地方's avatar 文幕地方

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into table_pr

......@@ -14,6 +14,9 @@ Global:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
use_amp: False
amp_level: O2
amp_custom_black_list: ['exp']
Architecture:
name: DistillationModel
......
Global:
use_gpu: true
epoch_num: 5
log_smooth_window: 20
print_batch_step: 20
save_model_dir: ./output/rec/rec_r31_robustscanner/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: ./inference/rec_inference
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
max_text_length: &max_text_length 40
infer_mode: False
use_space_char: False
rm_symbol: True
save_res_path: ./output/rec/predicts_robustscanner.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs: [3, 4]
values: [0.001, 0.0001, 0.00001]
regularizer:
name: 'L2'
factor: 0
Architecture:
model_type: rec
algorithm: RobustScanner
Transform:
Backbone:
name: ResNet31
init_type: KaimingNormal
Head:
name: RobustScannerHead
enc_outchannles: 128
hybrid_dec_rnn_layers: 2
hybrid_dec_dropout: 0
position_dec_rnn_layers: 2
start_idx: 91
mask: True
padding_idx: 92
encode_value: False
max_text_length: *max_text_length
Loss:
name: SARLoss
PostProcess:
name: SARLabelDecode
Metric:
name: RecMetric
is_filter: True
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SARLabelEncode: # Class handling label
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
width_downsample_ratio: 0.25
max_text_length: *max_text_length
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 64
drop_last: True
num_workers: 8
use_shared_memory: False
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SARLabelEncode: # Class handling label
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160]
max_text_length: *max_text_length
width_downsample_ratio: 0.25
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 64
num_workers: 4
use_shared_memory: False
......@@ -72,6 +72,7 @@
- [x] [ABINet](./algorithm_rec_abinet.md)
- [x] [VisionLAN](./algorithm_rec_visionlan.md)
- [x] [SPIN](./algorithm_rec_spin.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner.md)
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
......@@ -94,6 +95,7 @@
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
<a name="2"></a>
......
# RobustScanner
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition](https://arxiv.org/pdf/2007.07542.pdf)
> Xiaoyu Yue, Zhanghui Kuang, Chenhao Lin, Hongbin Sun, Wayne
Zhang
> ECCV, 2020
使用MJSynth和SynthText两个合成文字识别数据集训练,在IIIT, SVT, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
|模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- |
|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
注:除了使用MJSynth和SynthText两个文字识别数据集外,还加入了[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg)数据(提取码:627x),和部分真实数据,具体数据细节可以参考论文。
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
训练
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
```
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/rec/rec_r31_robustscanner.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r31_robustscanner.yml
```
评估
```
# GPU 评估, Global.pretrained_model 为待测权重
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
预测:
```
# 预测使用的配置文件必须与训练一致
python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
```
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将RobustScanner文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner
```
RobustScanner文本识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r31_robustscanner/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="RobustScanner" --rec_char_dict_path="ppocr/utils/dict90.txt" --use_space_char=False
```
<a name="4-2"></a>
### 4.2 C++推理
由于C++预处理后处理还未支持RobustScanner,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@article{2020RobustScanner,
title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition},
author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang},
journal={ECCV2020},
year={2020},
}
```
......@@ -79,7 +79,7 @@ python3 tools/export_model.py -c configs/rec/rec_r31_sar.yml -o Global.pretraine
SAR文本识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_char_type="ch" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False
```
<a name="4-2"></a>
......
......@@ -70,6 +70,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [ABINet](./algorithm_rec_abinet_en.md)
- [x] [VisionLAN](./algorithm_rec_visionlan_en.md)
- [x] [SPIN](./algorithm_rec_spin_en.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner_en.md)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
......@@ -92,6 +93,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
<a name="2"></a>
......
# RobustScanner
- [1. Introduction](#1)
- [2. Environment](#2)
- [3. Model Training / Evaluation / Prediction](#3)
- [3.1 Training](#3-1)
- [3.2 Evaluation](#3-2)
- [3.3 Prediction](#3-3)
- [4. Inference and Deployment](#4)
- [4.1 Python Inference](#4-1)
- [4.2 C++ Inference](#4-2)
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. Introduction
Paper:
> [RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition](https://arxiv.org/pdf/2007.07542.pdf)
> Xiaoyu Yue, Zhanghui Kuang, Chenhao Lin, Hongbin Sun, Wayne
Zhang
> ECCV, 2020
Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper.
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
Training:
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
```
#Single GPU training (long training period, not recommended)
python3 tools/train.py -c configs/rec/rec_r31_robustscanner.yml
#Multi GPU training, specify the gpu number through the --gpus parameter
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r31_robustscanner.yml
```
Evaluation:
```
# GPU evaluation
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
Prediction:
```
# The configuration file used for prediction must match the training
python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
```
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, the model saved during the RobustScanner text recognition training process is converted into an inference model. you can use the following command to convert:
```
python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner
```
For RobustScanner text recognition model inference, the following commands can be executed:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r31_robustscanner/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="RobustScanner" --rec_char_dict_path="ppocr/utils/dict90.txt" --use_space_char=False
```
<a name="4-2"></a>
### 4.2 C++ Inference
Not supported
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
## Citation
```bibtex
@article{2020RobustScanner,
title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition},
author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang},
journal={ECCV2020},
year={2020},
}
```
......@@ -79,7 +79,7 @@ python3 tools/export_model.py -c configs/rec/rec_r31_sar.yml -o Global.pretraine
For SAR text recognition model inference, the following commands can be executed:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_char_type="ch" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_sar/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="SAR" --rec_char_dict_path="ppocr/utils/dict90.txt" --max_text_length=30 --use_space_char=False
```
<a name="4-2"></a>
......
......@@ -26,8 +26,7 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
......
......@@ -414,6 +414,23 @@ class SVTRRecResizeImg(object):
data['valid_ratio'] = valid_ratio
return data
class RobustScannerRecResizeImg(object):
def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs):
self.image_shape = image_shape
self.width_downsample_ratio = width_downsample_ratio
self.max_text_length = max_text_length
def __call__(self, data):
img = data['image']
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
img, self.image_shape, self.width_downsample_ratio)
word_positons = np.array(range(0, self.max_text_length)).astype('int64')
data['image'] = norm_img
data['resized_shape'] = resize_shape
data['pad_shape'] = pad_shape
data['valid_ratio'] = valid_ratio
data['word_positons'] = word_positons
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
......
......@@ -29,27 +29,29 @@ import numpy as np
__all__ = ["ResNet31"]
def conv3x3(in_channel, out_channel, stride=1):
def conv3x3(in_channel, out_channel, stride=1, conv_weight_attr=None):
return nn.Conv2D(
in_channel,
out_channel,
kernel_size=3,
stride=stride,
padding=1,
weight_attr=conv_weight_attr,
bias_attr=False)
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self, in_channels, channels, stride=1, downsample=False):
def __init__(self, in_channels, channels, stride=1, downsample=False, conv_weight_attr=None, bn_weight_attr=None):
super().__init__()
self.conv1 = conv3x3(in_channels, channels, stride)
self.bn1 = nn.BatchNorm2D(channels)
self.conv1 = conv3x3(in_channels, channels, stride,
conv_weight_attr=conv_weight_attr)
self.bn1 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
self.relu = nn.ReLU()
self.conv2 = conv3x3(channels, channels)
self.bn2 = nn.BatchNorm2D(channels)
self.conv2 = conv3x3(channels, channels,
conv_weight_attr=conv_weight_attr)
self.bn2 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
self.downsample = downsample
if downsample:
self.downsample = nn.Sequential(
......@@ -58,8 +60,9 @@ class BasicBlock(nn.Layer):
channels * self.expansion,
1,
stride,
weight_attr=conv_weight_attr,
bias_attr=False),
nn.BatchNorm2D(channels * self.expansion), )
nn.BatchNorm2D(channels * self.expansion, weight_attr=bn_weight_attr))
else:
self.downsample = nn.Sequential()
self.stride = stride
......@@ -91,6 +94,7 @@ class ResNet31(nn.Layer):
channels (list[int]): List of out_channels of Conv2d layer.
out_indices (None | Sequence[int]): Indices of output stages.
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
init_type (None | str): the config to control the initialization.
'''
def __init__(self,
......@@ -98,7 +102,8 @@ class ResNet31(nn.Layer):
layers=[1, 2, 5, 3],
channels=[64, 128, 256, 256, 512, 512, 512],
out_indices=None,
last_stage_pool=False):
last_stage_pool=False,
init_type=None):
super(ResNet31, self).__init__()
assert isinstance(in_channels, int)
assert isinstance(last_stage_pool, bool)
......@@ -106,42 +111,55 @@ class ResNet31(nn.Layer):
self.out_indices = out_indices
self.last_stage_pool = last_stage_pool
conv_weight_attr = None
bn_weight_attr = None
if init_type is not None:
support_dict = ['KaimingNormal']
assert init_type in support_dict, Exception(
"resnet31 only support {}".format(support_dict))
conv_weight_attr = nn.initializer.KaimingNormal()
bn_weight_attr = ParamAttr(initializer=nn.initializer.Uniform(), learning_rate=1)
# conv 1 (Conv Conv)
self.conv1_1 = nn.Conv2D(
in_channels, channels[0], kernel_size=3, stride=1, padding=1)
self.bn1_1 = nn.BatchNorm2D(channels[0])
in_channels, channels[0], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn1_1 = nn.BatchNorm2D(channels[0], weight_attr=bn_weight_attr)
self.relu1_1 = nn.ReLU()
self.conv1_2 = nn.Conv2D(
channels[0], channels[1], kernel_size=3, stride=1, padding=1)
self.bn1_2 = nn.BatchNorm2D(channels[1])
channels[0], channels[1], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn1_2 = nn.BatchNorm2D(channels[1], weight_attr=bn_weight_attr)
self.relu1_2 = nn.ReLU()
# conv 2 (Max-pooling, Residual block, Conv)
self.pool2 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block2 = self._make_layer(channels[1], channels[2], layers[0])
self.block2 = self._make_layer(channels[1], channels[2], layers[0],
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
self.conv2 = nn.Conv2D(
channels[2], channels[2], kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2D(channels[2])
channels[2], channels[2], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn2 = nn.BatchNorm2D(channels[2], weight_attr=bn_weight_attr)
self.relu2 = nn.ReLU()
# conv 3 (Max-pooling, Residual block, Conv)
self.pool3 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block3 = self._make_layer(channels[2], channels[3], layers[1])
self.block3 = self._make_layer(channels[2], channels[3], layers[1],
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
self.conv3 = nn.Conv2D(
channels[3], channels[3], kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2D(channels[3])
channels[3], channels[3], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn3 = nn.BatchNorm2D(channels[3], weight_attr=bn_weight_attr)
self.relu3 = nn.ReLU()
# conv 4 (Max-pooling, Residual block, Conv)
self.pool4 = nn.MaxPool2D(
kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
self.block4 = self._make_layer(channels[3], channels[4], layers[2])
self.block4 = self._make_layer(channels[3], channels[4], layers[2],
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
self.conv4 = nn.Conv2D(
channels[4], channels[4], kernel_size=3, stride=1, padding=1)
self.bn4 = nn.BatchNorm2D(channels[4])
channels[4], channels[4], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn4 = nn.BatchNorm2D(channels[4], weight_attr=bn_weight_attr)
self.relu4 = nn.ReLU()
# conv 5 ((Max-pooling), Residual block, Conv)
......@@ -149,15 +167,16 @@ class ResNet31(nn.Layer):
if self.last_stage_pool:
self.pool5 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block5 = self._make_layer(channels[4], channels[5], layers[3])
self.block5 = self._make_layer(channels[4], channels[5], layers[3],
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
self.conv5 = nn.Conv2D(
channels[5], channels[5], kernel_size=3, stride=1, padding=1)
self.bn5 = nn.BatchNorm2D(channels[5])
channels[5], channels[5], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn5 = nn.BatchNorm2D(channels[5], weight_attr=bn_weight_attr)
self.relu5 = nn.ReLU()
self.out_channels = channels[-1]
def _make_layer(self, input_channels, output_channels, blocks):
def _make_layer(self, input_channels, output_channels, blocks, conv_weight_attr=None, bn_weight_attr=None):
layers = []
for _ in range(blocks):
downsample = None
......@@ -168,12 +187,14 @@ class ResNet31(nn.Layer):
output_channels,
kernel_size=1,
stride=1,
weight_attr=conv_weight_attr,
bias_attr=False),
nn.BatchNorm2D(output_channels), )
nn.BatchNorm2D(output_channels, weight_attr=bn_weight_attr))
layers.append(
BasicBlock(
input_channels, output_channels, downsample=downsample))
input_channels, output_channels, downsample=downsample,
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr))
input_channels = output_channels
return nn.Sequential(*layers)
......
......@@ -35,6 +35,7 @@ def build_head(config):
from .rec_multi_head import MultiHead
from .rec_spin_att_head import SPINAttentionHead
from .rec_abinet_head import ABINetHead
from .rec_robustscanner_head import RobustScannerHead
from .rec_visionlan_head import VLHead
# cls head
......@@ -51,7 +52,7 @@ def build_head(config):
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'SLAHead'
'VLHead', 'SLAHead', 'RobustScannerHead'
]
#table head
......
此差异已折叠。
text
title
figure
figure_caption
table
table_caption
header
footer
reference
equation
\ No newline at end of file
text
title
list
table
figure
\ No newline at end of file
......@@ -50,7 +50,7 @@ def get_check_global_params(mode):
def _check_image_file(path):
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'}
return any([path.lower().endswith(e) for e in img_end])
......@@ -59,7 +59,7 @@ def get_image_file_list(img_file):
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'}
if os.path.isfile(img_file) and _check_image_file(img_file):
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
......@@ -73,7 +73,7 @@ def get_image_file_list(img_file):
return imgs_lists
def check_and_read_gif(img_path):
def check_and_read(img_path):
if os.path.basename(img_path)[-3:] in ['gif', 'GIF']:
gif = cv2.VideoCapture(img_path)
ret, frame = gif.read()
......@@ -84,8 +84,26 @@ def check_and_read_gif(img_path):
if len(frame.shape) == 2 or frame.shape[-1] == 1:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
imgvalue = frame[:, :, ::-1]
return imgvalue, True
return None, False
return imgvalue, True, False
elif os.path.basename(img_path)[-3:] in ['pdf']:
import fitz
from PIL import Image
imgs = []
with fitz.open(img_path) as pdf:
for pg in range(0, pdf.pageCount):
page = pdf[pg]
mat = fitz.Matrix(2, 2)
pm = page.getPixmap(matrix=mat, alpha=False)
# if width or height > 2000 pixels, don't enlarge the image
if pm.width > 2000 or pm.height > 2000:
pm = page.getPixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
imgs.append(img)
return imgs, False, True
return None, False, False
def load_vqa_bio_label_maps(label_map_path):
......
此差异已折叠。
[English](README.md) | 简体中文
# 版面分析使用说明
- [1. 安装whl包](#1)
- [2. 使用](#2)
- [3. 后处理](#3)
- [4. 指标](#4)
- [5. 训练版面分析模型](#5)
<a name="1"></a>
## 1. 安装whl包
```bash
pip install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
```
<a name="2"></a>
## 2. 使用
使用layoutparser识别给定文档的布局:
```python
import cv2
import layoutparser as lp
image = cv2.imread("ppstructure/docs/table/layout.jpg")
image = image[..., ::-1]
# 加载模型
model = lp.PaddleDetectionLayoutModel(config_path="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
threshold=0.5,
label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},
enforce_cpu=False,
enable_mkldnn=True)
# 检测
layout = model.detect(image)
# 显示结果
show_img = lp.draw_box(image, layout, box_width=3, show_element_type=True)
show_img.show()
```
下图展示了结果,不同颜色的检测框表示不同的类别,并通过`show_element_type`在框的左上角显示具体类别:
<div align="center">
<img src="../docs/table/result_all.jpg" width = "600" />
</div>
`PaddleDetectionLayoutModel`函数参数说明如下:
| 参数 | 含义 | 默认值 | 备注 |
| :------------: | :-------------------------: | :---------: | :----------------------------------------------------------: |
| config_path | 模型配置路径 | None | 指定config_path会自动下载模型(仅第一次,之后模型存在,不会再下载) |
| model_path | 模型路径 | None | 本地模型路径,config_path和model_path必须设置一个,不能同时为None |
| threshold | 预测得分的阈值 | 0.5 | \ |
| input_shape | reshape之后图片尺寸 | [3,640,640] | \ |
| batch_size | 测试batch size | 1 | \ |
| label_map | 类别映射表 | None | 设置config_path时,可以为None,根据数据集名称自动获取label_map,设置model_path时需要手动指定 |
| enforce_cpu | 代码是否使用CPU运行 | False | 设置为False表示使用GPU,True表示强制使用CPU |
| enforce_mkldnn | CPU预测中是否开启MKLDNN加速 | True | \ |
| thread_num | 设置CPU线程数 | 10 | \ |
目前支持以下几种模型配置和label map,您可以通过修改 `--config_path``--label_map`使用这些模型,从而检测不同类型的内容:
| dataset | config_path | label_map |
| ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------------------------------------- |
| [TableBank](https://doc-analysis.github.io/tablebank-page/index.html) word | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_word/config | {0:"Table"} |
| TableBank latex | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_latex/config | {0:"Table"} |
| [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config | {0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"} |
* TableBank word和TableBank latex分别在word文档、latex文档数据集训练;
* 下载的TableBank数据集里同时包含word和latex。
<a name="3"></a>
## 3. 后处理
版面分析检测包含多个类别,如果只想获取指定类别(如"Text"类别)的检测框、可以使用下述代码:
```python
# 接上面代码
# 首先过滤特定文本类型的区域
text_blocks = lp.Layout([b for b in layout if b.type=='Text'])
figure_blocks = lp.Layout([b for b in layout if b.type=='Figure'])
# 因为在图像区域内可能检测到文本区域,所以只需要删除它们
text_blocks = lp.Layout([b for b in text_blocks \
if not any(b.is_in(b_fig) for b_fig in figure_blocks)])
# 对文本区域排序并分配id
h, w = image.shape[:2]
left_interval = lp.Interval(0, w/2*1.05, axis='x').put_on_canvas(image)
left_blocks = text_blocks.filter_by(left_interval, center=True)
left_blocks.sort(key = lambda b:b.coordinates[1])
right_blocks = [b for b in text_blocks if b not in left_blocks]
right_blocks.sort(key = lambda b:b.coordinates[1])
# 最终合并两个列表,并按顺序添加索引
text_blocks = lp.Layout([b.set(id = idx) for idx, b in enumerate(left_blocks + right_blocks)])
# 显示结果
show_img = lp.draw_box(image, text_blocks,
box_width=3,
show_element_id=True)
show_img.show()
```
显示只有"Text"类别的结果:
<div align="center">
<img src="../docs/table/result_text.jpg" width = "600" />
</div>
<a name="4"></a>
## 4. 指标
| Dataset | mAP | CPU time cost | GPU time cost |
| --------- | ---- | ------------- | ------------- |
| PubLayNet | 93.6 | 1713.7ms | 66.6ms |
| TableBank | 96.2 | 1968.4ms | 65.1ms |
**Envrionment:**
**CPU:** Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz,24core
**GPU:** a single NVIDIA Tesla P40
<a name="5"></a>
## 5. 训练版面分析模型
上述模型基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection) 训练,如果您想训练自己的版面分析模型,请参考:[train_layoutparser_model](train_layoutparser_model_ch.md)
English | [简体中文](train_layoutparser_model_ch.md)
- [Training layout-parse](#training-layout-parse)
- [1. Installation](#1--installation)
- [1.1 Requirements](#11-requirements)
- [1.2 Install PaddleDetection](#12-install-paddledetection)
- [2. Data preparation](#2-data-preparation)
- [3. Configuration](#3-configuration)
- [4. Training](#4-training)
- [5. Prediction](#5-prediction)
- [6. Deployment](#6-deployment)
- [6.1 Export model](#61-export-model)
- [6.2 Inference](#62-inference)
# Training layout-parse
## 1. Installation
### 1.1 Requirements
- PaddlePaddle 2.1
- OS 64 bit
- Python 3(3.5.1+/3.6/3.7/3.8/3.9),64 bit
- pip/pip3(9.0.1+), 64 bit
- CUDA >= 10.1
- cuDNN >= 7.6
### 1.2 Install PaddleDetection
```bash
# Clone PaddleDetection repository
cd <path/to/clone/PaddleDetection>
git clone https://github.com/PaddlePaddle/PaddleDetection.git
cd PaddleDetection
# Install other dependencies
pip install -r requirements.txt
```
For more installation tutorials, please refer to: [Install doc](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/INSTALL_cn.md)
## 2. Data preparation
Download the [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) dataset
```bash
cd PaddleDetection/dataset/
mkdir publaynet
# execute the command,download PubLayNet
wget -O publaynet.tar.gz https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz?_ga=2.104193024.1076900768.1622560733-649911202.1622560733
# unpack
tar -xvf publaynet.tar.gz
```
PubLayNet directory structure after decompressing :
| File or Folder | Description | num |
| :------------- | :----------------------------------------------- | ------- |
| `train/` | Images in the training subset | 335,703 |
| `val/` | Images in the validation subset | 11,245 |
| `test/` | Images in the testing subset | 11,405 |
| `train.json` | Annotations for training images | 1 |
| `val.json` | Annotations for validation images | 1 |
| `LICENSE.txt` | Plaintext version of the CDLA-Permissive license | 1 |
| `README.txt` | Text file with the file names and description | 1 |
For other datasets,please refer to [the PrepareDataSet]((https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/PrepareDataSet.md) )
## 3. Configuration
We use the `configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml` configuration for training,the configuration file is as follows
```bash
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/ppyolov2_r50vd_dcn.yml',
'./_base_/optimizer_365e.yml',
'./_base_/ppyolov2_reader.yml',
]
snapshot_epoch: 8
weights: output/ppyolov2_r50vd_dcn_365e_coco/model_final
```
The `ppyolov2_r50vd_dcn_365e_coco.yml` configuration depends on other configuration files, in this case:
- coco_detection.yml:mainly explains the path of training data and verification data
- runtime.yml:mainly describes the common parameters, such as whether to use the GPU and how many epoch to save model etc.
- optimizer_365e.yml:mainly explains the learning rate and optimizer configuration
- ppyolov2_r50vd_dcn.yml:mainly describes the model and the network
- ppyolov2_reader.yml:mainly describes the configuration of data readers, such as batch size and number of concurrent loading child processes, and also includes post preprocessing, such as resize and data augmention etc.
Modify the preceding files, such as the dataset path and batch size etc.
## 4. Training
PaddleDetection provides single-card/multi-card training mode to meet various training needs of users:
* GPU single card training
```bash
export CUDA_VISIBLE_DEVICES=0 #Don't need to run this command on Windows and Mac
python tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml
```
* GPU multi-card training
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval
```
--eval: training while verifying
* Model recovery training
During the daily training, if training is interrupted due to some reasons, you can use the -r command to resume the training:
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval -r output/ppyolov2_r50vd_dcn_365e_coco/10000
```
Note: If you encounter "`Out of memory error`" , try reducing `batch_size` in the `ppyolov2_reader.yml` file
## 5. Prediction
Set parameters and use PaddleDetection to predict:
```bash
export CUDA_VISIBLE_DEVICES=0
python tools/infer.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --infer_img=images/paper-image.jpg --output_dir=infer_output/ --draw_threshold=0.5 -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final --use_vdl=Ture
```
`--draw_threshold` is an optional parameter. According to the calculation of [NMS](https://ieeexplore.ieee.org/document/1699659), different threshold will produce different results, ` keep_top_k ` represent the maximum amount of output target, the default value is 10. You can set different value according to your own actual situation。
## 6. Deployment
Use your trained model in Layout Parser
### 6.1 Export model
n the process of model training, the model file saved contains the process of forward prediction and back propagation. In the actual industrial deployment, there is no need for back propagation. Therefore, the model should be translated into the model format required by the deployment. The `tools/export_model.py` script is provided in PaddleDetection to export the model.
The exported model name defaults to `model.*`, Layout Parser's code model is `inference.*`, So change [PaddleDetection/ppdet/engine/trainer. Py ](https://github.com/PaddlePaddle/PaddleDetection/blob/b87a1ea86fa18ce69e44a17ad1b49c1326f19ff9/ppdet/engine/trainer.py# L512) (click on the link to see the detailed line of code), change 'model' to 'inference'.
Execute the script to export model:
```bash
python tools/export_model.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --output_dir=./inference -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final.pdparams
```
The prediction model is exported to `inference/ppyolov2_r50vd_dcn_365e_coco` ,including:`infer_cfg.yml`(prediction not required), `inference.pdiparams`, `inference.pdiparams.info`,`inference.pdmodel`
More model export tutorials, please refer to:[EXPORT_MODEL](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/deploy/EXPORT_MODEL.md)
### 6.2 Inference
`model_path` represent the trained model path, and layoutparser is used to predict:
```bash
import layoutparser as lp
model = lp.PaddleDetectionLayoutModel(model_path="inference/ppyolov2_r50vd_dcn_365e_coco", threshold=0.5,label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},enforce_cpu=True,enable_mkldnn=True)
```
***
More PaddleDetection training tutorials,please reference:[PaddleDetection Training](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/GETTING_STARTED_cn.md)
***
[English](train_layoutparser_model.md) | 简体中文
- [训练版面分析](#训练版面分析)
- [1. 安装](#1-安装)
- [1.1 环境要求](#11-环境要求)
- [1.2 安装PaddleDetection](#12-安装paddledetection)
- [2. 准备数据](#2-准备数据)
- [3. 配置文件改动和说明](#3-配置文件改动和说明)
- [4. PaddleDetection训练](#4-paddledetection训练)
- [5. PaddleDetection预测](#5-paddledetection预测)
- [6. 预测部署](#6-预测部署)
- [6.1 模型导出](#61-模型导出)
- [6.2 layout_parser预测](#62-layout_parser预测)
# 训练版面分析
## 1. 安装
### 1.1 环境要求
- PaddlePaddle 2.1
- OS 64 bit
- Python 3(3.5.1+/3.6/3.7/3.8/3.9),64 bit
- pip/pip3(9.0.1+), 64 bit
- CUDA >= 10.1
- cuDNN >= 7.6
### 1.2 安装PaddleDetection
```bash
# 克隆PaddleDetection仓库
cd <path/to/clone/PaddleDetection>
git clone https://github.com/PaddlePaddle/PaddleDetection.git
cd PaddleDetection
# 安装其他依赖
pip install -r requirements.txt
```
更多安装教程,请参考: [Install doc](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/INSTALL_cn.md)
## 2. 准备数据
下载 [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) 数据集:
```bash
cd PaddleDetection/dataset/
mkdir publaynet
# 执行命令,下载
wget -O publaynet.tar.gz https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz?_ga=2.104193024.1076900768.1622560733-649911202.1622560733
# 解压
tar -xvf publaynet.tar.gz
```
解压之后PubLayNet目录结构:
| File or Folder | Description | num |
| :------------- | :----------------------------------------------- | ------- |
| `train/` | Images in the training subset | 335,703 |
| `val/` | Images in the validation subset | 11,245 |
| `test/` | Images in the testing subset | 11,405 |
| `train.json` | Annotations for training images | 1 |
| `val.json` | Annotations for validation images | 1 |
| `LICENSE.txt` | Plaintext version of the CDLA-Permissive license | 1 |
| `README.txt` | Text file with the file names and description | 1 |
如果使用其它数据集,请参考[准备训练数据](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/PrepareDataSet.md)
## 3. 配置文件改动和说明
我们使用 `configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml`配置进行训练,配置文件摘要如下:
```bash
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/ppyolov2_r50vd_dcn.yml',
'./_base_/optimizer_365e.yml',
'./_base_/ppyolov2_reader.yml',
]
snapshot_epoch: 8
weights: output/ppyolov2_r50vd_dcn_365e_coco/model_final
```
从中可以看到 `ppyolov2_r50vd_dcn_365e_coco.yml` 配置需要依赖其他的配置文件,在该例子中需要依赖:
- coco_detection.yml:主要说明了训练数据和验证数据的路径
- runtime.yml:主要说明了公共的运行参数,比如是否使用GPU、每多少个epoch存储checkpoint等
- optimizer_365e.yml:主要说明了学习率和优化器的配置
- ppyolov2_r50vd_dcn.yml:主要说明模型和主干网络的情况
- ppyolov2_reader.yml:主要说明数据读取器配置,如batch size,并发加载子进程数等,同时包含读取后预处理操作,如resize、数据增强等等
根据实际情况,修改上述文件,比如数据集路径、batch size等。
## 4. PaddleDetection训练
PaddleDetection提供了单卡/多卡训练模式,满足用户多种训练需求
* GPU 单卡训练
```bash
export CUDA_VISIBLE_DEVICES=0 #windows和Mac下不需要执行该命令
python tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml
```
* GPU多卡训练
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval
```
--eval:表示边训练边验证
* 模型恢复训练
在日常训练过程中,有的用户由于一些原因导致训练中断,用户可以使用-r的命令恢复训练:
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval -r output/ppyolov2_r50vd_dcn_365e_coco/10000
```
注意:如果遇到 "`Out of memory error`" 问题, 尝试在 `ppyolov2_reader.yml` 文件中调小`batch_size`
## 5. PaddleDetection预测
设置参数,使用PaddleDetection预测:
```bash
export CUDA_VISIBLE_DEVICES=0
python tools/infer.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --infer_img=images/paper-image.jpg --output_dir=infer_output/ --draw_threshold=0.5 -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final --use_vdl=Ture
```
`--draw_threshold` 是个可选参数. 根据 [NMS](https://ieeexplore.ieee.org/document/1699659) 的计算,不同阈值会产生不同的结果 `keep_top_k`表示设置输出目标的最大数量,默认值为100,用户可以根据自己的实际情况进行设定。
## 6. 预测部署
在layout parser中使用自己训练好的模型。
### 6.1 模型导出
在模型训练过程中保存的模型文件是包含前向预测和反向传播的过程,在实际的工业部署则不需要反向传播,因此需要将模型进行导成部署需要的模型格式。 在PaddleDetection中提供了 `tools/export_model.py`脚本来导出模型。
导出模型名称默认是`model.*`,layout parser代码模型名称是`inference.*`, 所以修改[PaddleDetection/ppdet/engine/trainer.py ](https://github.com/PaddlePaddle/PaddleDetection/blob/b87a1ea86fa18ce69e44a17ad1b49c1326f19ff9/ppdet/engine/trainer.py#L512) (点开链接查看详细代码行),将`model`改为`inference`即可。
执行导出模型脚本:
```bash
python tools/export_model.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --output_dir=./inference -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final.pdparams
```
预测模型会导出到`inference/ppyolov2_r50vd_dcn_365e_coco`目录下,分别为`infer_cfg.yml`(预测不需要), `inference.pdiparams`, `inference.pdiparams.info`,`inference.pdmodel`
更多模型导出教程,请参考:[EXPORT_MODEL](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/deploy/EXPORT_MODEL.md)
### 6.2 layout_parser预测
`model_path`指定训练好的模型路径,使用layout parser进行预测:
```bash
import layoutparser as lp
model = lp.PaddleDetectionLayoutModel(model_path="inference/ppyolov2_r50vd_dcn_365e_coco", threshold=0.5,label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},enforce_cpu=True,enable_mkldnn=True)
```
***
更多PaddleDetection训练教程,请参考:[PaddleDetection训练](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/GETTING_STARTED_cn.md)
***
......@@ -28,13 +28,12 @@ import time
import logging
from copy import deepcopy
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.utility import get_image_file_list, check_and_read
from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from ppstructure.layout.predict_layout import LayoutPredictor
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args, draw_structure_result
from ppstructure.recovery.recovery_to_doc import convert_info_docx
logger = get_logger()
......@@ -78,7 +77,7 @@ class StructureSystem(object):
elif self.mode == 'vqa':
raise NotImplementedError
def __call__(self, img, return_ocr_result_in_table=False):
def __call__(self, img, img_idx=0, return_ocr_result_in_table=False):
time_dict = {
'image_orientation': 0,
'layout': 0,
......@@ -143,8 +142,8 @@ class StructureSystem(object):
time_dict['det'] += ocr_time_dict['det']
time_dict['rec'] += ocr_time_dict['rec']
# remove style char,
# when using the recognition model trained on the PubtabNet dataset,
# remove style char,
# when using the recognition model trained on the PubtabNet dataset,
# it will recognize the text format in the table, such as <b>
style_token = [
'<strike>', '<strike>', '<sup>', '</sub>', '<b>',
......@@ -169,7 +168,8 @@ class StructureSystem(object):
'type': region['label'].lower(),
'bbox': [x1, y1, x2, y2],
'img': roi_img,
'res': res
'res': res,
'img_idx': img_idx
})
end = time.time()
time_dict['all'] = end - start
......@@ -179,26 +179,29 @@ class StructureSystem(object):
return None, None
def save_structure_res(res, save_folder, img_name):
def save_structure_res(res, save_folder, img_name, img_idx=0):
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
res_cp = deepcopy(res)
# save res
with open(
os.path.join(excel_save_folder, 'res.txt'), 'w',
os.path.join(excel_save_folder, 'res_{}.txt'.format(img_idx)),
'w',
encoding='utf8') as f:
for region in res_cp:
roi_img = region.pop('img')
f.write('{}\n'.format(json.dumps(region)))
if region['type'] == 'table' and len(region[
if region['type'].lower() == 'table' and len(region[
'res']) > 0 and 'html' in region['res']:
excel_path = os.path.join(excel_save_folder,
'{}.xlsx'.format(region['bbox']))
excel_path = os.path.join(
excel_save_folder,
'{}_{}.xlsx'.format(region['bbox'], img_idx))
to_excel(region['res']['html'], excel_path)
elif region['type'] == 'figure':
img_path = os.path.join(excel_save_folder,
'{}.jpg'.format(region['bbox']))
elif region['type'].lower() == 'figure':
img_path = os.path.join(
excel_save_folder,
'{}_{}.jpg'.format(region['bbox'], img_idx))
cv2.imwrite(img_path, roi_img)
......@@ -214,28 +217,75 @@ def main(args):
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
img, flag_gif, flag_pdf = check_and_read(image_file)
img_name = os.path.basename(image_file).split('.')[0]
if not flag:
if not flag_gif and not flag_pdf:
img = cv2.imread(image_file)
if img is None:
logger.error("error in loading image:{}".format(image_file))
continue
res, time_dict = structure_sys(img)
if structure_sys.mode == 'structure':
save_structure_res(res, save_folder, img_name)
draw_img = draw_structure_result(img, res, args.vis_font_path)
img_save_path = os.path.join(save_folder, img_name, 'show.jpg')
elif structure_sys.mode == 'vqa':
raise NotImplementedError
# draw_img = draw_ser_results(img, res, args.vis_font_path)
# img_save_path = os.path.join(save_folder, img_name + '.jpg')
cv2.imwrite(img_save_path, draw_img)
logger.info('result save to {}'.format(img_save_path))
if args.recovery:
convert_info_docx(img, res, save_folder, img_name)
if not flag_pdf:
if img is None:
logger.error("error in loading image:{}".format(image_file))
continue
res, time_dict = structure_sys(img)
if structure_sys.mode == 'structure':
save_structure_res(res, save_folder, img_name)
draw_img = draw_structure_result(img, res, args.vis_font_path)
img_save_path = os.path.join(save_folder, img_name, 'show.jpg')
elif structure_sys.mode == 'vqa':
raise NotImplementedError
# draw_img = draw_ser_results(img, res, args.vis_font_path)
# img_save_path = os.path.join(save_folder, img_name + '.jpg')
cv2.imwrite(img_save_path, draw_img)
logger.info('result save to {}'.format(img_save_path))
if args.recovery:
try:
from ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx
h, w, _ = img.shape
res = sorted_layout_boxes(res, w)
convert_info_docx(img, res, save_folder, img_name,
args.save_pdf)
except Exception as ex:
logger.error(
"error in layout recovery image:{}, err msg: {}".format(
image_file, ex))
continue
else:
pdf_imgs = img
all_res = []
for index, img in enumerate(pdf_imgs):
res, time_dict = structure_sys(img, index)
if structure_sys.mode == 'structure' and res != []:
save_structure_res(res, save_folder, img_name, index)
draw_img = draw_structure_result(img, res,
args.vis_font_path)
img_save_path = os.path.join(save_folder, img_name,
'show_{}.jpg'.format(index))
elif structure_sys.mode == 'vqa':
raise NotImplementedError
# draw_img = draw_ser_results(img, res, args.vis_font_path)
# img_save_path = os.path.join(save_folder, img_name + '.jpg')
if res != []:
cv2.imwrite(img_save_path, draw_img)
logger.info('result save to {}'.format(img_save_path))
if args.recovery and res != []:
from ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx
h, w, _ = img.shape
res = sorted_layout_boxes(res, w)
all_res += res
if args.recovery and all_res != []:
try:
convert_info_docx(img, all_res, save_folder, img_name,
args.save_pdf)
except Exception as ex:
logger.error(
"error in layout recovery image:{}, err msg: {}".format(
image_file, ex))
continue
logger.info("Predict time : {:.3f}s".format(time_dict['all']))
......
......@@ -78,9 +78,27 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar
# Download the ultra-lightweight English table inch model and unzip it
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
# Download the layout model of publaynet dataset and unzip it
wget
https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout_infer.tar && tar picodet_lcnet_x1_0_layout_infer.tar
cd ..
# run
python3 predict_system.py --det_model_dir=inference/en_PP-OCRv3_det_infer --rec_model_dir=inference/en_PP-OCRv3_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --rec_char_dict_path=../ppocr/utils/en_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output ./output/table --rec_image_shape=3,48,320 --vis_font_path=../doc/fonts/simfang.ttf --recovery=True --image_dir=./docs/table/1.png
python3 predict_system.py \
--image_dir=./docs/table/1.png \
--det_model_dir=inference/en_PP-OCRv3_det_infer \
--rec_model_dir=inference/en_PP-OCRv3_rec_infe \
--rec_char_dict_path=../ppocr/utils/en_dict.txt \
--output=../output/ \
--table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--table_max_len=488 \
--layout_model_dir=inference/picodet_lcnet_x1_0_layout_infer \
--layout_dict_path=../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--recovery=True \
--save_pdf=False
```
After running, the docx of each picture will be saved in the directory specified by the output field
\ No newline at end of file
After running, the docx of each picture will be saved in the directory specified by the output field
Recovery table to Word code[table_process.py] reference:https://github.com/pqzx/html2docx.git
\ No newline at end of file
......@@ -35,21 +35,15 @@
python3 -m pip install --upgrade pip
# GPU安装
python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple
python3 -m pip install "paddlepaddle-gpu>=2.3" -i https://mirror.baidu.com/pypi/simple
# CPU安装
python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple
python3 -m pip install "paddlepaddle>=2.3" -i https://mirror.baidu.com/pypi/simple
```
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
* **(2)安装依赖**
```bash
python3 -m pip install -r ppstructure/recovery/requirements.txt
```
<a name="2.2"></a>
### 2.2 安装PaddleOCR
......@@ -87,11 +81,28 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
# 下载英文轻量级PP-OCRv3模型的识别模型并解压
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar
# 下载超轻量级英文表格英寸模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar
# 下载英文版面分析模型
wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout_infer.tar && tar picodet_lcnet_x1_0_layout_infer.tar
cd ..
# 执行预测
python3 predict_system.py --det_model_dir=inference/en_PP-OCRv3_det_infer --rec_model_dir=inference/en_PP-OCRv3_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --rec_char_dict_path=../ppocr/utils/en_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output ./output/table --rec_image_shape=3,48,320 --vis_font_path=../doc/fonts/simfang.ttf --recovery=True --image_dir=./docs/table/1.png
python3 predict_system.py \
--image_dir=./docs/table/1.png \
--det_model_dir=inference/en_PP-OCRv3_det_infer \
--rec_model_dir=inference/en_PP-OCRv3_rec_infe \
--rec_char_dict_path=../ppocr/utils/en_dict.txt \
--output=../output/ \
--table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--table_max_len=488 \
--layout_model_dir=inference/picodet_lcnet_x1_0_layout_infer \
--layout_dict_path=../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--recovery=True \
--save_pdf=False
```
运行完成后,每张图片的docx文档会保存到output字段指定的目录下
运行完成后,每张图片的docx文档会保存到`output`字段指定的目录下
表格恢复到Word代码[table_process.py]来自:https://github.com/pqzx/html2docx.git
......@@ -22,21 +22,23 @@ from docx import shared
from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.enum.section import WD_SECTION
from docx.oxml.ns import qn
from docx.enum.table import WD_TABLE_ALIGNMENT
from table_process import HtmlToDocx
from ppocr.utils.logging import get_logger
logger = get_logger()
def convert_info_docx(img, res, save_folder, img_name):
def convert_info_docx(img, res, save_folder, img_name, save_pdf):
doc = Document()
doc.styles['Normal'].font.name = 'Times New Roman'
doc.styles['Normal']._element.rPr.rFonts.set(qn('w:eastAsia'), u'宋体')
doc.styles['Normal'].font.size = shared.Pt(6.5)
h, w, _ = img.shape
res = sorted_layout_boxes(res, w)
flag = 1
for i, region in enumerate(res):
img_idx = region['img_idx']
if flag == 2 and region['layout'] == 'single':
section = doc.add_section(WD_SECTION.CONTINUOUS)
section._sectPr.xpath('./w:cols')[0].set(qn('w:num'), '1')
......@@ -46,10 +48,10 @@ def convert_info_docx(img, res, save_folder, img_name):
section._sectPr.xpath('./w:cols')[0].set(qn('w:num'), '2')
flag = 2
if region['type'] == 'Figure':
if region['type'].lower() == 'figure':
excel_save_folder = os.path.join(save_folder, img_name)
img_path = os.path.join(excel_save_folder,
'{}.jpg'.format(region['bbox']))
'{}_{}.jpg'.format(region['bbox'], img_idx))
paragraph_pic = doc.add_paragraph()
paragraph_pic.alignment = WD_ALIGN_PARAGRAPH.CENTER
run = paragraph_pic.add_run("")
......@@ -57,40 +59,38 @@ def convert_info_docx(img, res, save_folder, img_name):
run.add_picture(img_path, width=shared.Inches(5))
elif flag == 2:
run.add_picture(img_path, width=shared.Inches(2))
elif region['type'] == 'Title':
elif region['type'].lower() == 'title':
doc.add_heading(region['res'][0]['text'])
elif region['type'] == 'Text':
elif region['type'].lower() == 'table':
paragraph = doc.add_paragraph()
new_parser = HtmlToDocx()
new_parser.table_style = 'TableGrid'
table = new_parser.handle_table(html=region['res']['html'])
new_table = deepcopy(table)
new_table.alignment = WD_TABLE_ALIGNMENT.CENTER
paragraph.add_run().element.addnext(new_table._tbl)
else:
paragraph = doc.add_paragraph()
paragraph_format = paragraph.paragraph_format
for i, line in enumerate(region['res']):
if i == 0:
paragraph_format.first_line_indent = shared.Inches(0.25)
text_run = paragraph.add_run(line['text'] + ' ')
text_run.font.size = shared.Pt(9)
elif region['type'] == 'Table':
pypandoc.convert(
source=region['res']['html'],
format='html',
to='docx',
outputfile='tmp.docx')
tmp_doc = Document('tmp.docx')
paragraph = doc.add_paragraph()
table = tmp_doc.tables[0]
new_table = deepcopy(table)
new_table.style = doc.styles['Table Grid']
from docx.enum.table import WD_TABLE_ALIGNMENT
new_table.alignment = WD_TABLE_ALIGNMENT.CENTER
paragraph.add_run().element.addnext(new_table._tbl)
os.remove('tmp.docx')
else:
continue
text_run.font.size = shared.Pt(10)
# save to docx
docx_path = os.path.join(save_folder, '{}.docx'.format(img_name))
doc.save(docx_path)
logger.info('docx save to {}'.format(docx_path))
# save to pdf
if save_pdf:
pdf = os.path.join(save_folder, '{}.pdf'.format(img_name))
from docx2pdf import convert
convert(docx_path, pdf_path)
logger.info('pdf save to {}'.format(pdf))
def sorted_layout_boxes(res, w):
"""
......
opencv-contrib-python==4.4.0.46
pypandoc
python-docx
\ No newline at end of file
python-docx
docx2pdf
fitz
PyMuPDF
\ No newline at end of file
此差异已折叠。
......@@ -38,7 +38,7 @@ def init_args():
parser.add_argument(
"--layout_dict_path",
type=str,
default="../ppocr/utils/dict/layout_publaynet_dict.txt")
default="../ppocr/utils/dict/layout_dict/layout_pubalynet_dict.txt")
parser.add_argument(
"--layout_score_threshold",
type=float,
......@@ -89,6 +89,11 @@ def init_args():
type=bool,
default=False,
help='Whether to enable layout of recovery')
parser.add_argument(
"--save_pdf",
type=bool,
default=False,
help='Whether to save pdf file')
return parser
......
......@@ -58,10 +58,11 @@ function status_check(){
run_command=$2
run_log=$3
model_name=$4
log_path=$5
if [ $last_status -eq 0 ]; then
echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log}
echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log}
else
echo -e "\033[33m Run failed with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log}
echo -e "\033[33m Run failed with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log}
fi
}
......@@ -54,6 +54,6 @@ random_infer_input:[{float32,[3,488,488]}]
===========================train_benchmark_params==========================
batch_size:32
fp_items:fp32|fp16
epoch:1
epoch:2
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
......@@ -52,7 +52,7 @@ null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
===========================train_benchmark_params==========================
batch_size:4
batch_size:8
fp_items:fp32|fp16
epoch:3
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
......
Global:
use_gpu: true
epoch_num: 5
log_smooth_window: 20
print_batch_step: 20
save_model_dir: ./output/rec/rec_r31_robustscanner/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: ./inference/rec_inference
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
max_text_length: &max_text_length 40
infer_mode: False
use_space_char: False
rm_symbol: True
save_res_path: ./output/rec/predicts_robustscanner.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs: [3, 4]
values: [0.001, 0.0001, 0.00001]
regularizer:
name: 'L2'
factor: 0
Architecture:
model_type: rec
algorithm: RobustScanner
Transform:
Backbone:
name: ResNet31
init_type: KaimingNormal
Head:
name: RobustScannerHead
enc_outchannles: 128
hybrid_dec_rnn_layers: 2
hybrid_dec_dropout: 0
position_dec_rnn_layers: 2
start_idx: 91
mask: True
padding_idx: 92
encode_value: False
max_text_length: *max_text_length
Loss:
name: SARLoss
PostProcess:
name: SARLabelDecode
Metric:
name: RecMetric
is_filter: True
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SARLabelEncode: # Class handling label
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
width_downsample_ratio: 0.25
max_text_length: *max_text_length
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 16
drop_last: True
num_workers: 0
use_shared_memory: False
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SARLabelEncode: # Class handling label
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160]
max_text_length: *max_text_length
width_downsample_ratio: 0.25
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 16
num_workers: 0
use_shared_memory: False
===========================train_params===========================
model_name:rec_r31_robustscanner
python:python
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=5
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:./inference/rec_r31_robustscanner/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.txt --rec_image_shape="3,48,48,160" --use_space_char=False --rec_algorithm="RobustScanner"
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1|6
--use_tensorrt:False|False
--precision:fp32|int8
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,48,160]}]
Global:
use_gpu: true
epoch_num: 8
log_smooth_window: 200
print_batch_step: 200
save_model_dir: ./output/rec/r45_visionlan
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/en/word_2.png
# for data or label process
character_dict_path:
max_text_length: &max_text_length 25
training_step: &training_step LA
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_visionlan.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 20.0
group_lr: true
training_step: *training_step
lr:
name: Piecewise
decay_epochs: [6]
values: [0.0001, 0.00001]
regularizer:
name: 'L2'
factor: 0
Architecture:
model_type: rec
algorithm: VisionLAN
Transform:
Backbone:
name: ResNet45
strides: [2, 2, 2, 1, 1]
Head:
name: VLHead
n_layers: 3
n_position: 256
n_dim: 512
max_text_length: *max_text_length
training_step: *training_step
Loss:
name: VLLoss
mode: *training_step
weight_res: 0.5
weight_mas: 0.5
PostProcess:
name: VLLabelDecode
Metric:
name: RecMetric
is_filter: true
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- ABINetRecAug:
- VLLabelEncode: # Class handling label
- VLRecResizeImg:
image_shape: [3, 64, 256]
- KeepKeys:
keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 220
drop_last: True
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VLLabelEncode: # Class handling label
- VLRecResizeImg:
image_shape: [3, 64, 256]
- KeepKeys:
keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 64
num_workers: 4
===========================train_params===========================
model_name:det_r18_db_v2_0
model_name:rec_r45_visionlan
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_lite_infer=4
Train.loader.batch_size_per_card:lite_train_lite_infer=32|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
norm_train:tools/train.py -c configs/det/det_res18_db_v2.0.yml -o
quant_export:null
fpgm_export:null
norm_train:tools/train.py -c test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
===========================eval_params===========================
eval:tools/eval.py -c test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:null
quant_export:null
norm_export:tools/export_model.py -c test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:null
infer_export:null
train_model:./inference/rec_r45_visionlan_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,64,256" --rec_algorithm="VisionLAN" --use_space_char=False
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--rec_batch_num:1|6
--use_tensorrt:False
--precision:fp32
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
--save_log_path:null
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
batch_size:8|16
fp_items:fp32|fp16
epoch:15
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
random_infer_input:[{float32,[3,64,256]}]
......@@ -54,6 +54,7 @@
| NRTR |rec_mtb_nrtr | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| SAR |rec_r31_sar | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| SPIN |rec_r32_gaspin_bilstm_att | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| RobustScanner |rec_r31_robustscanner | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端| 支持 | 多机多卡 <br> 混合精度 | - | - |
| TableMaster |table_structure_tablemaster_train | 表格识别| 支持 | 多机多卡 <br> 混合精度 | - | - |
......
......@@ -84,7 +84,7 @@ function func_cpp_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
done
......@@ -117,7 +117,7 @@ function func_cpp_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
......
......@@ -88,7 +88,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
done
......@@ -119,7 +119,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
......@@ -146,14 +146,15 @@ if [ ${MODE} = "whole_infer" ]; then
for infer_model in ${infer_model_dir_list[*]}; do
# run export
if [ ${infer_run_exports[Count]} != "null" ];then
_save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}_infermodel_${infer_model}.log"
save_infer_dir=$(dirname $infer_model)
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key}"
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${_save_log_path} 2>&1 "
echo ${infer_run_exports[Count]}
eval $export_cmd
status_export=$?
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}"
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
else
save_infer_dir=${infer_model}
fi
......
......@@ -66,7 +66,7 @@ function func_paddle2onnx(){
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}"
# trans rec
set_dirname=$(func_set_params "--model_dir" "${rec_infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
......@@ -78,7 +78,7 @@ function func_paddle2onnx(){
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
elif [[ ${model_name} =~ "det" ]]; then
# trans det
set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}")
......@@ -91,7 +91,7 @@ function func_paddle2onnx(){
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}"
elif [[ ${model_name} =~ "rec" ]]; then
# trans rec
set_dirname=$(func_set_params "--model_dir" "${rec_infer_model_dir_value}")
......@@ -104,7 +104,7 @@ function func_paddle2onnx(){
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
fi
# python inference
......@@ -127,7 +127,7 @@ function func_paddle2onnx(){
eval $infer_model_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then
_save_log_path="${LOG_PATH}/paddle2onnx_infer_gpu.log"
set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu}")
......@@ -146,7 +146,7 @@ function func_paddle2onnx(){
eval $infer_model_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
else
echo "Does not support hardware other than CPU and GPU Currently!"
fi
......@@ -158,4 +158,4 @@ echo "################### run test ###################"
export Count=0
IFS="|"
func_paddle2onnx
\ No newline at end of file
func_paddle2onnx
......@@ -84,7 +84,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
done
......@@ -109,7 +109,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
......@@ -145,7 +145,7 @@ if [ ${MODE} = "whole_infer" ]; then
echo $export_cmd
eval $export_cmd
status_export=$?
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}"
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}"
else
save_infer_dir=${infer_model}
fi
......
......@@ -83,7 +83,7 @@ function func_serving(){
trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
python_list=(${python_list})
cd ${serving_dir_value}
......@@ -95,14 +95,14 @@ function func_serving(){
web_service_cpp_cmd="nohup ${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} > ${server_log_path} 2>&1 &"
eval $web_service_cpp_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
sleep 5s
_save_log_path="${LOG_PATH}/cpp_client_cpu.log"
cpp_client_cmd="${python_list[0]} ${cpp_client_py} ${det_client_value} ${rec_client_value} > ${_save_log_path} 2>&1"
eval $cpp_client_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9
else
server_log_path="${LOG_PATH}/cpp_server_gpu.log"
......@@ -114,7 +114,7 @@ function func_serving(){
eval $cpp_client_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9
fi
done
......
......@@ -126,19 +126,19 @@ function func_serving(){
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
elif [[ ${model_name} =~ "det" ]]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
elif [[ ${model_name} =~ "rec" ]]; then
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
fi
sleep 2s
for pipeline in ${pipeline_py[*]}; do
......@@ -147,7 +147,7 @@ function func_serving(){
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
sleep 2s
done
ps ux | grep -E 'web_service' | awk '{print $2}' | xargs kill -s 9
......@@ -177,19 +177,19 @@ function func_serving(){
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
elif [[ ${model_name} =~ "det" ]]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
elif [[ ${model_name} =~ "rec" ]]; then
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
fi
sleep 2s
for pipeline in ${pipeline_py[*]}; do
......@@ -198,7 +198,7 @@ function func_serving(){
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
sleep 2s
done
ps ux | grep -E 'web_service' | awk '{print $2}' | xargs kill -s 9
......
......@@ -133,7 +133,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
done
......@@ -164,7 +164,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
status_check $last_status "${command}" "${status_log}" "${model_name}"
status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
......@@ -201,7 +201,7 @@ if [ ${MODE} = "whole_infer" ]; then
echo $export_cmd
eval $export_cmd
status_export=$?
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}"
status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}"
else
save_infer_dir=${infer_model}
fi
......@@ -298,7 +298,7 @@ else
# run train
eval $cmd
eval "cat ${save_log}/train.log >> ${save_log}.log"
status_check $? "${cmd}" "${status_log}" "${model_name}"
status_check $? "${cmd}" "${status_log}" "${model_name}" "${save_log}.log"
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}")
......@@ -309,7 +309,7 @@ else
eval_log_path="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}_eval.log"
eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1} > ${eval_log_path} 2>&1 "
eval $eval_cmd
status_check $? "${eval_cmd}" "${status_log}" "${model_name}"
status_check $? "${eval_cmd}" "${status_log}" "${model_name}" "${eval_log_path}"
fi
# run export model
if [ ${run_export} != "null" ]; then
......@@ -320,7 +320,7 @@ else
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_path}")
export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
eval $export_cmd
status_check $? "${export_cmd}" "${status_log}" "${model_name}"
status_check $? "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}"
#run inference
eval $env
......
......@@ -73,7 +73,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"]
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
......
......@@ -58,6 +58,8 @@ def export_single_model(model,
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 48, 160], dtype="float32"),
[paddle.static.InputSpec(
shape=[None], dtype="float32")]
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "SVTR":
......@@ -109,6 +111,22 @@ def export_single_model(model,
shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "RobustScanner":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 48, 160], dtype="float32"),
[
paddle.static.InputSpec(
shape=[None, ],
dtype="float32"),
paddle.static.InputSpec(
shape=[None, max_text_length],
dtype="int64")
]
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(
......@@ -128,7 +146,7 @@ def export_single_model(model,
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
infer_shape = [3, 48, -1] # for rec model, H must be 32
infer_shape = [3, 32, -1] # for rec model, H must be 32
if "Transform" in arch_config and arch_config[
"Transform"] is not None and arch_config["Transform"][
"name"] == "TPS":
......@@ -234,4 +252,4 @@ def main():
if __name__ == "__main__":
main()
\ No newline at end of file
main()
......@@ -68,7 +68,7 @@ class TextRecognizer(object):
'name': 'SARLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
}
elif self.rec_algorithm == "VisionLAN":
postprocess_params = {
'name': 'VLLabelDecode',
......@@ -93,6 +93,13 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "RobustScanner":
postprocess_params = {
'name': 'SARLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char,
"rm_symbol": True
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
......@@ -390,6 +397,18 @@ class TextRecognizer(object):
img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "RobustScanner":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]], self.rec_image_shape, width_downsample_ratio=0.25)
norm_img = norm_img[np.newaxis, :]
valid_ratio = np.expand_dims(valid_ratio, axis=0)
valid_ratios = []
valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img)
word_positions_list = []
word_positions = np.array(range(0, 40)).astype('int64')
word_positions = np.expand_dims(word_positions, axis=0)
word_positions_list.append(word_positions)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
......@@ -437,10 +456,40 @@ class TextRecognizer(object):
preds = {"predict": outputs[2]}
elif self.rec_algorithm == "SAR":
valid_ratios = np.concatenate(valid_ratios)
inputs = [
norm_img_batch,
np.array(
[valid_ratios], dtype=np.float32),
]
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(self.output_tensors,
input_dict)
preds = outputs[0]
else:
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(
input_names[i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
if self.benchmark:
self.autolog.times.stamp()
preds = outputs[0]
elif self.rec_algorithm == "RobustScanner":
valid_ratios = np.concatenate(valid_ratios)
word_positions_list = np.concatenate(word_positions_list)
inputs = [
norm_img_batch,
valid_ratios,
word_positions_list
]
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
......
......@@ -231,89 +231,10 @@ def create_predictor(args, mode, logger):
)
config.enable_tuned_tensorrt_dynamic_shape(
args.shape_info_filename, True)
use_dynamic_shape = True
if mode == "det":
min_input_shape = {
"x": [1, 3, 50, 50],
"conv2d_92.tmp_0": [1, 120, 20, 20],
"conv2d_91.tmp_0": [1, 24, 10, 10],
"conv2d_59.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_1.tmp_0": [1, 256, 10, 10],
"nearest_interp_v2_2.tmp_0": [1, 256, 20, 20],
"conv2d_124.tmp_0": [1, 256, 20, 20],
"nearest_interp_v2_3.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_4.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_5.tmp_0": [1, 64, 20, 20],
"elementwise_add_7": [1, 56, 2, 2],
"nearest_interp_v2_0.tmp_0": [1, 256, 2, 2]
}
max_input_shape = {
"x": [1, 3, 1536, 1536],
"conv2d_92.tmp_0": [1, 120, 400, 400],
"conv2d_91.tmp_0": [1, 24, 200, 200],
"conv2d_59.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_1.tmp_0": [1, 256, 200, 200],
"conv2d_124.tmp_0": [1, 256, 400, 400],
"nearest_interp_v2_2.tmp_0": [1, 256, 400, 400],
"nearest_interp_v2_3.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_4.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_5.tmp_0": [1, 64, 400, 400],
"elementwise_add_7": [1, 56, 400, 400],
"nearest_interp_v2_0.tmp_0": [1, 256, 400, 400]
}
opt_input_shape = {
"x": [1, 3, 640, 640],
"conv2d_92.tmp_0": [1, 120, 160, 160],
"conv2d_91.tmp_0": [1, 24, 80, 80],
"conv2d_59.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_1.tmp_0": [1, 256, 80, 80],
"nearest_interp_v2_2.tmp_0": [1, 256, 160, 160],
"conv2d_124.tmp_0": [1, 256, 160, 160],
"nearest_interp_v2_3.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_4.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_5.tmp_0": [1, 64, 160, 160],
"elementwise_add_7": [1, 56, 40, 40],
"nearest_interp_v2_0.tmp_0": [1, 256, 40, 40]
}
min_pact_shape = {
"nearest_interp_v2_26.tmp_0": [1, 256, 20, 20],
"nearest_interp_v2_27.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_28.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_29.tmp_0": [1, 64, 20, 20]
}
max_pact_shape = {
"nearest_interp_v2_26.tmp_0": [1, 256, 400, 400],
"nearest_interp_v2_27.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_28.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_29.tmp_0": [1, 64, 400, 400]
}
opt_pact_shape = {
"nearest_interp_v2_26.tmp_0": [1, 256, 160, 160],
"nearest_interp_v2_27.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_28.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_29.tmp_0": [1, 64, 160, 160]
}
min_input_shape.update(min_pact_shape)
max_input_shape.update(max_pact_shape)
opt_input_shape.update(opt_pact_shape)
elif mode == "rec":
if args.rec_algorithm not in ["CRNN", "SVTR_LCNet"]:
use_dynamic_shape = False
imgH = int(args.rec_image_shape.split(',')[-2])
min_input_shape = {"x": [1, 3, imgH, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 2304]}
opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]}
config.exp_disable_tensorrt_ops(["transpose2"])
elif mode == "cls":
min_input_shape = {"x": [1, 3, 48, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
else:
use_dynamic_shape = False
if use_dynamic_shape:
config.set_trt_dynamic_shape_info(
min_input_shape, max_input_shape, opt_input_shape)
logger.info(
f"when using tensorrt, dynamic shape is a suggested option, you can use '--shape_info_filename=shape.txt' for offline dygnamic shape tuning"
)
elif args.use_xpu:
config.enable_xpu(10 * 1024 * 1024)
......
......@@ -96,6 +96,8 @@ def main():
]
elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
elif config['Architecture']['algorithm'] == "RobustScanner":
op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons']
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
......@@ -131,12 +133,20 @@ def main():
if config['Architecture']['algorithm'] == "SAR":
valid_ratio = np.expand_dims(batch[-1], axis=0)
img_metas = [paddle.to_tensor(valid_ratio)]
if config['Architecture']['algorithm'] == "RobustScanner":
valid_ratio = np.expand_dims(batch[1], axis=0)
word_positons = np.expand_dims(batch[2], axis=0)
img_metas = [paddle.to_tensor(valid_ratio),
paddle.to_tensor(word_positons),
]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
elif config['Architecture']['algorithm'] == "SAR":
preds = model(images, img_metas)
elif config['Architecture']['algorithm'] == "RobustScanner":
preds = model(images, img_metas)
else:
preds = model(images)
post_result = post_process_class(preds)
......
......@@ -63,14 +63,14 @@ def main():
elif op_name in ['SRResize']:
op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['imge_lr']
op[op_name]['keep_keys'] = ['img_lr']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
save_res_path = config['Global'].get('save_res_path', "./infer_result")
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
save_visual_path = config['Global'].get('save_visual', "infer_result/")
if not os.path.exists(os.path.dirname(save_visual_path)):
os.makedirs(os.path.dirname(save_visual_path))
model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
......@@ -87,7 +87,7 @@ def main():
fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
img_name_pure = os.path.split(file)[-1]
cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
cv2.imwrite("{}/sr_{}".format(save_visual_path, img_name_pure),
fm_sr[:, :, ::-1])
logger.info("The visualized image saved in infer_result/sr_{}".format(
img_name_pure))
......
......@@ -162,18 +162,18 @@ def to_float32(preds):
for k in preds:
if isinstance(preds[k], dict) or isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
preds[k] = paddle.to_tensor(preds[k], dtype='float32')
elif isinstance(preds[k], paddle.Tensor):
preds[k] = preds[k].astype(paddle.float32)
elif isinstance(preds, list):
for k in range(len(preds)):
if isinstance(preds[k], dict):
preds[k] = to_float32(preds[k])
elif isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
preds[k] = paddle.to_tensor(preds[k], dtype='float32')
else:
preds = paddle.to_tensor(preds, dtype='float32')
elif isinstance(preds[k], paddle.Tensor):
preds[k] = preds[k].astype(paddle.float32)
elif isinstance(preds, paddle.Tensor):
preds = preds.astype(paddle.float32)
return preds
......@@ -190,7 +190,8 @@ def train(config,
pre_best_model_dict,
logger,
log_writer=None,
scaler=None):
scaler=None,
amp_level='O2'):
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
......@@ -230,7 +231,8 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN"
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
"RobustScanner"
]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
......@@ -276,7 +278,8 @@ def train(config,
model_average = True
# use amp
if scaler:
with paddle.amp.auto_cast(level='O2'):
custom_black_list = config['Global'].get('amp_custom_black_list',[])
with paddle.amp.auto_cast(level=amp_level, custom_black_list=custom_black_list):
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
......@@ -502,18 +505,9 @@ def eval(model,
preds = model(batch)
sr_img = preds["sr_img"]
lr_img = preds["lr_img"]
for i in (range(sr_img.shape[0])):
fm_sr = (sr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
cv2.imwrite("output/images/{}_{}_sr.jpg".format(
sum_images, i), fm_sr)
cv2.imwrite("output/images/{}_{}_lr.jpg".format(
sum_images, i), fm_lr)
else:
preds = model(images)
preds = to_float32(preds)
else:
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
......@@ -523,16 +517,6 @@ def eval(model,
preds = model(batch)
sr_img = preds["sr_img"]
lr_img = preds["lr_img"]
for i in (range(sr_img.shape[0])):
fm_sr = (sr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
cv2.imwrite("output/images/{}_{}_sr.jpg".format(
sum_images, i), fm_sr)
cv2.imwrite("output/images/{}_{}_lr.jpg".format(
sum_images, i), fm_lr)
else:
preds = model(images)
......@@ -653,7 +637,7 @@ def preprocess(is_train=False):
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
'Gestalt', 'SLANet'
'Gestalt', 'SLANet', 'RobustScanner'
]
if use_xpu:
......
......@@ -147,6 +147,7 @@ def main(config, device, logger, vdl_writer):
len(valid_dataloader)))
use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", 'O2')
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
......@@ -159,8 +160,9 @@ def main(config, device, logger, vdl_writer):
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
model, optimizer = paddle.amp.decorate(
models=model, optimizers=optimizer, level='O2', master_weight=True)
if amp_level == "O2":
model, optimizer = paddle.amp.decorate(
models=model, optimizers=optimizer, level=amp_level, master_weight=True)
else:
scaler = None
......@@ -169,7 +171,7 @@ def main(config, device, logger, vdl_writer):
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level)
def test_reader(config, device, logger):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册