未验证 提交 a485740f 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #6842 from smilelite/robustscanner_branch

添加robustscanner(第三次)
Global:
use_gpu: true
epoch_num: 5
log_smooth_window: 20
print_batch_step: 20
save_model_dir: ./output/rec/rec_r31_robustscanner/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: ./inference/rec_inference
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
max_text_length: &max_text_length 40
infer_mode: False
use_space_char: False
rm_symbol: True
save_res_path: ./output/rec/predicts_robustscanner.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs: [3, 4]
values: [0.001, 0.0001, 0.00001]
regularizer:
name: 'L2'
factor: 0
Architecture:
model_type: rec
algorithm: RobustScanner
Transform:
Backbone:
name: ResNet31
init_type: KaimingNormal
Head:
name: RobustScannerHead
enc_outchannles: 128
hybrid_dec_rnn_layers: 2
hybrid_dec_dropout: 0
position_dec_rnn_layers: 2
start_idx: 91
mask: True
padding_idx: 92
encode_value: False
max_text_length: *max_text_length
Loss:
name: SARLoss
PostProcess:
name: SARLabelDecode
Metric:
name: RecMetric
is_filter: True
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SARLabelEncode: # Class handling label
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
width_downsample_ratio: 0.25
max_text_length: *max_text_length
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 64
drop_last: True
num_workers: 8
use_shared_memory: False
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SARLabelEncode: # Class handling label
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160]
max_text_length: *max_text_length
width_downsample_ratio: 0.25
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 64
num_workers: 4
use_shared_memory: False
......@@ -72,6 +72,7 @@
- [x] [ABINet](./algorithm_rec_abinet.md)
- [x] [VisionLAN](./algorithm_rec_visionlan.md)
- [x] [SPIN](./algorithm_rec_spin.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner.md)
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
......@@ -94,6 +95,7 @@
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
<a name="2"></a>
......
# RobustScanner
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition](https://arxiv.org/pdf/2007.07542.pdf)
> Xiaoyu Yue, Zhanghui Kuang, Chenhao Lin, Hongbin Sun, Wayne
Zhang
> ECCV, 2020
使用MJSynth和SynthText两个合成文字识别数据集训练,在IIIT, SVT, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
|模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- |
|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
注:除了使用MJSynth和SynthText两个文字识别数据集外,还加入了[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg)数据(提取码:627x),和部分真实数据,具体数据细节可以参考论文。
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
训练
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
```
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/rec/rec_r31_robustscanner.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r31_robustscanner.yml
```
评估
```
# GPU 评估, Global.pretrained_model 为待测权重
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
预测:
```
# 预测使用的配置文件必须与训练一致
python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
```
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将RobustScanner文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner
```
RobustScanner文本识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r31_robustscanner/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="RobustScanner" --rec_char_dict_path="ppocr/utils/dict90.txt" --use_space_char=False
```
<a name="4-2"></a>
### 4.2 C++推理
由于C++预处理后处理还未支持RobustScanner,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@article{2020RobustScanner,
title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition},
author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang},
journal={ECCV2020},
year={2020},
}
```
......@@ -70,6 +70,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [ABINet](./algorithm_rec_abinet_en.md)
- [x] [VisionLAN](./algorithm_rec_visionlan_en.md)
- [x] [SPIN](./algorithm_rec_spin_en.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner_en.md)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
......@@ -92,6 +93,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
<a name="2"></a>
......
# RobustScanner
- [1. Introduction](#1)
- [2. Environment](#2)
- [3. Model Training / Evaluation / Prediction](#3)
- [3.1 Training](#3-1)
- [3.2 Evaluation](#3-2)
- [3.3 Prediction](#3-3)
- [4. Inference and Deployment](#4)
- [4.1 Python Inference](#4-1)
- [4.2 C++ Inference](#4-2)
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. Introduction
Paper:
> [RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition](https://arxiv.org/pdf/2007.07542.pdf)
> Xiaoyu Yue, Zhanghui Kuang, Chenhao Lin, Hongbin Sun, Wayne
Zhang
> ECCV, 2020
Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon|
Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper.
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
Training:
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
```
#Single GPU training (long training period, not recommended)
python3 tools/train.py -c configs/rec/rec_r31_robustscanner.yml
#Multi GPU training, specify the gpu number through the --gpus parameter
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r31_robustscanner.yml
```
Evaluation:
```
# GPU evaluation
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
Prediction:
```
# The configuration file used for prediction must match the training
python3 tools/infer_rec.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
```
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, the model saved during the RobustScanner text recognition training process is converted into an inference model. you can use the following command to convert:
```
python3 tools/export_model.py -c configs/rec/rec_r31_robustscanner.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r31_robustscanner
```
For RobustScanner text recognition model inference, the following commands can be executed:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r31_robustscanner/" --rec_image_shape="3, 48, 48, 160" --rec_algorithm="RobustScanner" --rec_char_dict_path="ppocr/utils/dict90.txt" --use_space_char=False
```
<a name="4-2"></a>
### 4.2 C++ Inference
Not supported
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
## Citation
```bibtex
@article{2020RobustScanner,
title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition},
author={Xiaoyu Yue and Zhanghui Kuang and Chenhao Lin and Hongbin Sun and Wayne Zhang},
journal={ECCV2020},
year={2020},
}
```
......@@ -26,8 +26,7 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
......
......@@ -414,6 +414,23 @@ class SVTRRecResizeImg(object):
data['valid_ratio'] = valid_ratio
return data
class RobustScannerRecResizeImg(object):
def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs):
self.image_shape = image_shape
self.width_downsample_ratio = width_downsample_ratio
self.max_text_length = max_text_length
def __call__(self, data):
img = data['image']
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
img, self.image_shape, self.width_downsample_ratio)
word_positons = np.array(range(0, self.max_text_length)).astype('int64')
data['image'] = norm_img
data['resized_shape'] = resize_shape
data['pad_shape'] = pad_shape
data['valid_ratio'] = valid_ratio
data['word_positons'] = word_positons
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
......
......@@ -29,27 +29,29 @@ import numpy as np
__all__ = ["ResNet31"]
def conv3x3(in_channel, out_channel, stride=1):
def conv3x3(in_channel, out_channel, stride=1, conv_weight_attr=None):
return nn.Conv2D(
in_channel,
out_channel,
kernel_size=3,
stride=stride,
padding=1,
weight_attr=conv_weight_attr,
bias_attr=False)
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self, in_channels, channels, stride=1, downsample=False):
def __init__(self, in_channels, channels, stride=1, downsample=False, conv_weight_attr=None, bn_weight_attr=None):
super().__init__()
self.conv1 = conv3x3(in_channels, channels, stride)
self.bn1 = nn.BatchNorm2D(channels)
self.conv1 = conv3x3(in_channels, channels, stride,
conv_weight_attr=conv_weight_attr)
self.bn1 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
self.relu = nn.ReLU()
self.conv2 = conv3x3(channels, channels)
self.bn2 = nn.BatchNorm2D(channels)
self.conv2 = conv3x3(channels, channels,
conv_weight_attr=conv_weight_attr)
self.bn2 = nn.BatchNorm2D(channels, weight_attr=bn_weight_attr)
self.downsample = downsample
if downsample:
self.downsample = nn.Sequential(
......@@ -58,8 +60,9 @@ class BasicBlock(nn.Layer):
channels * self.expansion,
1,
stride,
weight_attr=conv_weight_attr,
bias_attr=False),
nn.BatchNorm2D(channels * self.expansion), )
nn.BatchNorm2D(channels * self.expansion, weight_attr=bn_weight_attr))
else:
self.downsample = nn.Sequential()
self.stride = stride
......@@ -91,6 +94,7 @@ class ResNet31(nn.Layer):
channels (list[int]): List of out_channels of Conv2d layer.
out_indices (None | Sequence[int]): Indices of output stages.
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
init_type (None | str): the config to control the initialization.
'''
def __init__(self,
......@@ -98,7 +102,8 @@ class ResNet31(nn.Layer):
layers=[1, 2, 5, 3],
channels=[64, 128, 256, 256, 512, 512, 512],
out_indices=None,
last_stage_pool=False):
last_stage_pool=False,
init_type=None):
super(ResNet31, self).__init__()
assert isinstance(in_channels, int)
assert isinstance(last_stage_pool, bool)
......@@ -106,42 +111,55 @@ class ResNet31(nn.Layer):
self.out_indices = out_indices
self.last_stage_pool = last_stage_pool
conv_weight_attr = None
bn_weight_attr = None
if init_type is not None:
support_dict = ['KaimingNormal']
assert init_type in support_dict, Exception(
"resnet31 only support {}".format(support_dict))
conv_weight_attr = nn.initializer.KaimingNormal()
bn_weight_attr = ParamAttr(initializer=nn.initializer.Uniform(), learning_rate=1)
# conv 1 (Conv Conv)
self.conv1_1 = nn.Conv2D(
in_channels, channels[0], kernel_size=3, stride=1, padding=1)
self.bn1_1 = nn.BatchNorm2D(channels[0])
in_channels, channels[0], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn1_1 = nn.BatchNorm2D(channels[0], weight_attr=bn_weight_attr)
self.relu1_1 = nn.ReLU()
self.conv1_2 = nn.Conv2D(
channels[0], channels[1], kernel_size=3, stride=1, padding=1)
self.bn1_2 = nn.BatchNorm2D(channels[1])
channels[0], channels[1], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn1_2 = nn.BatchNorm2D(channels[1], weight_attr=bn_weight_attr)
self.relu1_2 = nn.ReLU()
# conv 2 (Max-pooling, Residual block, Conv)
self.pool2 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block2 = self._make_layer(channels[1], channels[2], layers[0])
self.block2 = self._make_layer(channels[1], channels[2], layers[0],
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
self.conv2 = nn.Conv2D(
channels[2], channels[2], kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2D(channels[2])
channels[2], channels[2], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn2 = nn.BatchNorm2D(channels[2], weight_attr=bn_weight_attr)
self.relu2 = nn.ReLU()
# conv 3 (Max-pooling, Residual block, Conv)
self.pool3 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block3 = self._make_layer(channels[2], channels[3], layers[1])
self.block3 = self._make_layer(channels[2], channels[3], layers[1],
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
self.conv3 = nn.Conv2D(
channels[3], channels[3], kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2D(channels[3])
channels[3], channels[3], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn3 = nn.BatchNorm2D(channels[3], weight_attr=bn_weight_attr)
self.relu3 = nn.ReLU()
# conv 4 (Max-pooling, Residual block, Conv)
self.pool4 = nn.MaxPool2D(
kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
self.block4 = self._make_layer(channels[3], channels[4], layers[2])
self.block4 = self._make_layer(channels[3], channels[4], layers[2],
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
self.conv4 = nn.Conv2D(
channels[4], channels[4], kernel_size=3, stride=1, padding=1)
self.bn4 = nn.BatchNorm2D(channels[4])
channels[4], channels[4], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn4 = nn.BatchNorm2D(channels[4], weight_attr=bn_weight_attr)
self.relu4 = nn.ReLU()
# conv 5 ((Max-pooling), Residual block, Conv)
......@@ -149,15 +167,16 @@ class ResNet31(nn.Layer):
if self.last_stage_pool:
self.pool5 = nn.MaxPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.block5 = self._make_layer(channels[4], channels[5], layers[3])
self.block5 = self._make_layer(channels[4], channels[5], layers[3],
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr)
self.conv5 = nn.Conv2D(
channels[5], channels[5], kernel_size=3, stride=1, padding=1)
self.bn5 = nn.BatchNorm2D(channels[5])
channels[5], channels[5], kernel_size=3, stride=1, padding=1, weight_attr=conv_weight_attr)
self.bn5 = nn.BatchNorm2D(channels[5], weight_attr=bn_weight_attr)
self.relu5 = nn.ReLU()
self.out_channels = channels[-1]
def _make_layer(self, input_channels, output_channels, blocks):
def _make_layer(self, input_channels, output_channels, blocks, conv_weight_attr=None, bn_weight_attr=None):
layers = []
for _ in range(blocks):
downsample = None
......@@ -168,12 +187,14 @@ class ResNet31(nn.Layer):
output_channels,
kernel_size=1,
stride=1,
weight_attr=conv_weight_attr,
bias_attr=False),
nn.BatchNorm2D(output_channels), )
nn.BatchNorm2D(output_channels, weight_attr=bn_weight_attr))
layers.append(
BasicBlock(
input_channels, output_channels, downsample=downsample))
input_channels, output_channels, downsample=downsample,
conv_weight_attr=conv_weight_attr, bn_weight_attr=bn_weight_attr))
input_channels = output_channels
return nn.Sequential(*layers)
......
......@@ -35,6 +35,7 @@ def build_head(config):
from .rec_multi_head import MultiHead
from .rec_spin_att_head import SPINAttentionHead
from .rec_abinet_head import ABINetHead
from .rec_robustscanner_head import RobustScannerHead
from .rec_visionlan_head import VLHead
# cls head
......@@ -51,7 +52,7 @@ def build_head(config):
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'SLAHead'
'VLHead', 'SLAHead', 'RobustScannerHead'
]
#table head
......
此差异已折叠。
Global:
use_gpu: true
epoch_num: 5
log_smooth_window: 20
print_batch_step: 20
save_model_dir: ./output/rec/rec_r31_robustscanner/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: ./inference/rec_inference
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
max_text_length: &max_text_length 40
infer_mode: False
use_space_char: False
rm_symbol: True
save_res_path: ./output/rec/predicts_robustscanner.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs: [3, 4]
values: [0.001, 0.0001, 0.00001]
regularizer:
name: 'L2'
factor: 0
Architecture:
model_type: rec
algorithm: RobustScanner
Transform:
Backbone:
name: ResNet31
init_type: KaimingNormal
Head:
name: RobustScannerHead
enc_outchannles: 128
hybrid_dec_rnn_layers: 2
hybrid_dec_dropout: 0
position_dec_rnn_layers: 2
start_idx: 91
mask: True
padding_idx: 92
encode_value: False
max_text_length: *max_text_length
Loss:
name: SARLoss
PostProcess:
name: SARLabelDecode
Metric:
name: RecMetric
is_filter: True
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SARLabelEncode: # Class handling label
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
width_downsample_ratio: 0.25
max_text_length: *max_text_length
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 16
drop_last: True
num_workers: 0
use_shared_memory: False
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SARLabelEncode: # Class handling label
- RobustScannerRecResizeImg:
image_shape: [3, 48, 48, 160]
max_text_length: *max_text_length
width_downsample_ratio: 0.25
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 16
num_workers: 0
use_shared_memory: False
===========================train_params===========================
model_name:rec_r31_robustscanner
python:python
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=5
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:./inference/rec_r31_robustscanner/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.txt --rec_image_shape="3,48,48,160" --use_space_char=False --rec_algorithm="RobustScanner"
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1|6
--use_tensorrt:False|False
--precision:fp32|int8
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,48,160]}]
......@@ -54,6 +54,7 @@
| NRTR |rec_mtb_nrtr | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| SAR |rec_r31_sar | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| SPIN |rec_r32_gaspin_bilstm_att | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| RobustScanner |rec_r31_robustscanner | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端| 支持 | 多机多卡 <br> 混合精度 | - | - |
| TableMaster |table_structure_tablemaster_train | 表格识别| 支持 | 多机多卡 <br> 混合精度 | - | - |
......
......@@ -73,7 +73,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"]
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
......
......@@ -111,6 +111,22 @@ def export_single_model(model,
shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "RobustScanner":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 48, 160], dtype="float32"),
[
paddle.static.InputSpec(
shape=[None, ],
dtype="float32"),
paddle.static.InputSpec(
shape=[None, max_text_length],
dtype="int64")
]
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(
......
......@@ -68,7 +68,7 @@ class TextRecognizer(object):
'name': 'SARLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
}
elif self.rec_algorithm == "VisionLAN":
postprocess_params = {
'name': 'VLLabelDecode',
......@@ -93,6 +93,13 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "RobustScanner":
postprocess_params = {
'name': 'SARLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char,
"rm_symbol": True
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
......@@ -390,6 +397,18 @@ class TextRecognizer(object):
img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "RobustScanner":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]], self.rec_image_shape, width_downsample_ratio=0.25)
norm_img = norm_img[np.newaxis, :]
valid_ratio = np.expand_dims(valid_ratio, axis=0)
valid_ratios = []
valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img)
word_positions_list = []
word_positions = np.array(range(0, 40)).astype('int64')
word_positions = np.expand_dims(word_positions, axis=0)
word_positions_list.append(word_positions)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
......@@ -442,6 +461,35 @@ class TextRecognizer(object):
np.array(
[valid_ratios], dtype=np.float32),
]
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(self.output_tensors,
input_dict)
preds = outputs[0]
else:
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(
input_names[i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
if self.benchmark:
self.autolog.times.stamp()
preds = outputs[0]
elif self.rec_algorithm == "RobustScanner":
valid_ratios = np.concatenate(valid_ratios)
word_positions_list = np.concatenate(word_positions_list)
inputs = [
norm_img_batch,
valid_ratios,
word_positions_list
]
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
......
......@@ -96,6 +96,8 @@ def main():
]
elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
elif config['Architecture']['algorithm'] == "RobustScanner":
op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons']
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
......@@ -131,12 +133,20 @@ def main():
if config['Architecture']['algorithm'] == "SAR":
valid_ratio = np.expand_dims(batch[-1], axis=0)
img_metas = [paddle.to_tensor(valid_ratio)]
if config['Architecture']['algorithm'] == "RobustScanner":
valid_ratio = np.expand_dims(batch[1], axis=0)
word_positons = np.expand_dims(batch[2], axis=0)
img_metas = [paddle.to_tensor(valid_ratio),
paddle.to_tensor(word_positons),
]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
elif config['Architecture']['algorithm'] == "SAR":
preds = model(images, img_metas)
elif config['Architecture']['algorithm'] == "RobustScanner":
preds = model(images, img_metas)
else:
preds = model(images)
post_result = post_process_class(preds)
......
......@@ -230,7 +230,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN"
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", "RobustScanner"
]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
......@@ -653,7 +653,7 @@ def preprocess(is_train=False):
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
'Gestalt', 'SLANet'
'Gestalt', 'SLANet', 'RobustScanner'
]
if use_xpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册