提交 9816aebd 编写于 作者: T Topdu

add rec vitstr algorithm.

上级 e2fafc30
...@@ -49,7 +49,7 @@ Architecture: ...@@ -49,7 +49,7 @@ Architecture:
Loss: Loss:
name: NRTRLoss name: CESmoothingLoss
smoothing: True smoothing: True
PostProcess: PostProcess:
...@@ -68,8 +68,8 @@ Train: ...@@ -68,8 +68,8 @@ Train:
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- NRTRLabelEncode: # Class handling label - NRTRLabelEncode: # Class handling label
- NRTRRecResizeImg: - GrayRecResizeImg:
image_shape: [100, 32] image_shape: [100, 32] # W H
resize_type: PIL # PIL or OpenCV resize_type: PIL # PIL or OpenCV
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
...@@ -88,8 +88,8 @@ Eval: ...@@ -88,8 +88,8 @@ Eval:
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- NRTRLabelEncode: # Class handling label - NRTRLabelEncode: # Class handling label
- NRTRRecResizeImg: - GrayRecResizeImg:
image_shape: [100, 32] image_shape: [100, 32] # W H
resize_type: PIL # PIL or OpenCV resize_type: PIL # PIL or OpenCV
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
......
...@@ -77,7 +77,7 @@ Metric: ...@@ -77,7 +77,7 @@ Metric:
Train: Train:
dataset: dataset:
name: LMDBDataSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/ data_dir: ./train_data/data_lmdb_release/training
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
...@@ -98,7 +98,7 @@ Train: ...@@ -98,7 +98,7 @@ Train:
Eval: Eval:
dataset: dataset:
name: LMDBDataSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/ data_dir: ./train_data/data_lmdb_release/validation
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
......
Global:
use_gpu: True
epoch_num: 20
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/vitstr/
save_epoch_step: 1
# evaluation is run every 2000 iterations after the 0th iteration#
eval_batch_step: [0, 50]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path: ppocr/utils/EN_symbol_dict.txt
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_vitstr.txt
Optimizer:
name: Adadelta
epsilon: 0.00000001
rho: 0.95
clip_norm: 5.0
lr:
learning_rate: 1.0
Architecture:
model_type: rec
algorithm: ViTSTR
in_channels: 1
Transform:
Backbone:
name: ViTSTR
scale: tiny
Neck:
name: SequenceEncoder
encoder_type: reshape
Head:
name: CTCHead
Loss:
name: CESmoothingLoss
smoothing: False
with_all: True
PostProcess:
name: ViTSTRLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- ViTSTRLabelEncode: # Class handling label
- GrayRecResizeImg:
image_shape: [224, 224] # W H
resize_type: PIL # PIL or OpenCV
inter_type: 'Image.BICUBIC'
scale: false
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 48
drop_last: True
num_workers: 2
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- ViTSTRLabelEncode: # Class handling label
- GrayRecResizeImg:
image_shape: [224, 224] # W H
resize_type: PIL # PIL or OpenCV
inter_type: 'Image.BICUBIC'
scale: false
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 2
...@@ -66,6 +66,7 @@ ...@@ -66,6 +66,7 @@
- [x] [SAR](./algorithm_rec_sar.md) - [x] [SAR](./algorithm_rec_sar.md)
- [x] [SEED](./algorithm_rec_seed.md) - [x] [SEED](./algorithm_rec_seed.md)
- [x] [SVTR](./algorithm_rec_svtr.md) - [x] [SVTR](./algorithm_rec_svtr.md)
- [x] [ViTSTR](./algorithm_rec_vitstr.md)
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: 参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
...@@ -84,7 +85,7 @@ ...@@ -84,7 +85,7 @@
|SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | |SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) | |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) | |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
<a name="2"></a> <a name="2"></a>
......
# 场景文本识别算法-ViTSTR
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [Vision Transformer for Fast and Efficient Scene Text Recognition](https://arxiv.org/abs/2105.08582)
> Rowel Atienza
> ICDAR, 2021
<a name="model"></a>
`ViTSTR`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
|模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- |
|ViTSTR|ViTSTR|[rec_vitstr.yml](../../configs/rec/rec_vitstr.yml)|79.82%|[训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
<a name="3-1"></a>
### 3.1 模型训练
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`ViTSTR`识别模型时需要**更换配置文件**`ViTSTR`[配置文件](../../configs/rec/rec_ViTSTR.yml)
#### 启动训练
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
```shell
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/rec/rec_vitstr.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vitstr.yml
```
<a name="3-2"></a>
### 3.2 评估
可下载已训练完成的[模型文件](#model),使用如下命令进行评估:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vitstr.yml -o Global.pretrained_model=./rec_vitstr_train/best_accuracy
```
<a name="3-3"></a>
### 3.3 预测
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/infer_rec.py -c configs/rec/rec_vitstr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_vitstr_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
```
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) ),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/rec/rec_vitstr.yml -o Global.pretrained_model=./rec_vitstr_train/best_accuracy Global.save_inference_dir=./inference/rec_vitstr/
```
**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应NRTR的`infer_shape`
转换成功后,在目录下有三个文件:
```
/inference/rec_vitstr/
├── inference.pdiparams # 识别inference模型的参数文件
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
└── inference.pdmodel # 识别inference模型的program文件
```
执行如下命令进行模型推理:
```shell
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_vitstr/' --rec_algorithm='ViTSTR' --rec_image_shape='1,224,224' --rec_char_dict_path='./ppocr/utils/EN_symbol_dict.txt'
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
```
![](../imgs_words_en/word_10.png)
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
结果如下:
```shell
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9265879392623901)
```
**注意**
- 训练上述模型采用的图像分辨率是[1,224,224],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中NRTR的预处理为您的预处理方法。
<a name="4-2"></a>
### 4.2 C++推理部署
由于C++预处理后处理还未支持NRTR,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
1.`ViTSTR`论文中,使用在ImageNet1k上的预训练权重进行初始化训练,我们在训练未采用预训练权重,最终精度没有变化甚至有所提高。
2. 我们仅仅复现了`ViTSTR`中的tiny版本,如果有需要使用small、base版本,可直接使用源开源repo中的预训练权重转为Paddle权重即可使用。
## 引用
```bibtex
@article{Atienza2021ViTSTR,
title = {Vision Transformer for Fast and Efficient Scene Text Recognition},
author = {Rowel Atienza},
booktitle = {ICDAR},
year = {2021},
url = {https://arxiv.org/abs/2105.08582}
}
```
...@@ -65,6 +65,7 @@ Supported text recognition algorithms (Click the link to get the tutorial): ...@@ -65,6 +65,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [SAR](./algorithm_rec_sar_en.md) - [x] [SAR](./algorithm_rec_sar_en.md)
- [x] [SEED](./algorithm_rec_seed_en.md) - [x] [SEED](./algorithm_rec_seed_en.md)
- [x] [SVTR](./algorithm_rec_svtr_en.md) - [x] [SVTR](./algorithm_rec_svtr_en.md)
- [x] [ViTSTR](./algorithm_rec_vitstr_en.md)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
...@@ -83,7 +84,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r ...@@ -83,7 +84,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SAR|Resnet31| 87.20% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | |SAR|Resnet31| 87.20% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) | |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) | |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
<a name="2"></a> <a name="2"></a>
......
# ViTSTR
- [1. Introduction](#1)
- [2. Environment](#2)
- [3. Model Training / Evaluation / Prediction](#3)
- [3.1 Training](#3-1)
- [3.2 Evaluation](#3-2)
- [3.3 Prediction](#3-3)
- [4. Inference and Deployment](#4)
- [4.1 Python Inference](#4-1)
- [4.2 C++ Inference](#4-2)
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. Introduction
Paper:
> [Vision Transformer for Fast and Efficient Scene Text Recognition](https://arxiv.org/abs/2105.08582)
> Rowel Atienza
> ICDAR, 2021
Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
|ViTSTR|ViTSTR|[rec_vitstr.yml](../../configs/rec/rec_vitstr.yml)|79.82%|[训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar)|
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
Training:
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
```
#Single GPU training (long training period, not recommended)
python3 tools/train.py -c configs/rec/rec_vitstr.yml
#Multi GPU training, specify the gpu number through the --gpus parameter
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vitstr.yml
```
Evaluation:
```
# GPU evaluation
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vitstr.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
Prediction:
```
# The configuration file used for prediction must match the training
python3 tools/infer_rec.py -c configs/rec/rec_vitstr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_vitstr_train/best_accuracy
```
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, the model saved during the ViTSTR text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar)) ), you can use the following command to convert:
```
python3 tools/export_model.py -c configs/rec/rec_vitstr.yml -o Global.pretrained_model=./rec_vitstr_train/best_accuracy Global.save_inference_dir=./inference/rec_vitstr
```
**Note:**
- If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file.
- If you modified the input size during training, please modify the `infer_shape` corresponding to ViTSTR in the `tools/export_model.py` file.
After the conversion is successful, there are three files in the directory:
```
/inference/rec_vitstr/
├── inference.pdiparams
├── inference.pdiparams.info
└── inference.pdmodel
```
For ViTSTR text recognition model inference, the following commands can be executed:
```
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_vitstr/' --rec_algorithm='ViTSTR' --rec_image_shape='1,224,224' --rec_char_dict_path='./ppocr/utils/EN_symbol_dict.txt'
```
![](../imgs_words_en/word_10.png)
After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
The result is as follows:
```shell
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9265879392623901)
```
<a name="4-2"></a>
### 4.2 C++ Inference
Not supported
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
1. In the `ViTSTR` paper, using pre-trained weights on ImageNet1k for initial training, we did not use pre-trained weights in training, and the final accuracy did not change or even improved.
## Citation
```bibtex
@article{Atienza2021ViTSTR,
title = {Vision Transformer for Fast and Efficient Scene Text Recognition},
author = {Rowel Atienza},
booktitle = {ICDAR},
year = {2021},
url = {https://arxiv.org/abs/2105.08582}
}
```
...@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask ...@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg
from .ssl_img_aug import SSLRotateResize from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste from .copy_paste import CopyPaste
......
...@@ -443,7 +443,9 @@ class KieLabelEncode(object): ...@@ -443,7 +443,9 @@ class KieLabelEncode(object):
elif 'key_cls' in anno.keys(): elif 'key_cls' in anno.keys():
labels.append(anno['key_cls']) labels.append(anno['key_cls'])
else: else:
raise ValueError("Cannot found 'key_cls' in ann.keys(), please check your training annotation.") raise ValueError(
"Cannot found 'key_cls' in ann.keys(), please check your training annotation."
)
edges.append(ann.get('edge', 0)) edges.append(ann.get('edge', 0))
ann_infos = dict( ann_infos = dict(
image=data['image'], image=data['image'],
...@@ -838,6 +840,37 @@ class PRENLabelEncode(BaseRecLabelEncode): ...@@ -838,6 +840,37 @@ class PRENLabelEncode(BaseRecLabelEncode):
return data return data
class ViTSTRLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(ViTSTRLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text))
text.insert(0, 0)
text.append(1)
text = text + [0] * (self.max_text_len + 2 - len(text))
data['label'] = np.array(text)
return data
def add_special_char(self, dict_character):
dict_character = ['<s>', '</s>'] + dict_character
return dict_character
class VQATokenLabelEncode(object): class VQATokenLabelEncode(object):
""" """
Label encode for NLP VQA methods Label encode for NLP VQA methods
......
...@@ -87,11 +87,19 @@ class ClsResizeImg(object): ...@@ -87,11 +87,19 @@ class ClsResizeImg(object):
return data return data
class NRTRRecResizeImg(object): class GrayRecResizeImg(object):
def __init__(self, image_shape, resize_type, padding=False, **kwargs): def __init__(self,
image_shape,
resize_type,
inter_type='Image.ANTIALIAS',
scale=True,
padding=False,
**kwargs):
self.image_shape = image_shape self.image_shape = image_shape
self.resize_type = resize_type self.resize_type = resize_type
self.padding = padding self.padding = padding
self.inter_type = eval(inter_type)
self.scale = scale
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
...@@ -117,13 +125,16 @@ class NRTRRecResizeImg(object): ...@@ -117,13 +125,16 @@ class NRTRRecResizeImg(object):
return data return data
if self.resize_type == 'PIL': if self.resize_type == 'PIL':
image_pil = Image.fromarray(np.uint8(img)) image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize(self.image_shape, Image.ANTIALIAS) img = image_pil.resize(self.image_shape, self.inter_type)
img = np.array(img) img = np.array(img)
if self.resize_type == 'OpenCV': if self.resize_type == 'OpenCV':
img = cv2.resize(img, self.image_shape) img = cv2.resize(img, self.image_shape)
norm_img = np.expand_dims(img, -1) norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1)) norm_img = norm_img.transpose((2, 0, 1))
if self.scale:
data['image'] = norm_img.astype(np.float32) / 128. - 1. data['image'] = norm_img.astype(np.float32) / 128. - 1.
else:
data['image'] = norm_img.astype(np.float32) / 255.
return data return data
......
...@@ -30,7 +30,7 @@ from .det_fce_loss import FCELoss ...@@ -30,7 +30,7 @@ from .det_fce_loss import FCELoss
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss from .rec_srn_loss import SRNLoss
from .rec_nrtr_loss import NRTRLoss from .rec_ce_smooth_loss import CESmoothingLoss
from .rec_sar_loss import SARLoss from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss from .rec_pren_loss import PRENLoss
...@@ -60,8 +60,9 @@ def build_loss(config): ...@@ -60,8 +60,9 @@ def build_loss(config):
support_dict = [ support_dict = [
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss', 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'CESmoothingLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss' 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss',
'MultiLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
...@@ -3,14 +3,18 @@ from paddle import nn ...@@ -3,14 +3,18 @@ from paddle import nn
import paddle.nn.functional as F import paddle.nn.functional as F
class NRTRLoss(nn.Layer): class CESmoothingLoss(nn.Layer):
def __init__(self, smoothing=True, **kwargs): def __init__(self, smoothing=True, with_all=False, **kwargs):
super(NRTRLoss, self).__init__() super(CESmoothingLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0) self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
self.smoothing = smoothing self.smoothing = smoothing
self.with_all = with_all
def forward(self, pred, batch): def forward(self, pred, batch):
pred = pred.reshape([-1, pred.shape[2]]) pred = pred.reshape([-1, pred.shape[2]])
if self.with_all:
tgt = batch[1]
else:
max_len = batch[2].max() max_len = batch[2].max()
tgt = batch[1][:, 1:2 + max_len] tgt = batch[1][:, 1:2 + max_len]
tgt = tgt.reshape([-1]) tgt = tgt.reshape([-1])
......
...@@ -32,10 +32,11 @@ def build_backbone(config, model_type): ...@@ -32,10 +32,11 @@ def build_backbone(config, model_type):
from .rec_micronet import MicroNet from .rec_micronet import MicroNet
from .rec_efficientb3_pren import EfficientNetb3_PREN from .rec_efficientb3_pren import EfficientNetb3_PREN
from .rec_svtrnet import SVTRNet from .rec_svtrnet import SVTRNet
from .rec_vitstr import ViTSTR
support_dict = [ support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
"ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN', "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
'SVTRNet' 'SVTRNet', 'ViTSTR'
] ]
elif model_type == "e2e": elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet from .e2e_resnet_vd_pg import ResNet
......
...@@ -147,7 +147,7 @@ class Attention(nn.Layer): ...@@ -147,7 +147,7 @@ class Attention(nn.Layer):
dim, dim,
num_heads=8, num_heads=8,
mixer='Global', mixer='Global',
HW=[8, 25], HW=None,
local_k=[7, 11], local_k=[7, 11],
qkv_bias=False, qkv_bias=False,
qk_scale=None, qk_scale=None,
...@@ -210,7 +210,7 @@ class Block(nn.Layer): ...@@ -210,7 +210,7 @@ class Block(nn.Layer):
num_heads, num_heads,
mixer='Global', mixer='Global',
local_mixer=[7, 11], local_mixer=[7, 11],
HW=[8, 25], HW=None,
mlp_ratio=4., mlp_ratio=4.,
qkv_bias=False, qkv_bias=False,
qk_scale=None, qk_scale=None,
...@@ -274,7 +274,9 @@ class PatchEmbed(nn.Layer): ...@@ -274,7 +274,9 @@ class PatchEmbed(nn.Layer):
img_size=[32, 100], img_size=[32, 100],
in_channels=3, in_channels=3,
embed_dim=768, embed_dim=768,
sub_num=2): sub_num=2,
patch_size=[4, 4],
mode='pope'):
super().__init__() super().__init__()
num_patches = (img_size[1] // (2 ** sub_num)) * \ num_patches = (img_size[1] // (2 ** sub_num)) * \
(img_size[0] // (2 ** sub_num)) (img_size[0] // (2 ** sub_num))
...@@ -282,6 +284,7 @@ class PatchEmbed(nn.Layer): ...@@ -282,6 +284,7 @@ class PatchEmbed(nn.Layer):
self.num_patches = num_patches self.num_patches = num_patches
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.norm = None self.norm = None
if mode == 'pope':
if sub_num == 2: if sub_num == 2:
self.proj = nn.Sequential( self.proj = nn.Sequential(
ConvBNLayer( ConvBNLayer(
...@@ -326,6 +329,11 @@ class PatchEmbed(nn.Layer): ...@@ -326,6 +329,11 @@ class PatchEmbed(nn.Layer):
padding=1, padding=1,
act=nn.GELU, act=nn.GELU,
bias_attr=None)) bias_attr=None))
elif mode == 'linear':
self.proj = nn.Conv2D(
1, embed_dim, kernel_size=patch_size, stride=patch_size)
self.num_patches = img_size[0] // patch_size[0] * img_size[
1] // patch_size[1]
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/roatienza/deep-text-recognition-benchmark/blob/master/modules/vitstr.py
"""
import numpy as np
import paddle
import paddle.nn as nn
from ppocr.modeling.backbones.rec_svtrnet import Block, PatchEmbed, zeros_, trunc_normal_, ones_
scale_dim_heads = {'tiny': [192, 3], 'small': [384, 6], 'base': [768, 12]}
class ViTSTR(nn.Layer):
def __init__(self,
img_size=[224, 224],
in_channels=1,
scale='tiny',
seqlen=27,
patch_size=[16, 16],
embed_dim=None,
depth=12,
num_heads=None,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_path_rate=0.,
drop_rate=0.,
attn_drop_rate=0.,
norm_layer='nn.LayerNorm',
act_layer='nn.GELU',
epsilon=1e-6,
out_channels=None,
**kwargs):
super().__init__()
self.seqlen = seqlen
embed_dim = embed_dim if embed_dim is not None else scale_dim_heads[
scale][0]
num_heads = num_heads if num_heads is not None else scale_dim_heads[
scale][1]
out_channels = out_channels if out_channels is not None else embed_dim
self.patch_embed = PatchEmbed(
img_size=img_size,
in_channels=in_channels,
embed_dim=embed_dim,
patch_size=patch_size,
mode='linear')
num_patches = self.patch_embed.num_patches
self.pos_embed = self.create_parameter(
shape=[1, num_patches + 1, embed_dim], default_initializer=zeros_)
self.add_parameter("pos_embed", self.pos_embed)
self.cls_token = self.create_parameter(
shape=[1, 1, embed_dim], default_initializer=zeros_)
self.add_parameter("cls_token", self.cls_token)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = np.linspace(0, drop_path_rate, depth)
self.blocks = nn.LayerList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=eval(act_layer),
epsilon=epsilon,
prenorm=False) for i in range(depth)
])
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
self.out_channels = out_channels
trunc_normal_(self.pos_embed)
trunc_normal_(self.cls_token)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = paddle.tile(self.cls_token, repeat_times=[B, 1, 1])
x = paddle.concat((cls_tokens, x), axis=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = x[:, :self.seqlen]
return x.transpose([0, 2, 1]).unsqueeze(2)
...@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess ...@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
...@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None): ...@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode' 'DistillationSARLabelDecode', 'ViTSTRLabelDecode'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -188,13 +188,13 @@ class NRTRLabelDecode(BaseRecLabelDecode): ...@@ -188,13 +188,13 @@ class NRTRLabelDecode(BaseRecLabelDecode):
char_list = [] char_list = []
conf_list = [] conf_list = []
for idx in range(len(text_index[batch_idx])): for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] == 3: # end
break
try: try:
char_list.append(self.character[int(text_index[batch_idx][ char_idx = self.character[int(text_index[batch_idx][idx])]
idx])])
except: except:
continue continue
if char_idx == '</s>': # end
break
char_list.append(char_idx)
if text_prob is not None: if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx]) conf_list.append(text_prob[batch_idx][idx])
else: else:
...@@ -204,6 +204,32 @@ class NRTRLabelDecode(BaseRecLabelDecode): ...@@ -204,6 +204,32 @@ class NRTRLabelDecode(BaseRecLabelDecode):
return result_list return result_list
class ViTSTRLabelDecode(NRTRLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(ViTSTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds[:, 1:].numpy()
else:
preds = preds[:, 1:]
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label[:, 1:])
return text, label
def add_special_char(self, dict_character):
dict_character = ['<s>', '</s>'] + dict_character
return dict_character
class AttnLabelDecode(BaseRecLabelDecode): class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
...@@ -49,7 +49,7 @@ Architecture: ...@@ -49,7 +49,7 @@ Architecture:
Loss: Loss:
name: NRTRLoss name: CESmoothingLoss
smoothing: True smoothing: True
PostProcess: PostProcess:
...@@ -69,7 +69,7 @@ Train: ...@@ -69,7 +69,7 @@ Train:
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- NRTRLabelEncode: # Class handling label - NRTRLabelEncode: # Class handling label
- NRTRRecResizeImg: - GrayRecResizeImg:
image_shape: [100, 32] image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV resize_type: PIL # PIL or OpenCV
- KeepKeys: - KeepKeys:
...@@ -90,7 +90,7 @@ Eval: ...@@ -90,7 +90,7 @@ Eval:
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- NRTRLabelEncode: # Class handling label - NRTRLabelEncode: # Class handling label
- NRTRRecResizeImg: - GrayRecResizeImg:
image_shape: [100, 32] image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV resize_type: PIL # PIL or OpenCV
- KeepKeys: - KeepKeys:
......
Global:
use_gpu: True
epoch_num: 20
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/svtr/
save_epoch_step: 1
# evaluation is run every 2000 iterations after the 0th iteration
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
character_type: en
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_svtr_tiny.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.99
epsilon: 0.00000008
weight_decay: 0.05
no_weight_decay_name: norm pos_embed
one_dim_param_no_weight_decay: true
lr:
name: Cosine
learning_rate: 0.0005
warmup_epoch: 2
Architecture:
model_type: rec
algorithm: SVTR
Transform:
name: STN_ON
tps_inputsize: [32, 64]
tps_outputsize: [32, 100]
num_control_points: 20
tps_margins: [0.05,0.05]
stn_activation: none
Backbone:
name: SVTRNet
img_size: [32, 100]
out_char_num: 25
out_channels: 192
patch_merging: 'Conv'
embed_dim: [64, 128, 256]
depth: [3, 6, 3]
num_heads: [2, 4, 8]
mixer: ['Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global']
local_mixer: [[7, 11], [7, 11], [7, 11]]
last_stage: True
prenorm: false
Neck:
name: SequenceEncoder
encoder_type: reshape
Head:
name: CTCHead
Loss:
name: CTCLoss
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
character_dict_path:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 512
drop_last: True
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
character_dict_path:
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 2
===========================train_params===========================
model_name:rec_svtrnet
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:./inference/rec_svtrnet_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbol_dict.txt --rec_image_shape="1,224,224" --rec_algorithm="SVTR"
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:1|6
--rec_batch_num:1|6
--use_tensorrt:False
--precision:fp32|int8
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[1,224,224]}]
Global:
use_gpu: True
epoch_num: 20
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/vitstr/
save_epoch_step: 1
# evaluation is run every 2000 iterations after the 0th iteration#
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path: ppocr/utils/EN_symbol_dict.txt
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_vitstr.txt
Optimizer:
name: Adadelta
epsilon: 0.00000001
rho: 0.95
clip_norm: 5.0
lr:
learning_rate: 1.0
Architecture:
model_type: rec
algorithm: ViTSTR
in_channels: 1
Transform:
Backbone:
name: ViTSTR
Neck:
name: SequenceEncoder
encoder_type: reshape
Head:
name: CTCHead
Loss:
name: CESmoothingLoss
smoothing: False
with_all: True
PostProcess:
name: ViTSTRLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- ViTSTRLabelEncode: # Class handling label
- GrayRecResizeImg:
image_shape: [224, 224] # W H
resize_type: PIL # PIL or OpenCV
inter_type: 'Image.BICUBIC'
scale: false
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 48
drop_last: True
num_workers: 2
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- ViTSTRLabelEncode: # Class handling label
- GrayRecResizeImg:
image_shape: [224, 224] # W H
resize_type: PIL # PIL or OpenCV
inter_type: 'Image.BICUBIC'
scale: false
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 2
===========================train_params===========================
model_name:rec_vitstr
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/rec_vitstr/rec_vitstr.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c test_tipc/configs/rec_vitstr/rec_vitstr.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_vitstr/rec_vitstr.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:./inference/rec_vitstr_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_vitstr/rec_vitstr.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbol_dict.txt --rec_image_shape="1,224,224" --rec_algorithm="ViTSTR"
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:1|6
--rec_batch_num:1|6
--use_tensorrt:False
--precision:fp32|int8
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[1,224,224]}]
...@@ -73,6 +73,12 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None): ...@@ -73,6 +73,12 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
shape=[None, 3, 64, 512], dtype="float32"), shape=[None, 3, 64, 512], dtype="float32"),
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "ViTSTR":
other_shape = [
paddle.static.InputSpec(
shape=[None, 1, 224, 224], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
else: else:
infer_shape = [3, -1, -1] infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec": if arch_config["model_type"] == "rec":
......
...@@ -69,6 +69,12 @@ class TextRecognizer(object): ...@@ -69,6 +69,12 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path, "character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char "use_space_char": args.use_space_char
} }
elif self.rec_algorithm == 'ViTSTR':
postprocess_params = {
'name': 'ViTSTRLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger) utility.create_predictor(args, 'rec', logger)
...@@ -96,15 +102,22 @@ class TextRecognizer(object): ...@@ -96,15 +102,22 @@ class TextRecognizer(object):
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
if self.rec_algorithm == 'NRTR': if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR':
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im # return padding_im
image_pil = Image.fromarray(np.uint8(img)) image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize([100, 32], Image.ANTIALIAS) if self.rec_algorithm == 'ViTSTR':
img = image_pil.resize([imgW, imgH], Image.BICUBIC)
else:
img = image_pil.resize([imgW, imgH], Image.ANTIALIAS)
img = np.array(img) img = np.array(img)
norm_img = np.expand_dims(img, -1) norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1)) norm_img = norm_img.transpose((2, 0, 1))
return norm_img.astype(np.float32) / 128. - 1. if self.rec_algorithm == 'ViTSTR':
norm_img = norm_img.astype(np.float32) / 255.
else:
norm_img = norm_img.astype(np.float32) / 128. - 1.
return norm_img
assert imgC == img.shape[2] assert imgC == img.shape[2]
imgW = int((imgH * max_wh_ratio)) imgW = int((imgH * max_wh_ratio))
......
...@@ -307,7 +307,8 @@ def train(config, ...@@ -307,7 +307,8 @@ def train(config,
train_stats.update(stats) train_stats.update(stats)
if log_writer is not None and dist.get_rank() == 0: if log_writer is not None and dist.get_rank() == 0:
log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step) log_writer.log_metrics(
metrics=train_stats.get(), prefix="TRAIN", step=global_step)
if dist.get_rank() == 0 and ( if dist.get_rank() == 0 and (
(global_step > 0 and global_step % print_batch_step == 0) or (global_step > 0 and global_step % print_batch_step == 0) or
...@@ -354,7 +355,8 @@ def train(config, ...@@ -354,7 +355,8 @@ def train(config,
# logger metric # logger metric
if log_writer is not None: if log_writer is not None:
log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step) log_writer.log_metrics(
metrics=cur_metric, prefix="EVAL", step=global_step)
if cur_metric[main_indicator] >= best_model_dict[ if cur_metric[main_indicator] >= best_model_dict[
main_indicator]: main_indicator]:
...@@ -377,11 +379,18 @@ def train(config, ...@@ -377,11 +379,18 @@ def train(config,
logger.info(best_str) logger.info(best_str)
# logger best metric # logger best metric
if log_writer is not None: if log_writer is not None:
log_writer.log_metrics(metrics={ log_writer.log_metrics(
"best_{}".format(main_indicator): best_model_dict[main_indicator] metrics={
}, prefix="EVAL", step=global_step) "best_{}".format(main_indicator):
best_model_dict[main_indicator]
log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict) },
prefix="EVAL",
step=global_step)
log_writer.log_model(
is_best=True,
prefix="best_accuracy",
metadata=best_model_dict)
reader_start = time.time() reader_start = time.time()
if dist.get_rank() == 0: if dist.get_rank() == 0:
...@@ -413,7 +422,8 @@ def train(config, ...@@ -413,7 +422,8 @@ def train(config,
epoch=epoch, epoch=epoch,
global_step=global_step) global_step=global_step)
if log_writer is not None: if log_writer is not None:
log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch)) log_writer.log_model(
is_best=False, prefix='iter_epoch_{}'.format(epoch))
best_str = 'best metric, {}'.format(', '.join( best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
...@@ -564,7 +574,8 @@ def preprocess(is_train=False): ...@@ -564,7 +574,8 @@ def preprocess(is_train=False):
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR' 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
'ViTSTR'
] ]
if use_xpu: if use_xpu:
...@@ -585,7 +596,8 @@ def preprocess(is_train=False): ...@@ -585,7 +596,8 @@ def preprocess(is_train=False):
vdl_writer_path = '{}/vdl/'.format(save_model_dir) vdl_writer_path = '{}/vdl/'.format(save_model_dir)
log_writer = VDLLogger(save_model_dir) log_writer = VDLLogger(save_model_dir)
loggers.append(log_writer) loggers.append(log_writer)
if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config: if ('use_wandb' in config['Global'] and
config['Global']['use_wandb']) or 'wandb' in config:
save_dir = config['Global']['save_model_dir'] save_dir = config['Global']['save_model_dir']
wandb_writer_path = "{}/wandb".format(save_dir) wandb_writer_path = "{}/wandb".format(save_dir)
if "wandb" in config: if "wandb" in config:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册