diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml
index 04267500854310dc6d5df9318bb8c056c65cd5b5..dfe2cc9811120f0a5960d02a28e39ada83b98104 100644
--- a/configs/rec/rec_mtb_nrtr.yml
+++ b/configs/rec/rec_mtb_nrtr.yml
@@ -49,7 +49,7 @@ Architecture:
Loss:
- name: NRTRLoss
+ name: CESmoothingLoss
smoothing: True
PostProcess:
@@ -68,8 +68,8 @@ Train:
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- - NRTRRecResizeImg:
- image_shape: [100, 32]
+ - GrayRecResizeImg:
+ image_shape: [100, 32] # W H
resize_type: PIL # PIL or OpenCV
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
@@ -88,8 +88,8 @@ Eval:
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- - NRTRRecResizeImg:
- image_shape: [100, 32]
+ - GrayRecResizeImg:
+ image_shape: [100, 32] # W H
resize_type: PIL # PIL or OpenCV
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
diff --git a/configs/rec/rec_svtrnet.yml b/configs/rec/rec_svtrnet.yml
index 233d5e276577cad0144456ef7df1e20de99891f9..a700e4bd92995e8c0da2bf7623fe25e746483b1b 100644
--- a/configs/rec/rec_svtrnet.yml
+++ b/configs/rec/rec_svtrnet.yml
@@ -77,7 +77,7 @@ Metric:
Train:
dataset:
name: LMDBDataSet
- data_dir: ./train_data/data_lmdb_release/training/
+ data_dir: ./train_data/data_lmdb_release/training
transforms:
- DecodeImage: # load image
img_mode: BGR
@@ -98,7 +98,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
- data_dir: ./train_data/data_lmdb_release/validation/
+ data_dir: ./train_data/data_lmdb_release/validation
transforms:
- DecodeImage: # load image
img_mode: BGR
diff --git a/configs/rec/rec_vitstr.yml b/configs/rec/rec_vitstr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..005db0184ae3319edffacb29a1dfd1751460a00a
--- /dev/null
+++ b/configs/rec/rec_vitstr.yml
@@ -0,0 +1,100 @@
+Global:
+ use_gpu: True
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec/vitstr/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations after the 0th iteration#
+ eval_batch_step: [0, 50]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words_en/word_10.png
+ # for data or label process
+ character_dict_path: ppocr/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_vitstr.txt
+
+
+Optimizer:
+ name: Adadelta
+ epsilon: 0.00000001
+ rho: 0.95
+ clip_norm: 5.0
+ lr:
+ learning_rate: 1.0
+
+Architecture:
+ model_type: rec
+ algorithm: ViTSTR
+ in_channels: 1
+ Transform:
+ Backbone:
+ name: ViTSTR
+ scale: tiny
+ Neck:
+ name: SequenceEncoder
+ encoder_type: reshape
+ Head:
+ name: CTCHead
+
+Loss:
+ name: CESmoothingLoss
+ smoothing: False
+ with_all: True
+
+PostProcess:
+ name: ViTSTRLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/data_lmdb_release/training
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ViTSTRLabelEncode: # Class handling label
+ - GrayRecResizeImg:
+ image_shape: [224, 224] # W H
+ resize_type: PIL # PIL or OpenCV
+ inter_type: 'Image.BICUBIC'
+ scale: false
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 48
+ drop_last: True
+ num_workers: 2
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/data_lmdb_release/validation
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ViTSTRLabelEncode: # Class handling label
+ - GrayRecResizeImg:
+ image_shape: [224, 224] # W H
+ resize_type: PIL # PIL or OpenCV
+ inter_type: 'Image.BICUBIC'
+ scale: false
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index 6227a21498eda7d8527e21e7f2567995251d9e47..934ac08537504fe6fa4d78c1d3635ac43a201efb 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -66,6 +66,7 @@
- [x] [SAR](./algorithm_rec_sar.md)
- [x] [SEED](./algorithm_rec_seed.md)
- [x] [SVTR](./algorithm_rec_svtr.md)
+- [x] [ViTSTR](./algorithm_rec_vitstr.md)
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
@@ -84,7 +85,7 @@
|SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
-
+|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
diff --git a/doc/doc_ch/algorithm_rec_vitstr.md b/doc/doc_ch/algorithm_rec_vitstr.md
new file mode 100644
index 0000000000000000000000000000000000000000..bd83b8d9c2d9474310cc12d716e9d34467bf74a5
--- /dev/null
+++ b/doc/doc_ch/algorithm_rec_vitstr.md
@@ -0,0 +1,154 @@
+# 场景文本识别算法-ViTSTR
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+ - [3.1 训练](#3-1)
+ - [3.2 评估](#3-2)
+ - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+ - [4.1 Python推理](#4-1)
+ - [4.2 C++推理](#4-2)
+ - [4.3 Serving服务化部署](#4-3)
+ - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. 算法简介
+
+论文信息:
+> [Vision Transformer for Fast and Efficient Scene Text Recognition](https://arxiv.org/abs/2105.08582)
+> Rowel Atienza
+> ICDAR, 2021
+
+
+
+`ViTSTR`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
+
+|模型|骨干网络|配置文件|Acc|下载链接|
+| --- | --- | --- | --- | --- |
+|ViTSTR|ViTSTR|[rec_vitstr.yml](../../configs/rec/rec_vitstr.yml)|79.82%|[训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar)|
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+
+### 3.1 模型训练
+
+请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`ViTSTR`识别模型时需要**更换配置文件**为`ViTSTR`的[配置文件](../../configs/rec/rec_ViTSTR.yml)。
+
+#### 启动训练
+
+
+具体地,在完成数据准备后,便可以启动训练,训练命令如下:
+```shell
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/rec_vitstr.yml
+
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vitstr.yml
+```
+
+
+### 3.2 评估
+
+可下载已训练完成的[模型文件](#model),使用如下命令进行评估:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vitstr.yml -o Global.pretrained_model=./rec_vitstr_train/best_accuracy
+```
+
+
+### 3.3 预测
+
+使用如下命令进行单张图片预测:
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/infer_rec.py -c configs/rec/rec_vitstr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_vitstr_train/best_accuracy
+# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
+```
+
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) ),可以使用如下命令进行转换:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/export_model.py -c configs/rec/rec_vitstr.yml -o Global.pretrained_model=./rec_vitstr_train/best_accuracy Global.save_inference_dir=./inference/rec_vitstr/
+```
+**注意:**
+- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
+- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应NRTR的`infer_shape`。
+
+转换成功后,在目录下有三个文件:
+```
+/inference/rec_vitstr/
+ ├── inference.pdiparams # 识别inference模型的参数文件
+ ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
+ └── inference.pdmodel # 识别inference模型的program文件
+```
+
+执行如下命令进行模型推理:
+
+```shell
+python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_vitstr/' --rec_algorithm='ViTSTR' --rec_image_shape='1,224,224' --rec_char_dict_path='./ppocr/utils/EN_symbol_dict.txt'
+# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
+```
+
+![](../imgs_words_en/word_10.png)
+
+执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
+结果如下:
+```shell
+Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9265879392623901)
+```
+
+**注意**:
+
+- 训练上述模型采用的图像分辨率是[1,224,224],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。
+- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
+- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中NRTR的预处理为您的预处理方法。
+
+
+
+### 4.2 C++推理部署
+
+由于C++预处理后处理还未支持NRTR,所以暂未支持
+
+
+### 4.3 Serving服务化部署
+
+暂不支持
+
+
+### 4.4 更多推理部署
+
+暂不支持
+
+
+## 5. FAQ
+
+1. 在`ViTSTR`论文中,使用在ImageNet1k上的预训练权重进行初始化训练,我们在训练未采用预训练权重,最终精度没有变化甚至有所提高。
+2. 我们仅仅复现了`ViTSTR`中的tiny版本,如果有需要使用small、base版本,可直接使用源开源repo中的预训练权重转为Paddle权重即可使用。
+
+## 引用
+
+```bibtex
+@article{Atienza2021ViTSTR,
+ title = {Vision Transformer for Fast and Efficient Scene Text Recognition},
+ author = {Rowel Atienza},
+ booktitle = {ICDAR},
+ year = {2021},
+ url = {https://arxiv.org/abs/2105.08582}
+}
+```
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index 383cbe39bbd2eb8ca85f497888920ce87cb1837e..213d95807dd14189b27051679b1791e43307d328 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -65,6 +65,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [SAR](./algorithm_rec_sar_en.md)
- [x] [SEED](./algorithm_rec_seed_en.md)
- [x] [SVTR](./algorithm_rec_svtr_en.md)
+- [x] [ViTSTR](./algorithm_rec_vitstr_en.md)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
@@ -83,7 +84,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SAR|Resnet31| 87.20% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
-
+|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
diff --git a/doc/doc_en/algorithm_rec_vitstr_en.md b/doc/doc_en/algorithm_rec_vitstr_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..666798c3e0b6fccf21fcbbc0e09fc1fad0c8acff
--- /dev/null
+++ b/doc/doc_en/algorithm_rec_vitstr_en.md
@@ -0,0 +1,134 @@
+# ViTSTR
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+ - [3.1 Training](#3-1)
+ - [3.2 Evaluation](#3-2)
+ - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+ - [4.1 Python Inference](#4-1)
+ - [4.2 C++ Inference](#4-2)
+ - [4.3 Serving](#4-3)
+ - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. Introduction
+
+Paper:
+> [Vision Transformer for Fast and Efficient Scene Text Recognition](https://arxiv.org/abs/2105.08582)
+> Rowel Atienza
+> ICDAR, 2021
+
+Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
+
+|Model|Backbone|config|Acc|Download link|
+| --- | --- | --- | --- | --- |
+|ViTSTR|ViTSTR|[rec_vitstr.yml](../../configs/rec/rec_vitstr.yml)|79.82%|[训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar)|
+
+
+## 2. Environment
+Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
+
+
+
+## 3. Model Training / Evaluation / Prediction
+
+Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
+
+Training:
+
+Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
+
+```
+#Single GPU training (long training period, not recommended)
+python3 tools/train.py -c configs/rec/rec_vitstr.yml
+
+#Multi GPU training, specify the gpu number through the --gpus parameter
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vitstr.yml
+```
+
+Evaluation:
+
+```
+# GPU evaluation
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vitstr.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+Prediction:
+
+```
+# The configuration file used for prediction must match the training
+python3 tools/infer_rec.py -c configs/rec/rec_vitstr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_vitstr_train/best_accuracy
+```
+
+
+## 4. Inference and Deployment
+
+
+### 4.1 Python Inference
+First, the model saved during the ViTSTR text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar)) ), you can use the following command to convert:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_vitstr.yml -o Global.pretrained_model=./rec_vitstr_train/best_accuracy Global.save_inference_dir=./inference/rec_vitstr
+```
+
+**Note:**
+- If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file.
+- If you modified the input size during training, please modify the `infer_shape` corresponding to ViTSTR in the `tools/export_model.py` file.
+
+After the conversion is successful, there are three files in the directory:
+```
+/inference/rec_vitstr/
+ ├── inference.pdiparams
+ ├── inference.pdiparams.info
+ └── inference.pdmodel
+```
+
+
+For ViTSTR text recognition model inference, the following commands can be executed:
+
+```
+python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_vitstr/' --rec_algorithm='ViTSTR' --rec_image_shape='1,224,224' --rec_char_dict_path='./ppocr/utils/EN_symbol_dict.txt'
+```
+
+![](../imgs_words_en/word_10.png)
+
+After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
+The result is as follows:
+```shell
+Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9265879392623901)
+```
+
+
+### 4.2 C++ Inference
+
+Not supported
+
+
+### 4.3 Serving
+
+Not supported
+
+
+### 4.4 More
+
+Not supported
+
+
+## 5. FAQ
+
+1. In the `ViTSTR` paper, using pre-trained weights on ImageNet1k for initial training, we did not use pre-trained weights in training, and the final accuracy did not change or even improved.
+
+## Citation
+
+```bibtex
+@article{Atienza2021ViTSTR,
+ title = {Vision Transformer for Fast and Efficient Scene Text Recognition},
+ author = {Rowel Atienza},
+ booktitle = {ICDAR},
+ year = {2021},
+ url = {https://arxiv.org/abs/2105.08582}
+}
+```
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 548832fb0d116ba2de622bd97562b591d74501d8..2dbc92a7037c58b09753330e9c5f1b9791252ef6 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
- SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
+ SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 02a5187dad27b76d04e866de45333d79383c1347..0366e3f7854513b79350dea4ddd6b29178c7fffc 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -443,7 +443,9 @@ class KieLabelEncode(object):
elif 'key_cls' in anno.keys():
labels.append(anno['key_cls'])
else:
- raise ValueError("Cannot found 'key_cls' in ann.keys(), please check your training annotation.")
+ raise ValueError(
+ "Cannot found 'key_cls' in ann.keys(), please check your training annotation."
+ )
edges.append(ann.get('edge', 0))
ann_infos = dict(
image=data['image'],
@@ -838,6 +840,37 @@ class PRENLabelEncode(BaseRecLabelEncode):
return data
+class ViTSTRLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+
+ super(ViTSTRLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) >= self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ text.insert(0, 0)
+ text.append(1)
+ text = text + [0] * (self.max_text_len + 2 - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = ['', ''] + dict_character
+ return dict_character
+
+
class VQATokenLabelEncode(object):
"""
Label encode for NLP VQA methods
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 7483dffe5b6d9a0a2204702757fcb49762a1cc7a..0697baf436fa1f345bbd33c7e0847be0d8f1df8c 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -87,11 +87,19 @@ class ClsResizeImg(object):
return data
-class NRTRRecResizeImg(object):
- def __init__(self, image_shape, resize_type, padding=False, **kwargs):
+class GrayRecResizeImg(object):
+ def __init__(self,
+ image_shape,
+ resize_type,
+ inter_type='Image.ANTIALIAS',
+ scale=True,
+ padding=False,
+ **kwargs):
self.image_shape = image_shape
self.resize_type = resize_type
self.padding = padding
+ self.inter_type = eval(inter_type)
+ self.scale = scale
def __call__(self, data):
img = data['image']
@@ -117,13 +125,16 @@ class NRTRRecResizeImg(object):
return data
if self.resize_type == 'PIL':
image_pil = Image.fromarray(np.uint8(img))
- img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
+ img = image_pil.resize(self.image_shape, self.inter_type)
img = np.array(img)
if self.resize_type == 'OpenCV':
img = cv2.resize(img, self.image_shape)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
- data['image'] = norm_img.astype(np.float32) / 128. - 1.
+ if self.scale:
+ data['image'] = norm_img.astype(np.float32) / 128. - 1.
+ else:
+ data['image'] = norm_img.astype(np.float32) / 255.
return data
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index de8419b7c1cf6a30ab7195a1cbcbb10a5e52642d..6c4545eb21a2a2cf7ddd1b0a0f2023b56b41e196 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -30,7 +30,7 @@ from .det_fce_loss import FCELoss
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
-from .rec_nrtr_loss import NRTRLoss
+from .rec_ce_smooth_loss import CESmoothingLoss
from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss
@@ -60,8 +60,9 @@ def build_loss(config):
support_dict = [
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
- 'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
- 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
+ 'CESmoothingLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss',
+ 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss',
+ 'MultiLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_ce_smooth_loss.py
similarity index 73%
rename from ppocr/losses/rec_nrtr_loss.py
rename to ppocr/losses/rec_ce_smooth_loss.py
index 200a6d0486dbf6f76dd674eb58f641b31a70f31c..22243ed41f4ce739377a39112c640c00cb4b7792 100644
--- a/ppocr/losses/rec_nrtr_loss.py
+++ b/ppocr/losses/rec_ce_smooth_loss.py
@@ -3,16 +3,20 @@ from paddle import nn
import paddle.nn.functional as F
-class NRTRLoss(nn.Layer):
- def __init__(self, smoothing=True, **kwargs):
- super(NRTRLoss, self).__init__()
+class CESmoothingLoss(nn.Layer):
+ def __init__(self, smoothing=True, with_all=False, **kwargs):
+ super(CESmoothingLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
self.smoothing = smoothing
+ self.with_all = with_all
def forward(self, pred, batch):
pred = pred.reshape([-1, pred.shape[2]])
- max_len = batch[2].max()
- tgt = batch[1][:, 1:2 + max_len]
+ if self.with_all:
+ tgt = batch[1]
+ else:
+ max_len = batch[2].max()
+ tgt = batch[1][:, 1:2 + max_len]
tgt = tgt.reshape([-1])
if self.smoothing:
eps = 0.1
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 072d6e0f84d4126d256c26aa5baf17c9dc4e63df..a368e7481628cd836963618cd4cbfca12ba2080b 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -32,10 +32,11 @@ def build_backbone(config, model_type):
from .rec_micronet import MicroNet
from .rec_efficientb3_pren import EfficientNetb3_PREN
from .rec_svtrnet import SVTRNet
+ from .rec_vitstr import ViTSTR
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
"ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
- 'SVTRNet'
+ 'SVTRNet', 'ViTSTR'
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/rec_svtrnet.py b/ppocr/modeling/backbones/rec_svtrnet.py
index c57bf46345d6e08f23b9258358f77f2285366314..c2c07f4476929d49237c8e9a10713f881f5f556b 100644
--- a/ppocr/modeling/backbones/rec_svtrnet.py
+++ b/ppocr/modeling/backbones/rec_svtrnet.py
@@ -147,7 +147,7 @@ class Attention(nn.Layer):
dim,
num_heads=8,
mixer='Global',
- HW=[8, 25],
+ HW=None,
local_k=[7, 11],
qkv_bias=False,
qk_scale=None,
@@ -210,7 +210,7 @@ class Block(nn.Layer):
num_heads,
mixer='Global',
local_mixer=[7, 11],
- HW=[8, 25],
+ HW=None,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
@@ -274,7 +274,9 @@ class PatchEmbed(nn.Layer):
img_size=[32, 100],
in_channels=3,
embed_dim=768,
- sub_num=2):
+ sub_num=2,
+ patch_size=[4, 4],
+ mode='pope'):
super().__init__()
num_patches = (img_size[1] // (2 ** sub_num)) * \
(img_size[0] // (2 ** sub_num))
@@ -282,50 +284,56 @@ class PatchEmbed(nn.Layer):
self.num_patches = num_patches
self.embed_dim = embed_dim
self.norm = None
- if sub_num == 2:
- self.proj = nn.Sequential(
- ConvBNLayer(
- in_channels=in_channels,
- out_channels=embed_dim // 2,
- kernel_size=3,
- stride=2,
- padding=1,
- act=nn.GELU,
- bias_attr=None),
- ConvBNLayer(
- in_channels=embed_dim // 2,
- out_channels=embed_dim,
- kernel_size=3,
- stride=2,
- padding=1,
- act=nn.GELU,
- bias_attr=None))
- if sub_num == 3:
- self.proj = nn.Sequential(
- ConvBNLayer(
- in_channels=in_channels,
- out_channels=embed_dim // 4,
- kernel_size=3,
- stride=2,
- padding=1,
- act=nn.GELU,
- bias_attr=None),
- ConvBNLayer(
- in_channels=embed_dim // 4,
- out_channels=embed_dim // 2,
- kernel_size=3,
- stride=2,
- padding=1,
- act=nn.GELU,
- bias_attr=None),
- ConvBNLayer(
- in_channels=embed_dim // 2,
- out_channels=embed_dim,
- kernel_size=3,
- stride=2,
- padding=1,
- act=nn.GELU,
- bias_attr=None))
+ if mode == 'pope':
+ if sub_num == 2:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None))
+ if sub_num == 3:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None),
+ ConvBNLayer(
+ in_channels=embed_dim // 4,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=None))
+ elif mode == 'linear':
+ self.proj = nn.Conv2D(
+ 1, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.num_patches = img_size[0] // patch_size[0] * img_size[
+ 1] // patch_size[1]
def forward(self, x):
B, C, H, W = x.shape
diff --git a/ppocr/modeling/backbones/rec_vitstr.py b/ppocr/modeling/backbones/rec_vitstr.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5d7d5148a1120e6f97a321b4135c6780c0c5db2
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_vitstr.py
@@ -0,0 +1,120 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/roatienza/deep-text-recognition-benchmark/blob/master/modules/vitstr.py
+"""
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+from ppocr.modeling.backbones.rec_svtrnet import Block, PatchEmbed, zeros_, trunc_normal_, ones_
+
+scale_dim_heads = {'tiny': [192, 3], 'small': [384, 6], 'base': [768, 12]}
+
+
+class ViTSTR(nn.Layer):
+ def __init__(self,
+ img_size=[224, 224],
+ in_channels=1,
+ scale='tiny',
+ seqlen=27,
+ patch_size=[16, 16],
+ embed_dim=None,
+ depth=12,
+ num_heads=None,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_path_rate=0.,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ norm_layer='nn.LayerNorm',
+ act_layer='nn.GELU',
+ epsilon=1e-6,
+ out_channels=None,
+ **kwargs):
+ super().__init__()
+ self.seqlen = seqlen
+ embed_dim = embed_dim if embed_dim is not None else scale_dim_heads[
+ scale][0]
+ num_heads = num_heads if num_heads is not None else scale_dim_heads[
+ scale][1]
+ out_channels = out_channels if out_channels is not None else embed_dim
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ in_channels=in_channels,
+ embed_dim=embed_dim,
+ patch_size=patch_size,
+ mode='linear')
+ num_patches = self.patch_embed.num_patches
+
+ self.pos_embed = self.create_parameter(
+ shape=[1, num_patches + 1, embed_dim], default_initializer=zeros_)
+ self.add_parameter("pos_embed", self.pos_embed)
+ self.cls_token = self.create_parameter(
+ shape=[1, 1, embed_dim], default_initializer=zeros_)
+ self.add_parameter("cls_token", self.cls_token)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = np.linspace(0, drop_path_rate, depth)
+ self.blocks = nn.LayerList([
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=eval(act_layer),
+ epsilon=epsilon,
+ prenorm=False) for i in range(depth)
+ ])
+ self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
+
+ self.out_channels = out_channels
+
+ trunc_normal_(self.pos_embed)
+ trunc_normal_(self.cls_token)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+ cls_tokens = paddle.tile(self.cls_token, repeat_times=[B, 1, 1])
+ x = paddle.concat((cls_tokens, x), axis=1)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+ for blk in self.blocks:
+ x = blk(x)
+ x = self.norm(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = x[:, :self.seqlen]
+ return x.transpose([0, 2, 1]).unsqueeze(2)
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index f50b5f1c5f8e617066bb47636c8f4d2b171b6ecb..4f900ee1fc716ebec78feaaed07bf258d3d0df0a 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
- SEEDLabelDecode, PRENLabelDecode
+ SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
- 'DistillationSARLabelDecode'
+ 'DistillationSARLabelDecode', 'ViTSTRLabelDecode'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index bf0fd890bf25949361665d212bf8e1a657054e5b..df6203fadaae9e99c1df125e7172166e5e3d8acd 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -188,13 +188,13 @@ class NRTRLabelDecode(BaseRecLabelDecode):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
- if text_index[batch_idx][idx] == 3: # end
- break
try:
- char_list.append(self.character[int(text_index[batch_idx][
- idx])])
+ char_idx = self.character[int(text_index[batch_idx][idx])]
except:
continue
+ if char_idx == '': # end
+ break
+ char_list.append(char_idx)
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
@@ -204,6 +204,32 @@ class NRTRLabelDecode(BaseRecLabelDecode):
return result_list
+class ViTSTRLabelDecode(NRTRLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(ViTSTRLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ if isinstance(preds, paddle.Tensor):
+ preds = preds[:, 1:].numpy()
+ else:
+ preds = preds[:, 1:]
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if label is None:
+ return text
+ label = self.decode(label[:, 1:])
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = ['', ''] + dict_character
+ return dict_character
+
+
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
diff --git a/test_tipc/configs/rec_mtb_nrtr/rec_mtb_nrtr.yml b/test_tipc/configs/rec_mtb_nrtr/rec_mtb_nrtr.yml
index 15119bb2a9de02c19684d21ad5a1859db94895ce..3936ab58adfca2b5f900b99c84766e3c1058236e 100644
--- a/test_tipc/configs/rec_mtb_nrtr/rec_mtb_nrtr.yml
+++ b/test_tipc/configs/rec_mtb_nrtr/rec_mtb_nrtr.yml
@@ -49,7 +49,7 @@ Architecture:
Loss:
- name: NRTRLoss
+ name: CESmoothingLoss
smoothing: True
PostProcess:
@@ -69,7 +69,7 @@ Train:
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- - NRTRRecResizeImg:
+ - GrayRecResizeImg:
image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys:
@@ -90,7 +90,7 @@ Eval:
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- - NRTRRecResizeImg:
+ - GrayRecResizeImg:
image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys:
diff --git a/test_tipc/configs/rec_svtrnet/rec_svtrnet.yml b/test_tipc/configs/rec_svtrnet/rec_svtrnet.yml
new file mode 100644
index 0000000000000000000000000000000000000000..26facca34d20536cb19c3b1f80b0828ebc817e50
--- /dev/null
+++ b/test_tipc/configs/rec_svtrnet/rec_svtrnet.yml
@@ -0,0 +1,119 @@
+Global:
+ use_gpu: True
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec/svtr/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations after the 0th iteration
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words_en/word_10.png
+ # for data or label process
+ character_dict_path:
+ character_type: en
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_svtr_tiny.txt
+
+
+Optimizer:
+ name: AdamW
+ beta1: 0.9
+ beta2: 0.99
+ epsilon: 0.00000008
+ weight_decay: 0.05
+ no_weight_decay_name: norm pos_embed
+ one_dim_param_no_weight_decay: true
+ lr:
+ name: Cosine
+ learning_rate: 0.0005
+ warmup_epoch: 2
+
+Architecture:
+ model_type: rec
+ algorithm: SVTR
+ Transform:
+ name: STN_ON
+ tps_inputsize: [32, 64]
+ tps_outputsize: [32, 100]
+ num_control_points: 20
+ tps_margins: [0.05,0.05]
+ stn_activation: none
+ Backbone:
+ name: SVTRNet
+ img_size: [32, 100]
+ out_char_num: 25
+ out_channels: 192
+ patch_merging: 'Conv'
+ embed_dim: [64, 128, 256]
+ depth: [3, 6, 3]
+ num_heads: [2, 4, 8]
+ mixer: ['Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[7, 11], [7, 11], [7, 11]]
+ last_stage: True
+ prenorm: false
+ Neck:
+ name: SequenceEncoder
+ encoder_type: reshape
+ Head:
+ name: CTCHead
+
+Loss:
+ name: CTCLoss
+
+PostProcess:
+ name: CTCLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - CTCLabelEncode: # Class handling label
+ - RecResizeImg:
+ character_dict_path:
+ image_shape: [3, 64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 512
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data
+ label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - CTCLabelEncode: # Class handling label
+ - RecResizeImg:
+ character_dict_path:
+ image_shape: [3, 64, 256]
+ padding: False
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/test_tipc/configs/rec_svtrnet/train_infer_python.txt b/test_tipc/configs/rec_svtrnet/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..72526063e090b40b8926f8fdc2acc42a705841e6
--- /dev/null
+++ b/test_tipc/configs/rec_svtrnet/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:rec_svtrnet
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./inference/rec_inference
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/rec_svtrnet_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/rec_svtrnet/rec_svtrnet.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbol_dict.txt --rec_image_shape="1,224,224" --rec_algorithm="SVTR"
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:1|6
+--rec_batch_num:1|6
+--use_tensorrt:False
+--precision:fp32|int8
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[1,224,224]}]
diff --git a/test_tipc/configs/rec_vitstr/rec_vitstr.yml b/test_tipc/configs/rec_vitstr/rec_vitstr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..427bce4b5adfd6ddb51e162741d10a9ba003d001
--- /dev/null
+++ b/test_tipc/configs/rec_vitstr/rec_vitstr.yml
@@ -0,0 +1,101 @@
+Global:
+ use_gpu: True
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec/vitstr/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations after the 0th iteration#
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words_en/word_10.png
+ # for data or label process
+ character_dict_path: ppocr/utils/EN_symbol_dict.txt
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_vitstr.txt
+
+
+Optimizer:
+ name: Adadelta
+ epsilon: 0.00000001
+ rho: 0.95
+ clip_norm: 5.0
+ lr:
+ learning_rate: 1.0
+
+Architecture:
+ model_type: rec
+ algorithm: ViTSTR
+ in_channels: 1
+ Transform:
+ Backbone:
+ name: ViTSTR
+ Neck:
+ name: SequenceEncoder
+ encoder_type: reshape
+ Head:
+ name: CTCHead
+
+Loss:
+ name: CESmoothingLoss
+ smoothing: False
+ with_all: True
+
+PostProcess:
+ name: ViTSTRLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ViTSTRLabelEncode: # Class handling label
+ - GrayRecResizeImg:
+ image_shape: [224, 224] # W H
+ resize_type: PIL # PIL or OpenCV
+ inter_type: 'Image.BICUBIC'
+ scale: false
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 48
+ drop_last: True
+ num_workers: 2
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data
+ label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ViTSTRLabelEncode: # Class handling label
+ - GrayRecResizeImg:
+ image_shape: [224, 224] # W H
+ resize_type: PIL # PIL or OpenCV
+ inter_type: 'Image.BICUBIC'
+ scale: false
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/test_tipc/configs/rec_vitstr/train_infer_python.txt b/test_tipc/configs/rec_vitstr/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6c7af1fb87375922350b7b54af6997a03cc91b1a
--- /dev/null
+++ b/test_tipc/configs/rec_vitstr/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:rec_vitstr
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./inference/rec_inference
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/rec_vitstr/rec_vitstr.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/rec_vitstr/rec_vitstr.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/rec_vitstr/rec_vitstr.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/rec_vitstr_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/rec_vitstr/rec_vitstr.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbol_dict.txt --rec_image_shape="1,224,224" --rec_algorithm="ViTSTR"
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:1|6
+--rec_batch_num:1|6
+--use_tensorrt:False
+--precision:fp32|int8
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[1,224,224]}]
diff --git a/tools/export_model.py b/tools/export_model.py
index c0cbcd361cec31c51616a7154836c234f076a86e..6e003f2ffa8755913fe76ee893677da70e2459b2 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -73,6 +73,12 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
shape=[None, 3, 64, 512], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] == "ViTSTR":
+ other_shape = [
+ paddle.static.InputSpec(
+ shape=[None, 1, 224, 224], dtype="float32"),
+ ]
+ model = to_static(model, input_spec=other_shape)
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 3664ef2caf4b888d6a3918202256c99cc54c5eb1..1945667972310cdef03daac7d2bfdb52373b950b 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -69,6 +69,12 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
+ elif self.rec_algorithm == 'ViTSTR':
+ postprocess_params = {
+ 'name': 'ViTSTRLabelDecode',
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
@@ -96,15 +102,22 @@ class TextRecognizer(object):
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
- if self.rec_algorithm == 'NRTR':
+ if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR':
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
image_pil = Image.fromarray(np.uint8(img))
- img = image_pil.resize([100, 32], Image.ANTIALIAS)
+ if self.rec_algorithm == 'ViTSTR':
+ img = image_pil.resize([imgW, imgH], Image.BICUBIC)
+ else:
+ img = image_pil.resize([imgW, imgH], Image.ANTIALIAS)
img = np.array(img)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
- return norm_img.astype(np.float32) / 128. - 1.
+ if self.rec_algorithm == 'ViTSTR':
+ norm_img = norm_img.astype(np.float32) / 255.
+ else:
+ norm_img = norm_img.astype(np.float32) / 128. - 1.
+ return norm_img
assert imgC == img.shape[2]
imgW = int((imgH * max_wh_ratio))
diff --git a/tools/program.py b/tools/program.py
index aa0d2698cf66c928f87217996c31c042e1c8aa02..745c28b87292480159ed41285f2a79c6bf5d0abe 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -307,7 +307,8 @@ def train(config,
train_stats.update(stats)
if log_writer is not None and dist.get_rank() == 0:
- log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
+ log_writer.log_metrics(
+ metrics=train_stats.get(), prefix="TRAIN", step=global_step)
if dist.get_rank() == 0 and (
(global_step > 0 and global_step % print_batch_step == 0) or
@@ -354,7 +355,8 @@ def train(config,
# logger metric
if log_writer is not None:
- log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)
+ log_writer.log_metrics(
+ metrics=cur_metric, prefix="EVAL", step=global_step)
if cur_metric[main_indicator] >= best_model_dict[
main_indicator]:
@@ -377,11 +379,18 @@ def train(config,
logger.info(best_str)
# logger best metric
if log_writer is not None:
- log_writer.log_metrics(metrics={
- "best_{}".format(main_indicator): best_model_dict[main_indicator]
- }, prefix="EVAL", step=global_step)
-
- log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
+ log_writer.log_metrics(
+ metrics={
+ "best_{}".format(main_indicator):
+ best_model_dict[main_indicator]
+ },
+ prefix="EVAL",
+ step=global_step)
+
+ log_writer.log_model(
+ is_best=True,
+ prefix="best_accuracy",
+ metadata=best_model_dict)
reader_start = time.time()
if dist.get_rank() == 0:
@@ -413,7 +422,8 @@ def train(config,
epoch=epoch,
global_step=global_step)
if log_writer is not None:
- log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
+ log_writer.log_model(
+ is_best=False, prefix='iter_epoch_{}'.format(epoch))
best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
@@ -564,7 +574,8 @@ def preprocess(is_train=False):
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
- 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
+ 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
+ 'ViTSTR'
]
if use_xpu:
@@ -585,7 +596,8 @@ def preprocess(is_train=False):
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
log_writer = VDLLogger(save_model_dir)
loggers.append(log_writer)
- if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
+ if ('use_wandb' in config['Global'] and
+ config['Global']['use_wandb']) or 'wandb' in config:
save_dir = config['Global']['save_model_dir']
wandb_writer_path = "{}/wandb".format(save_dir)
if "wandb" in config: