diff --git a/configs/rec/rec_r45_visionlan.yml b/configs/rec/rec_r45_visionlan.yml
new file mode 100644
index 0000000000000000000000000000000000000000..25017653a37941af07cfc3cfa092e17309c966b9
--- /dev/null
+++ b/configs/rec/rec_r45_visionlan.yml
@@ -0,0 +1,106 @@
+Global:
+ use_gpu: true
+ epoch_num: 8
+ log_smooth_window: 200
+ print_batch_step: 200
+ save_model_dir: ./output/rec/r45_visionlan
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: True
+ infer_img: doc/imgs_words/en/word_2.png
+ # for data or label process
+ character_dict_path:
+ max_text_length: &max_text_length 25
+ training_step: &training_step LA
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_visionlan.txt
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ clip_norm: 20.0
+ group_lr: true
+ training_step: *training_step
+ lr:
+ name: Piecewise
+ decay_epochs: [6]
+ values: [0.0001, 0.00001]
+ regularizer:
+ name: 'L2'
+ factor: 0
+
+Architecture:
+ model_type: rec
+ algorithm: VisionLAN
+ Transform:
+ Backbone:
+ name: ResNet45
+ strides: [2, 2, 2, 1, 1]
+ Head:
+ name: VLHead
+ n_layers: 3
+ n_position: 256
+ n_dim: 512
+ max_text_length: *max_text_length
+ training_step: *training_step
+
+Loss:
+ name: VLLoss
+ mode: *training_step
+ weight_res: 0.5
+ weight_mas: 0.5
+
+PostProcess:
+ name: VLLabelDecode
+
+Metric:
+ name: RecMetric
+ is_filter: true
+
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/data_lmdb_release/training/
+ transforms:
+ - DecodeImage: # load image
+ img_mode: RGB
+ channel_first: False
+ - ABINetRecAug:
+ - VLLabelEncode: # Class handling label
+ - VLRecResizeImg:
+ image_shape: [3, 64, 256]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 220
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/data_lmdb_release/validation/
+ transforms:
+ - DecodeImage: # load image
+ img_mode: RGB
+ channel_first: False
+ - VLLabelEncode: # Class handling label
+ - VLRecResizeImg:
+ image_shape: [3, 64, 256]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 64
+ num_workers: 4
+
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index fbd3ce9ebccec0b1c2133b52e4aeb9d4d5e21114..9d725a86ab8f48051fdb36fe20e94fbe88abc2f6 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -69,6 +69,7 @@
- [x] [SVTR](./algorithm_rec_svtr.md)
- [x] [ViTSTR](./algorithm_rec_vitstr.md)
- [x] [ABINet](./algorithm_rec_abinet.md)
+- [x] [VisionLAN](./algorithm_rec_visionlan.md)
- [x] [SPIN](./algorithm_rec_spin.md)
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
@@ -90,6 +91,7 @@
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
+|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
diff --git a/doc/doc_ch/algorithm_rec_visionlan.md b/doc/doc_ch/algorithm_rec_visionlan.md
new file mode 100644
index 0000000000000000000000000000000000000000..0c4fe86e58831f4f5480483f5c21ff1da4176d2b
--- /dev/null
+++ b/doc/doc_ch/algorithm_rec_visionlan.md
@@ -0,0 +1,154 @@
+# 场景文本识别算法-VisionLAN
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+ - [3.1 训练](#3-1)
+ - [3.2 评估](#3-2)
+ - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+ - [4.1 Python推理](#4-1)
+ - [4.2 C++推理](#4-2)
+ - [4.3 Serving服务化部署](#4-3)
+ - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. 算法简介
+
+论文信息:
+> [From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network](https://arxiv.org/abs/2108.09661)
+> Yuxin Wang, Hongtao Xie, Shancheng Fang, Jing Wang, Shenggao Zhu, Yongdong Zhang
+> ICCV, 2021
+
+
+
+`VisionLAN`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
+
+|模型|骨干网络|配置文件|Acc|下载链接|
+| --- | --- | --- | --- | --- |
+|VisionLAN|ResNet45|[rec_r45_visionlan.yml](../../configs/rec/rec_r45_visionlan.yml)|90.3%|[预训练、训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)|
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+
+### 3.1 模型训练
+
+请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`VisionLAN`识别模型时需要**更换配置文件**为`VisionLAN`的[配置文件](../../configs/rec/rec_r45_visionlan.yml)。
+
+#### 启动训练
+
+
+具体地,在完成数据准备后,便可以启动训练,训练命令如下:
+```shell
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/rec_r45_visionlan.yml
+
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r45_visionlan.yml
+```
+
+
+### 3.2 评估
+
+可下载已训练完成的[模型文件](#model),使用如下命令进行评估:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/eval.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy
+```
+
+
+### 3.3 预测
+
+使用如下命令进行单张图片预测:
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/infer_rec.py -c configs/rec/rec_r45_visionlan.yml -o Global.infer_img='./doc/imgs_words/en/word_2.png' Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy
+# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
+```
+
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)),可以使用如下命令进行转换:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/export_model.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy Global.save_inference_dir=./inference/rec_r45_visionlan/
+```
+**注意:**
+- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
+- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应VisionLAN的`infer_shape`。
+
+转换成功后,在目录下有三个文件:
+```
+./inference/rec_r45_visionlan/
+ ├── inference.pdiparams # 识别inference模型的参数文件
+ ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
+ └── inference.pdmodel # 识别inference模型的program文件
+```
+
+执行如下命令进行模型推理:
+
+```shell
+python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/dict36.txt'
+# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
+```
+
+![](../imgs_words/en/word_2.png)
+
+执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
+结果如下:
+```shell
+Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.97076982)
+```
+
+**注意**:
+
+- 训练上述模型采用的图像分辨率是[3,64,256],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。
+- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
+- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中VisionLAN的预处理为您的预处理方法。
+
+
+
+### 4.2 C++推理部署
+
+由于C++预处理后处理还未支持VisionLAN,所以暂未支持
+
+
+### 4.3 Serving服务化部署
+
+暂不支持
+
+
+### 4.4 更多推理部署
+
+暂不支持
+
+
+## 5. FAQ
+
+1. MJSynth和SynthText两种数据集来自于[VisionLAN源repo](https://github.com/wangyuxin87/VisionLAN) 。
+2. 我们使用VisionLAN作者提供的预训练模型进行finetune训练。
+
+## 引用
+
+```bibtex
+@inproceedings{wang2021two,
+ title={From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network},
+ author={Wang, Yuxin and Xie, Hongtao and Fang, Shancheng and Wang, Jing and Zhu, Shenggao and Zhang, Yongdong},
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
+ pages={14194--14203},
+ year={2021}
+}
+```
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index a579d2447c52067e05d16af5e9d6cf50defc2b1c..dfd8ecda5c306aeb41902caccc2b6079f4f86542 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -68,6 +68,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [SVTR](./algorithm_rec_svtr_en.md)
- [x] [ViTSTR](./algorithm_rec_vitstr_en.md)
- [x] [ABINet](./algorithm_rec_abinet_en.md)
+- [x] [VisionLAN](./algorithm_rec_visionlan_en.md)
- [x] [SPIN](./algorithm_rec_spin_en.md)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
@@ -89,6 +90,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
+|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
diff --git a/doc/doc_en/algorithm_rec_visionlan_en.md b/doc/doc_en/algorithm_rec_visionlan_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..ebd02d52f4252c672b4a76c940ccdd621f5354ef
--- /dev/null
+++ b/doc/doc_en/algorithm_rec_visionlan_en.md
@@ -0,0 +1,135 @@
+# VisionLAN
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+ - [3.1 Training](#3-1)
+ - [3.2 Evaluation](#3-2)
+ - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+ - [4.1 Python Inference](#4-1)
+ - [4.2 C++ Inference](#4-2)
+ - [4.3 Serving](#4-3)
+ - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. Introduction
+
+Paper:
+> [From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network](https://arxiv.org/abs/2108.09661)
+> Yuxin Wang, Hongtao Xie, Shancheng Fang, Jing Wang, Shenggao Zhu, Yongdong Zhang
+> ICCV, 2021
+
+Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
+
+|Model|Backbone|config|Acc|Download link|
+| --- | --- | --- | --- | --- |
+|VisionLAN|ResNet45|[rec_r45_visionlan.yml](../../configs/rec/rec_r45_visionlan.yml)|90.3%|[预训练、训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)|
+
+
+## 2. Environment
+Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
+
+
+
+## 3. Model Training / Evaluation / Prediction
+
+Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
+
+Training:
+
+Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
+
+```
+#Single GPU training (long training period, not recommended)
+python3 tools/train.py -c configs/rec/rec_r45_visionlan.yml
+
+#Multi GPU training, specify the gpu number through the --gpus parameter
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r45_visionlan.yml
+```
+
+Evaluation:
+
+```
+# GPU evaluation
+python3 tools/eval.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+Prediction:
+
+```
+# The configuration file used for prediction must match the training
+python3 tools/infer_rec.py -c configs/rec/rec_r45_visionlan.yml -o Global.infer_img='./doc/imgs_words/en/word_2.png' Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy
+```
+
+
+## 4. Inference and Deployment
+
+
+### 4.1 Python Inference
+First, the model saved during the VisionLAN text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar)) ), you can use the following command to convert:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_r45_visionlan.yml -o Global.pretrained_model=./rec_r45_visionlan_train/best_accuracy Global.save_inference_dir=./inference/rec_r45_visionlan/
+```
+
+**Note:**
+- If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file.
+- If you modified the input size during training, please modify the `infer_shape` corresponding to VisionLAN in the `tools/export_model.py` file.
+
+After the conversion is successful, there are three files in the directory:
+```
+./inference/rec_r45_visionlan/
+ ├── inference.pdiparams
+ ├── inference.pdiparams.info
+ └── inference.pdmodel
+```
+
+
+For VisionLAN text recognition model inference, the following commands can be executed:
+
+```
+python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words/en/word_2.png' --rec_model_dir='./inference/rec_r45_visionlan/' --rec_algorithm='VisionLAN' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/dict36.txt'
+```
+
+![](../imgs_words/en/word_2.png)
+
+After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
+The result is as follows:
+```shell
+Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.97076982)
+```
+
+
+### 4.2 C++ Inference
+
+Not supported
+
+
+### 4.3 Serving
+
+Not supported
+
+
+### 4.4 More
+
+Not supported
+
+
+## 5. FAQ
+
+1. Note that the MJSynth and SynthText datasets come from [VisionLAN repo](https://github.com/wangyuxin87/VisionLAN).
+2. We use the pre-trained model provided by the VisionLAN authors for finetune training.
+
+## Citation
+
+```bibtex
+@inproceedings{wang2021two,
+ title={From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network},
+ author={Wang, Yuxin and Xie, Hongtao and Fang, Shancheng and Wang, Jing and Zhu, Shenggao and Zhang, Yongdong},
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
+ pages={14194--14203},
+ year={2021}
+}
+```
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index d41eed9dfbd2980242e76fa8d8aae380a6594cd4..a2332b6c07be63ecfe2fa9003cbe9d0c1b0e8001 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -25,8 +25,9 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
- SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
- ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, SPINRecResizeImg
+ SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
+ ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg
+
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index ce539dcea9608762f725e5a3ae501e384360d04d..03314dde3a8b5d52373f1fc1d74411e126c304cb 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -23,6 +23,8 @@ import string
from shapely.geometry import LineString, Point, Polygon
import json
import copy
+from random import sample
+
from ppocr.utils.logging import get_logger
from ppocr.data.imaug.vqa.augment import order_by_tbyx
@@ -98,12 +100,13 @@ class BaseRecLabelEncode(object):
def __init__(self,
max_text_length,
character_dict_path=None,
- use_space_char=False):
+ use_space_char=False,
+ lower=False):
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
- self.lower = False
+ self.lower = lower
if character_dict_path is None:
logger = get_logger()
@@ -1273,3 +1276,67 @@ class SPINLabelEncode(AttnLabelEncode):
padded_text[:len(target)] = target
data['label'] = np.array(padded_text)
return data
+
+
+class VLLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ lower=True,
+ **kwargs):
+ super(VLLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char, lower)
+ self.character = self.character[10:] + self.character[
+ 1:10] + [self.character[0]]
+ self.dict = {}
+ for i, char in enumerate(self.character):
+ self.dict[char] = i
+
+ def __call__(self, data):
+ text = data['label'] # original string
+ # generate occluded text
+ len_str = len(text)
+ if len_str <= 0:
+ return None
+ change_num = 1
+ order = list(range(len_str))
+ change_id = sample(order, change_num)[0]
+ label_sub = text[change_id]
+ if change_id == (len_str - 1):
+ label_res = text[:change_id]
+ elif change_id == 0:
+ label_res = text[1:]
+ else:
+ label_res = text[:change_id] + text[change_id + 1:]
+
+ data['label_res'] = label_res # remaining string
+ data['label_sub'] = label_sub # occluded character
+ data['label_id'] = change_id # character index
+ # encode label
+ text = self.encode(text)
+ if text is None:
+ return None
+ text = [i + 1 for i in text]
+ data['length'] = np.array(len(text))
+ text = text + [0] * (self.max_text_len - len(text))
+ data['label'] = np.array(text)
+ label_res = self.encode(label_res)
+ label_sub = self.encode(label_sub)
+ if label_res is None:
+ label_res = []
+ else:
+ label_res = [i + 1 for i in label_res]
+ if label_sub is None:
+ label_sub = []
+ else:
+ label_sub = [i + 1 for i in label_sub]
+ data['length_res'] = np.array(len(label_res))
+ data['length_sub'] = np.array(len(label_sub))
+ label_res = label_res + [0] * (self.max_text_len - len(label_res))
+ label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
+ data['label_res'] = np.array(label_res)
+ data['label_sub'] = np.array(label_sub)
+ return data
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index c5d8a3b2fd773a1877a788401a926d7fbca07adf..725b4b0617c2f0808c7bf99077e2f62caa3afbf0 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -205,6 +205,38 @@ class RecResizeImg(object):
return data
+class VLRecResizeImg(object):
+ def __init__(self,
+ image_shape,
+ infer_mode=False,
+ character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
+ padding=True,
+ **kwargs):
+ self.image_shape = image_shape
+ self.infer_mode = infer_mode
+ self.character_dict_path = character_dict_path
+ self.padding = padding
+
+ def __call__(self, data):
+ img = data['image']
+
+ imgC, imgH, imgW = self.image_shape
+ resized_image = cv2.resize(
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_w = imgW
+ resized_image = resized_image.astype('float32')
+ if self.image_shape[0] == 1:
+ resized_image = resized_image / 255
+ norm_img = resized_image[np.newaxis, :]
+ else:
+ norm_img = resized_image.transpose((2, 0, 1)) / 255
+ valid_ratio = min(1.0, float(resized_w / imgW))
+
+ data['image'] = norm_img
+ data['valid_ratio'] = valid_ratio
+ return data
+
+
class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape
@@ -259,6 +291,7 @@ class PRENResizeImg(object):
data['image'] = resized_img.astype(np.float32)
return data
+
class SPINRecResizeImg(object):
def __init__(self,
image_shape,
@@ -267,7 +300,7 @@ class SPINRecResizeImg(object):
std=(127.5, 127.5, 127.5),
**kwargs):
self.image_shape = image_shape
-
+
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.interpolation = interpolation
@@ -303,6 +336,7 @@ class SPINRecResizeImg(object):
data['image'] = img
return data
+
class GrayRecResizeImg(object):
def __init__(self,
image_shape,
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 4629f0fe4478b783f5d9f4a7c41a626c413678bc..8f3adfccd46b7cedd3141e1cfce5baba621c8676 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss
+from .rec_vl_loss import VLLoss
from .rec_spin_att_loss import SPINAttentionLoss
# cls loss
@@ -63,7 +64,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
- 'TableMasterLoss', 'SPINAttentionLoss', 'SLANetLoss'
+ 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLANetLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/rec_vl_loss.py b/ppocr/losses/rec_vl_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cd87c709bbf81d2fd83d721d49086256f2ab629
--- /dev/null
+++ b/ppocr/losses/rec_vl_loss.py
@@ -0,0 +1,70 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/wangyuxin87/VisionLAN
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class VLLoss(nn.Layer):
+ def __init__(self, mode='LF_1', weight_res=0.5, weight_mas=0.5, **kwargs):
+ super(VLLoss, self).__init__()
+ self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean")
+ assert mode in ['LF_1', 'LF_2', 'LA']
+ self.mode = mode
+ self.weight_res = weight_res
+ self.weight_mas = weight_mas
+
+ def flatten_label(self, target):
+ label_flatten = []
+ label_length = []
+ for i in range(0, target.shape[0]):
+ cur_label = target[i].tolist()
+ label_flatten += cur_label[:cur_label.index(0) + 1]
+ label_length.append(cur_label.index(0) + 1)
+ label_flatten = paddle.to_tensor(label_flatten, dtype='int64')
+ label_length = paddle.to_tensor(label_length, dtype='int32')
+ return (label_flatten, label_length)
+
+ def _flatten(self, sources, lengths):
+ return paddle.concat([t[:l] for t, l in zip(sources, lengths)])
+
+ def forward(self, predicts, batch):
+ text_pre = predicts[0]
+ target = batch[1].astype('int64')
+ label_flatten, length = self.flatten_label(target)
+ text_pre = self._flatten(text_pre, length)
+ if self.mode == 'LF_1':
+ loss = self.loss_func(text_pre, label_flatten)
+ else:
+ text_rem = predicts[1]
+ text_mas = predicts[2]
+ target_res = batch[2].astype('int64')
+ target_sub = batch[3].astype('int64')
+ label_flatten_res, length_res = self.flatten_label(target_res)
+ label_flatten_sub, length_sub = self.flatten_label(target_sub)
+ text_rem = self._flatten(text_rem, length_res)
+ text_mas = self._flatten(text_mas, length_sub)
+ loss_ori = self.loss_func(text_pre, label_flatten)
+ loss_res = self.loss_func(text_rem, label_flatten_res)
+ loss_mas = self.loss_func(text_mas, label_flatten_sub)
+ loss = loss_ori + loss_res * self.weight_res + loss_mas * self.weight_mas
+ return {'loss': loss}
diff --git a/ppocr/modeling/backbones/rec_resnet_45.py b/ppocr/modeling/backbones/rec_resnet_45.py
index 9093d0bc99b78806d36662dec36b6cfbdd4ae493..083eb7f48811cf6887845f98bbeae315b727287d 100644
--- a/ppocr/modeling/backbones/rec_resnet_45.py
+++ b/ppocr/modeling/backbones/rec_resnet_45.py
@@ -84,11 +84,15 @@ class BasicBlock(nn.Layer):
class ResNet45(nn.Layer):
- def __init__(self, block=BasicBlock, layers=[3, 4, 6, 6, 3], in_channels=3):
+ def __init__(self,
+ in_channels=3,
+ block=BasicBlock,
+ layers=[3, 4, 6, 6, 3],
+ strides=[2, 1, 2, 1, 1]):
self.inplanes = 32
super(ResNet45, self).__init__()
self.conv1 = nn.Conv2D(
- 3,
+ in_channels,
32,
kernel_size=3,
stride=1,
@@ -98,18 +102,13 @@ class ResNet45(nn.Layer):
self.bn1 = nn.BatchNorm2D(32)
self.relu = nn.ReLU()
- self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
- self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
- self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
- self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
- self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
+ self.layer1 = self._make_layer(block, 32, layers[0], stride=strides[0])
+ self.layer2 = self._make_layer(block, 64, layers[1], stride=strides[1])
+ self.layer3 = self._make_layer(block, 128, layers[2], stride=strides[2])
+ self.layer4 = self._make_layer(block, 256, layers[3], stride=strides[3])
+ self.layer5 = self._make_layer(block, 512, layers[4], stride=strides[4])
self.out_channels = 512
- # for m in self.modules():
- # if isinstance(m, nn.Conv2D):
- # n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
- # m.weight.data.normal_(0, math.sqrt(2. / n))
-
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
@@ -137,11 +136,9 @@ class ResNet45(nn.Layer):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
- # print(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
- # print(x)
x = self.layer4(x)
x = self.layer5(x)
return x
diff --git a/ppocr/modeling/backbones/rec_resnet_aster.py b/ppocr/modeling/backbones/rec_resnet_aster.py
index 6a2710dfa079b4d910146c10ca2cff31321b2513..782dc393ea3c8b67d68fb9f4b038afc85ffcad93 100644
--- a/ppocr/modeling/backbones/rec_resnet_aster.py
+++ b/ppocr/modeling/backbones/rec_resnet_aster.py
@@ -140,4 +140,4 @@ class ResNet_ASTER(nn.Layer):
rnn_feat, _ = self.rnn(cnn_feat)
return rnn_feat
else:
- return cnn_feat
+ return cnn_feat
\ No newline at end of file
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index d8289d458d50f476b74b3a75e58795bdb2385a6c..3f6ff0c4e0240ff4f241f475e70dc6211106a659 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -35,6 +35,7 @@ def build_head(config):
from .rec_multi_head import MultiHead
from .rec_spin_att_head import SPINAttentionHead
from .rec_abinet_head import ABINetHead
+ from .rec_visionlan_head import VLHead
# cls head
from .cls_head import ClsHead
@@ -50,7 +51,7 @@ def build_head(config):
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
- 'SLAHead'
+ 'VLHead', 'SLAHead'
]
#table head
diff --git a/ppocr/modeling/heads/rec_visionlan_head.py b/ppocr/modeling/heads/rec_visionlan_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..86054d9bbb12613e3119b4c0d72f4670344d773a
--- /dev/null
+++ b/ppocr/modeling/heads/rec_visionlan_head.py
@@ -0,0 +1,468 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/wangyuxin87/VisionLAN
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.nn.initializer import Normal, XavierNormal
+import numpy as np
+
+
+class PositionalEncoding(nn.Layer):
+ def __init__(self, d_hid, n_position=200):
+ super(PositionalEncoding, self).__init__()
+ self.register_buffer(
+ 'pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
+
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
+ ''' Sinusoid position encoding table '''
+
+ def get_position_angle_vec(position):
+ return [
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
+ for hid_j in range(d_hid)
+ ]
+
+ sinusoid_table = np.array(
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)])
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+ sinusoid_table = paddle.to_tensor(sinusoid_table, dtype='float32')
+ sinusoid_table = paddle.unsqueeze(sinusoid_table, axis=0)
+ return sinusoid_table
+
+ def forward(self, x):
+ return x + self.pos_table[:, :x.shape[1]].clone().detach()
+
+
+class ScaledDotProductAttention(nn.Layer):
+ "Scaled Dot-Product Attention"
+
+ def __init__(self, temperature, attn_dropout=0.1):
+ super(ScaledDotProductAttention, self).__init__()
+ self.temperature = temperature
+ self.dropout = nn.Dropout(attn_dropout)
+ self.softmax = nn.Softmax(axis=2)
+
+ def forward(self, q, k, v, mask=None):
+ k = paddle.transpose(k, perm=[0, 2, 1])
+ attn = paddle.bmm(q, k)
+ attn = attn / self.temperature
+ if mask is not None:
+ attn = attn.masked_fill(mask, -1e9)
+ if mask.dim() == 3:
+ mask = paddle.unsqueeze(mask, axis=1)
+ elif mask.dim() == 2:
+ mask = paddle.unsqueeze(mask, axis=1)
+ mask = paddle.unsqueeze(mask, axis=1)
+ repeat_times = [
+ attn.shape[1] // mask.shape[1], attn.shape[2] // mask.shape[2]
+ ]
+ mask = paddle.tile(mask, [1, repeat_times[0], repeat_times[1], 1])
+ attn[mask == 0] = -1e9
+ attn = self.softmax(attn)
+ attn = self.dropout(attn)
+ output = paddle.bmm(attn, v)
+ return output
+
+
+class MultiHeadAttention(nn.Layer):
+ " Multi-Head Attention module"
+
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
+ super(MultiHeadAttention, self).__init__()
+ self.n_head = n_head
+ self.d_k = d_k
+ self.d_v = d_v
+ self.w_qs = nn.Linear(
+ d_model,
+ n_head * d_k,
+ weight_attr=ParamAttr(initializer=Normal(
+ mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
+ self.w_ks = nn.Linear(
+ d_model,
+ n_head * d_k,
+ weight_attr=ParamAttr(initializer=Normal(
+ mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
+ self.w_vs = nn.Linear(
+ d_model,
+ n_head * d_v,
+ weight_attr=ParamAttr(initializer=Normal(
+ mean=0, std=np.sqrt(2.0 / (d_model + d_v)))))
+
+ self.attention = ScaledDotProductAttention(temperature=np.power(d_k,
+ 0.5))
+ self.layer_norm = nn.LayerNorm(d_model)
+ self.fc = nn.Linear(
+ n_head * d_v,
+ d_model,
+ weight_attr=ParamAttr(initializer=XavierNormal()))
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, q, k, v, mask=None):
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
+ sz_b, len_q, _ = q.shape
+ sz_b, len_k, _ = k.shape
+ sz_b, len_v, _ = v.shape
+ residual = q
+
+ q = self.w_qs(q)
+ q = paddle.reshape(
+ q, shape=[-1, len_q, n_head, d_k]) # 4*21*512 ---- 4*21*8*64
+ k = self.w_ks(k)
+ k = paddle.reshape(k, shape=[-1, len_k, n_head, d_k])
+ v = self.w_vs(v)
+ v = paddle.reshape(v, shape=[-1, len_v, n_head, d_v])
+
+ q = paddle.transpose(q, perm=[2, 0, 1, 3])
+ q = paddle.reshape(q, shape=[-1, len_q, d_k]) # (n*b) x lq x dk
+ k = paddle.transpose(k, perm=[2, 0, 1, 3])
+ k = paddle.reshape(k, shape=[-1, len_k, d_k]) # (n*b) x lk x dk
+ v = paddle.transpose(v, perm=[2, 0, 1, 3])
+ v = paddle.reshape(v, shape=[-1, len_v, d_v]) # (n*b) x lv x dv
+
+ mask = paddle.tile(
+ mask,
+ [n_head, 1, 1]) if mask is not None else None # (n*b) x .. x ..
+ output = self.attention(q, k, v, mask=mask)
+ output = paddle.reshape(output, shape=[n_head, -1, len_q, d_v])
+ output = paddle.transpose(output, perm=[1, 2, 0, 3])
+ output = paddle.reshape(
+ output, shape=[-1, len_q, n_head * d_v]) # b x lq x (n*dv)
+ output = self.dropout(self.fc(output))
+ output = self.layer_norm(output + residual)
+ return output
+
+
+class PositionwiseFeedForward(nn.Layer):
+ def __init__(self, d_in, d_hid, dropout=0.1):
+ super(PositionwiseFeedForward, self).__init__()
+ self.w_1 = nn.Conv1D(d_in, d_hid, 1) # position-wise
+ self.w_2 = nn.Conv1D(d_hid, d_in, 1) # position-wise
+ self.layer_norm = nn.LayerNorm(d_in)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ residual = x
+ x = paddle.transpose(x, perm=[0, 2, 1])
+ x = self.w_2(F.relu(self.w_1(x)))
+ x = paddle.transpose(x, perm=[0, 2, 1])
+ x = self.dropout(x)
+ x = self.layer_norm(x + residual)
+ return x
+
+
+class EncoderLayer(nn.Layer):
+ ''' Compose with two layers '''
+
+ def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
+ super(EncoderLayer, self).__init__()
+ self.slf_attn = MultiHeadAttention(
+ n_head, d_model, d_k, d_v, dropout=dropout)
+ self.pos_ffn = PositionwiseFeedForward(
+ d_model, d_inner, dropout=dropout)
+
+ def forward(self, enc_input, slf_attn_mask=None):
+ enc_output = self.slf_attn(
+ enc_input, enc_input, enc_input, mask=slf_attn_mask)
+ enc_output = self.pos_ffn(enc_output)
+ return enc_output
+
+
+class Transformer_Encoder(nn.Layer):
+ def __init__(self,
+ n_layers=2,
+ n_head=8,
+ d_word_vec=512,
+ d_k=64,
+ d_v=64,
+ d_model=512,
+ d_inner=2048,
+ dropout=0.1,
+ n_position=256):
+ super(Transformer_Encoder, self).__init__()
+ self.position_enc = PositionalEncoding(
+ d_word_vec, n_position=n_position)
+ self.dropout = nn.Dropout(p=dropout)
+ self.layer_stack = nn.LayerList([
+ EncoderLayer(
+ d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
+ for _ in range(n_layers)
+ ])
+ self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
+
+ def forward(self, enc_output, src_mask, return_attns=False):
+ enc_output = self.dropout(
+ self.position_enc(enc_output)) # position embeding
+ for enc_layer in self.layer_stack:
+ enc_output = enc_layer(enc_output, slf_attn_mask=src_mask)
+ enc_output = self.layer_norm(enc_output)
+ return enc_output
+
+
+class PP_layer(nn.Layer):
+ def __init__(self, n_dim=512, N_max_character=25, n_position=256):
+
+ super(PP_layer, self).__init__()
+ self.character_len = N_max_character
+ self.f0_embedding = nn.Embedding(N_max_character, n_dim)
+ self.w0 = nn.Linear(N_max_character, n_position)
+ self.wv = nn.Linear(n_dim, n_dim)
+ self.we = nn.Linear(n_dim, N_max_character)
+ self.active = nn.Tanh()
+ self.softmax = nn.Softmax(axis=2)
+
+ def forward(self, enc_output):
+ # enc_output: b,256,512
+ reading_order = paddle.arange(self.character_len, dtype='int64')
+ reading_order = reading_order.unsqueeze(0).expand(
+ [enc_output.shape[0], self.character_len]) # (S,) -> (B, S)
+ reading_order = self.f0_embedding(reading_order) # b,25,512
+
+ # calculate attention
+ reading_order = paddle.transpose(reading_order, perm=[0, 2, 1])
+ t = self.w0(reading_order) # b,512,256
+ t = self.active(
+ paddle.transpose(
+ t, perm=[0, 2, 1]) + self.wv(enc_output)) # b,256,512
+ t = self.we(t) # b,256,25
+ t = self.softmax(paddle.transpose(t, perm=[0, 2, 1])) # b,25,256
+ g_output = paddle.bmm(t, enc_output) # b,25,512
+ return g_output
+
+
+class Prediction(nn.Layer):
+ def __init__(self,
+ n_dim=512,
+ n_position=256,
+ N_max_character=25,
+ n_class=37):
+ super(Prediction, self).__init__()
+ self.pp = PP_layer(
+ n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
+ self.pp_share = PP_layer(
+ n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
+ self.w_vrm = nn.Linear(n_dim, n_class) # output layer
+ self.w_share = nn.Linear(n_dim, n_class) # output layer
+ self.nclass = n_class
+
+ def forward(self, cnn_feature, f_res, f_sub, train_mode=False,
+ use_mlm=True):
+ if train_mode:
+ if not use_mlm:
+ g_output = self.pp(cnn_feature) # b,25,512
+ g_output = self.w_vrm(g_output)
+ f_res = 0
+ f_sub = 0
+ return g_output, f_res, f_sub
+ g_output = self.pp(cnn_feature) # b,25,512
+ f_res = self.pp_share(f_res)
+ f_sub = self.pp_share(f_sub)
+ g_output = self.w_vrm(g_output)
+ f_res = self.w_share(f_res)
+ f_sub = self.w_share(f_sub)
+ return g_output, f_res, f_sub
+ else:
+ g_output = self.pp(cnn_feature) # b,25,512
+ g_output = self.w_vrm(g_output)
+ return g_output
+
+
+class MLM(nn.Layer):
+ "Architecture of MLM"
+
+ def __init__(self, n_dim=512, n_position=256, max_text_length=25):
+ super(MLM, self).__init__()
+ self.MLM_SequenceModeling_mask = Transformer_Encoder(
+ n_layers=2, n_position=n_position)
+ self.MLM_SequenceModeling_WCL = Transformer_Encoder(
+ n_layers=1, n_position=n_position)
+ self.pos_embedding = nn.Embedding(max_text_length, n_dim)
+ self.w0_linear = nn.Linear(1, n_position)
+ self.wv = nn.Linear(n_dim, n_dim)
+ self.active = nn.Tanh()
+ self.we = nn.Linear(n_dim, 1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x, label_pos):
+ # transformer unit for generating mask_c
+ feature_v_seq = self.MLM_SequenceModeling_mask(x, src_mask=None)
+ # position embedding layer
+ label_pos = paddle.to_tensor(label_pos, dtype='int64')
+ pos_emb = self.pos_embedding(label_pos)
+ pos_emb = self.w0_linear(paddle.unsqueeze(pos_emb, axis=2))
+ pos_emb = paddle.transpose(pos_emb, perm=[0, 2, 1])
+ # fusion position embedding with features V & generate mask_c
+ att_map_sub = self.active(pos_emb + self.wv(feature_v_seq))
+ att_map_sub = self.we(att_map_sub) # b,256,1
+ att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
+ att_map_sub = self.sigmoid(att_map_sub) # b,1,256
+ # WCL
+ ## generate inputs for WCL
+ att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
+ f_res = x * (1 - att_map_sub) # second path with remaining string
+ f_sub = x * att_map_sub # first path with occluded character
+ ## transformer units in WCL
+ f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None)
+ f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None)
+ return f_res, f_sub, att_map_sub
+
+
+def trans_1d_2d(x):
+ b, w_h, c = x.shape # b, 256, 512
+ x = paddle.transpose(x, perm=[0, 2, 1])
+ x = paddle.reshape(x, [-1, c, 32, 8])
+ x = paddle.transpose(x, perm=[0, 1, 3, 2]) # [b, c, 8, 32]
+ return x
+
+
+class MLM_VRM(nn.Layer):
+ """
+ MLM+VRM, MLM is only used in training.
+ ratio controls the occluded number in a batch.
+ The pipeline of VisionLAN in testing is very concise with only a backbone + sequence modeling(transformer unit) + prediction layer(pp layer).
+ x: input image
+ label_pos: character index
+ training_step: LF or LA process
+ output
+ text_pre: prediction of VRM
+ test_rem: prediction of remaining string in MLM
+ text_mas: prediction of occluded character in MLM
+ mask_c_show: visualization of Mask_c
+ """
+
+ def __init__(self,
+ n_layers=3,
+ n_position=256,
+ n_dim=512,
+ max_text_length=25,
+ nclass=37):
+ super(MLM_VRM, self).__init__()
+ self.MLM = MLM(n_dim=n_dim,
+ n_position=n_position,
+ max_text_length=max_text_length)
+ self.SequenceModeling = Transformer_Encoder(
+ n_layers=n_layers, n_position=n_position)
+ self.Prediction = Prediction(
+ n_dim=n_dim,
+ n_position=n_position,
+ N_max_character=max_text_length +
+ 1, # N_max_character = 1 eos + 25 characters
+ n_class=nclass)
+ self.nclass = nclass
+ self.max_text_length = max_text_length
+
+ def forward(self, x, label_pos, training_step, train_mode=False):
+ b, c, h, w = x.shape
+ nT = self.max_text_length
+ x = paddle.transpose(x, perm=[0, 1, 3, 2])
+ x = paddle.reshape(x, [-1, c, h * w])
+ x = paddle.transpose(x, perm=[0, 2, 1])
+ if train_mode:
+ if training_step == 'LF_1':
+ f_res = 0
+ f_sub = 0
+ x = self.SequenceModeling(x, src_mask=None)
+ text_pre, test_rem, text_mas = self.Prediction(
+ x, f_res, f_sub, train_mode=True, use_mlm=False)
+ return text_pre, text_pre, text_pre, text_pre
+ elif training_step == 'LF_2':
+ # MLM
+ f_res, f_sub, mask_c = self.MLM(x, label_pos)
+ x = self.SequenceModeling(x, src_mask=None)
+ text_pre, test_rem, text_mas = self.Prediction(
+ x, f_res, f_sub, train_mode=True)
+ mask_c_show = trans_1d_2d(mask_c)
+ return text_pre, test_rem, text_mas, mask_c_show
+ elif training_step == 'LA':
+ # MLM
+ f_res, f_sub, mask_c = self.MLM(x, label_pos)
+ ## use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
+ ## ratio controls the occluded number in a batch
+ character_mask = paddle.zeros_like(mask_c)
+
+ ratio = b // 2
+ if ratio >= 1:
+ with paddle.no_grad():
+ character_mask[0:ratio, :, :] = mask_c[0:ratio, :, :]
+ else:
+ character_mask = mask_c
+ x = x * (1 - character_mask)
+ # VRM
+ ## transformer unit for VRM
+ x = self.SequenceModeling(x, src_mask=None)
+ ## prediction layer for MLM and VSR
+ text_pre, test_rem, text_mas = self.Prediction(
+ x, f_res, f_sub, train_mode=True)
+ mask_c_show = trans_1d_2d(mask_c)
+ return text_pre, test_rem, text_mas, mask_c_show
+ else:
+ raise NotImplementedError
+ else: # VRM is only used in the testing stage
+ f_res = 0
+ f_sub = 0
+ contextual_feature = self.SequenceModeling(x, src_mask=None)
+ text_pre = self.Prediction(
+ contextual_feature,
+ f_res,
+ f_sub,
+ train_mode=False,
+ use_mlm=False)
+ text_pre = paddle.transpose(
+ text_pre, perm=[1, 0, 2]) # (26, b, 37))
+ return text_pre, x
+
+
+class VLHead(nn.Layer):
+ """
+ Architecture of VisionLAN
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels=36,
+ n_layers=3,
+ n_position=256,
+ n_dim=512,
+ max_text_length=25,
+ training_step='LA'):
+ super(VLHead, self).__init__()
+ self.MLM_VRM = MLM_VRM(
+ n_layers=n_layers,
+ n_position=n_position,
+ n_dim=n_dim,
+ max_text_length=max_text_length,
+ nclass=out_channels + 1)
+ self.training_step = training_step
+
+ def forward(self, feat, targets=None):
+
+ if self.training:
+ label_pos = targets[-2]
+ text_pre, test_rem, text_mas, mask_map = self.MLM_VRM(
+ feat, label_pos, self.training_step, train_mode=True)
+ return text_pre, test_rem, text_mas, mask_map
+ else:
+ text_pre, x = self.MLM_VRM(
+ feat, targets, self.training_step, train_mode=False)
+ return text_pre, x
diff --git a/ppocr/optimizer/optimizer.py b/ppocr/optimizer/optimizer.py
index dd8544e2e7d39be33a9096cad16c4d58eb58bcad..144f011c79ec2303b7fbc73ac078afe3ce92c255 100644
--- a/ppocr/optimizer/optimizer.py
+++ b/ppocr/optimizer/optimizer.py
@@ -77,11 +77,62 @@ class Adam(object):
self.grad_clip = grad_clip
self.name = name
self.lazy_mode = lazy_mode
+ self.group_lr = kwargs.get('group_lr', False)
+ self.training_step = kwargs.get('training_step', None)
def __call__(self, model):
- train_params = [
- param for param in model.parameters() if param.trainable is True
- ]
+ if self.group_lr:
+ if self.training_step == 'LF_2':
+ import paddle
+ if isinstance(model, paddle.fluid.dygraph.parallel.
+ DataParallel): # multi gpu
+ mlm = model._layers.head.MLM_VRM.MLM.parameters()
+ pre_mlm_pp = model._layers.head.MLM_VRM.Prediction.pp_share.parameters(
+ )
+ pre_mlm_w = model._layers.head.MLM_VRM.Prediction.w_share.parameters(
+ )
+ else: # single gpu
+ mlm = model.head.MLM_VRM.MLM.parameters()
+ pre_mlm_pp = model.head.MLM_VRM.Prediction.pp_share.parameters(
+ )
+ pre_mlm_w = model.head.MLM_VRM.Prediction.w_share.parameters(
+ )
+
+ total = []
+ for param in mlm:
+ total.append(id(param))
+ for param in pre_mlm_pp:
+ total.append(id(param))
+ for param in pre_mlm_w:
+ total.append(id(param))
+
+ group_base_params = [
+ param for param in model.parameters() if id(param) in total
+ ]
+ group_small_params = [
+ param for param in model.parameters()
+ if id(param) not in total
+ ]
+ train_params = [{
+ 'params': group_base_params
+ }, {
+ 'params': group_small_params,
+ 'learning_rate': self.learning_rate.values[0] * 0.1
+ }]
+
+ else:
+ print(
+ 'group lr currently only support VisionLAN in LF_2 training step'
+ )
+ train_params = [
+ param for param in model.parameters()
+ if param.trainable is True
+ ]
+ else:
+ train_params = [
+ param for param in model.parameters() if param.trainable is True
+ ]
+
opt = optim.Adam(
learning_rate=self.learning_rate,
beta1=self.beta1,
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 6fa871a45cdce9f5fc308f7e54f8980d852ebc8c..7c0c7fd003a38966a24fd116d8cfd3805aed6797 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
- SPINLabelDecode
+ SPINLabelDecode, VLLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
@@ -38,31 +38,16 @@ from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
def build_post_process(config, global_config=None):
support_dict = [
- 'DBPostProcess',
- 'EASTPostProcess',
- 'SASTPostProcess',
- 'FCEPostProcess',
- 'CTCLabelDecode',
- 'AttnLabelDecode',
- 'ClsPostProcess',
- 'SRNLabelDecode',
- 'PGPostProcess',
- 'DistillationCTCLabelDecode',
- 'TableLabelDecode',
- 'DistillationDBPostProcess',
- 'NRTRLabelDecode',
- 'SARLabelDecode',
- 'SEEDLabelDecode',
- 'VQASerTokenLayoutLMPostProcess',
- 'VQAReTokenLayoutLMPostProcess',
- 'PRENLabelDecode',
- 'DistillationSARLabelDecode',
- 'ViTSTRLabelDecode',
- 'ABINetLabelDecode',
- 'TableMasterLabelDecode',
- 'SPINLabelDecode',
- 'DistillationSerPostProcess',
- 'DistillationRePostProcess',
+ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'FCEPostProcess',
+ 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
+ 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
+ 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
+ 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
+ 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
+ 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
+ 'TableMasterLabelDecode', 'SPINLabelDecode',
+ 'DistillationSerPostProcess', 'DistillationRePostProcess',
+ 'VLLabelDecode'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 3fe29aabe58f42faa02d1b25b4255ba8a19b3ea3..7b994f810d6747a91aceec82641f433d816b3feb 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -668,6 +668,7 @@ class ABINetLabelDecode(NRTRLabelDecode):
dict_character = [''] + dict_character
return dict_character
+
class SPINLabelDecode(AttnLabelDecode):
""" Convert between text-label and text-index """
@@ -681,4 +682,106 @@ class SPINLabelDecode(AttnLabelDecode):
self.end_str = "eos"
dict_character = dict_character
dict_character = [self.beg_str] + [self.end_str] + dict_character
- return dict_character
\ No newline at end of file
+ return dict_character
+
+
+class VLLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
+ self.max_text_length = kwargs.get('max_text_length', 25)
+ self.nclass = len(self.character) + 1
+ self.character = self.character[10:] + self.character[
+ 1:10] + [self.character[0]]
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ ignored_tokens = self.get_ignored_tokens()
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ selection = np.ones(len(text_index[batch_idx]), dtype=bool)
+ if is_remove_duplicate:
+ selection[1:] = text_index[batch_idx][1:] != text_index[
+ batch_idx][:-1]
+ for ignored_token in ignored_tokens:
+ selection &= text_index[batch_idx] != ignored_token
+
+ char_list = [
+ self.character[text_id - 1]
+ for text_id in text_index[batch_idx][selection]
+ ]
+ if text_prob is not None:
+ conf_list = text_prob[batch_idx][selection]
+ else:
+ conf_list = [1] * len(selection)
+ if len(conf_list) == 0:
+ conf_list = [0]
+
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
+
+ def __call__(self, preds, label=None, length=None, *args, **kwargs):
+ if len(preds) == 2: # eval mode
+ text_pre, x = preds
+ b = text_pre.shape[1]
+ lenText = self.max_text_length
+ nsteps = self.max_text_length
+
+ if not isinstance(text_pre, paddle.Tensor):
+ text_pre = paddle.to_tensor(text_pre, dtype='float32')
+
+ out_res = paddle.zeros(
+ shape=[lenText, b, self.nclass], dtype=x.dtype)
+ out_length = paddle.zeros(shape=[b], dtype=x.dtype)
+ now_step = 0
+ for _ in range(nsteps):
+ if 0 in out_length and now_step < nsteps:
+ tmp_result = text_pre[now_step, :, :]
+ out_res[now_step] = tmp_result
+ tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
+ for j in range(b):
+ if out_length[j] == 0 and tmp_result[j] == 0:
+ out_length[j] = now_step + 1
+ now_step += 1
+ for j in range(0, b):
+ if int(out_length[j]) == 0:
+ out_length[j] = nsteps
+ start = 0
+ output = paddle.zeros(
+ shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
+ for i in range(0, b):
+ cur_length = int(out_length[i])
+ output[start:start + cur_length] = out_res[0:cur_length, i, :]
+ start += cur_length
+ net_out = output
+ length = out_length
+
+ else: # train mode
+ net_out = preds[0]
+ length = length
+ net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
+ text = []
+ if not isinstance(net_out, paddle.Tensor):
+ net_out = paddle.to_tensor(net_out, dtype='float32')
+ net_out = F.softmax(net_out, axis=1)
+ for i in range(0, length.shape[0]):
+ preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
+ ) + length[i])].topk(1)[1][:, 0].tolist()
+ preds_text = ''.join([
+ self.character[idx - 1]
+ if idx > 0 and idx <= len(self.character) else ''
+ for idx in preds_idx
+ ])
+ preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
+ ) + length[i])].topk(1)[0][:, 0]
+ preds_prob = paddle.exp(
+ paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
+ text.append((preds_text, preds_prob))
+ if label is None:
+ return text
+ label = self.decode(label)
+ return text, label
diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py
index 8fded687c62e8de9ff126037ec2a9fd88db9590d..e77a6ce0183611569193e1996e935f4bd30400a0 100644
--- a/ppocr/utils/save_load.py
+++ b/ppocr/utils/save_load.py
@@ -53,6 +53,7 @@ def load_model(config, model, optimizer=None, model_type='det'):
checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model')
best_model_dict = {}
+ is_float16 = False
if model_type == 'vqa':
# NOTE: for vqa model, resume training is not supported now
@@ -100,6 +101,9 @@ def load_model(config, model, optimizer=None, model_type='det'):
key, params.keys()))
continue
pre_value = params[key]
+ if pre_value.dtype == paddle.float16:
+ pre_value = pre_value.astype(paddle.float32)
+ is_float16 = True
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
@@ -107,7 +111,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
"The shape of model params {} {} not matched with loaded params shape {} !".
format(key, value.shape, pre_value.shape))
model.set_state_dict(new_state_dict)
-
+ if is_float16:
+ logger.info(
+ "The parameter type is float16, which is converted to float32 when loading"
+ )
if optimizer is not None:
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
@@ -126,9 +133,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
- load_pretrained_params(model, pretrained_model)
+ is_float16 = load_pretrained_params(model, pretrained_model)
else:
logger.info('train from scratch')
+ best_model_dict['is_float16'] = is_float16
return best_model_dict
@@ -142,19 +150,28 @@ def load_pretrained_params(model, path):
params = paddle.load(path + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
+ is_float16 = False
for k1 in params.keys():
if k1 not in state_dict.keys():
logger.warning("The pretrained params {} not in model".format(k1))
else:
+ if params[k1].dtype == paddle.float16:
+ params[k1] = params[k1].astype(paddle.float32)
+ is_float16 = True
if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1]
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params {} {} !".
format(k1, state_dict[k1].shape, k1, params[k1].shape))
+
model.set_state_dict(new_state_dict)
+ if is_float16:
+ logger.info(
+ "The parameter type is float16, which is converted to float32 when loading"
+ )
logger.info("load pretrain successful from {}".format(path))
- return model
+ return is_float16
def save_model(model,
diff --git a/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml b/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml
index 3a513b8f38cd5abf800c86f8fbeda789cb3d056a..29f6f32a58739e181d0c0f54d62021e3754a324a 100644
--- a/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml
+++ b/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml
@@ -8,7 +8,7 @@ Global:
# evaluation is run every 835 iterations
eval_batch_step: [0, 4000]
cal_metric_during_train: False
- pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
+ pretrained_model: pretrain_models/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy.pdparams
checkpoints:
save_inference_dir:
use_visualdl: False
diff --git a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt
index 887c3285eccf59c1833eead48893a807ded12fee..34082bc193a2ebd8f4c7a9e7c9ce55dc8dbf8e40 100644
--- a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt
+++ b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt
@@ -6,14 +6,14 @@ Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=8|whole_train_whole_infer=8
+Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
Architecture.Backbone.checkpoints:null
train_model_name:latest
train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
+norm_train:tools/train.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Architecture.Backbone.checkpoints:
-norm_export:tools/export_model.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o
+norm_export:tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o
quant_export:
fpgm_export:
distill_export:null
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index 931024382ee48637b09c22c2f20297a5591c13ad..76543f39e4952b40368cdd392acc430dda8fcd9b 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -108,7 +108,7 @@ if [ ${MODE} = "benchmark_train" ];then
fi
if [ ${model_name} == "layoutxlm_ser" ]; then
pip install -r ppstructure/vqa/requirements.txt
- pip install paddlenlp\>=2.3.5 --force-reinstall
+ pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
cd ./train_data/ && tar xf XFUND.tar
# expand gt.txt 10 times
@@ -222,7 +222,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
fi
if [ ${model_name} == "layoutxlm_ser" ]; then
pip install -r ppstructure/vqa/requirements.txt
- pip install paddlenlp\>=2.3.5 --force-reinstall
+ pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
cd ./train_data/ && tar xf XFUND.tar
cd ../
diff --git a/tools/eval.py b/tools/eval.py
index cab28334396c54f1526f830044de0772b5402a11..2fc53488efa2c4c475d31af47f69b3560e6cc69a 100755
--- a/tools/eval.py
+++ b/tools/eval.py
@@ -73,7 +73,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
- extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
+ extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
diff --git a/tools/export_model.py b/tools/export_model.py
index 69ac904c661fad77255c70563fdf1f16c5c29875..78932c987d8bc57216ef3586c2bdc0cdbd6a9037 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -97,6 +97,12 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] == "VisionLAN":
+ other_shape = [
+ paddle.static.InputSpec(
+ shape=[None, 3, 64, 256], dtype="float32"),
+ ]
+ model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(
@@ -217,4 +223,4 @@ def main():
if __name__ == "__main__":
- main()
+ main()
\ No newline at end of file
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index fdbf429be0ef2008d05c141504fcc216987112b3..4e4150c515fc2d0ee4eb7e635cb8c81a467e748f 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -69,6 +69,12 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
+ elif self.rec_algorithm == "VisionLAN":
+ postprocess_params = {
+ 'name': 'VLLabelDecode',
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
elif self.rec_algorithm == 'ViTSTR':
postprocess_params = {
'name': 'ViTSTRLabelDecode',
@@ -157,6 +163,16 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
+ def resize_norm_img_vl(self, img, image_shape):
+
+ imgC, imgH, imgW = image_shape
+ img = img[:, :, ::-1] # bgr2rgb
+ resized_image = cv2.resize(
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ resized_image = resized_image.astype('float32')
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ return resized_image
+
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
@@ -280,6 +296,7 @@ class TextRecognizer(object):
img -= mean
img *= stdinv
return img
+
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
@@ -359,6 +376,11 @@ class TextRecognizer(object):
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
+ elif self.rec_algorithm == "VisionLAN":
+ norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
+ self.rec_image_shape)
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
elif self.rec_algorithm == 'SPIN':
norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
diff --git a/tools/infer_rec.py b/tools/infer_rec.py
index a08fa25b467482da4a2996912ad2cc8cc7c398da..182694e6cda12ead0e263bb94a7d6483a6f7f212 100755
--- a/tools/infer_rec.py
+++ b/tools/infer_rec.py
@@ -131,7 +131,6 @@ def main():
if config['Architecture']['algorithm'] == "SAR":
valid_ratio = np.expand_dims(batch[-1], axis=0)
img_metas = [paddle.to_tensor(valid_ratio)]
-
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
diff --git a/tools/program.py b/tools/program.py
index 1802e8529d4943993cd5ef7bff75fe5dc42d41d5..17f26003022f2f1cd158ad7d13f516e621e4bcab 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -161,7 +161,7 @@ def to_float32(preds):
if isinstance(preds[k], dict) or isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
- preds[k] = preds[k].astype(paddle.float32)
+ preds[k] = paddle.to_tensor(preds[k], dtype='float32')
elif isinstance(preds, list):
for k in range(len(preds)):
if isinstance(preds[k], dict):
@@ -169,9 +169,9 @@ def to_float32(preds):
elif isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
- preds[k] = preds[k].astype(paddle.float32)
+ preds[k] = paddle.to_tensor(preds[k], dtype='float32')
else:
- preds = preds.astype(paddle.float32)
+ preds = paddle.to_tensor(preds, dtype='float32')
return preds
@@ -227,7 +227,9 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
- extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN"]
+ extra_input_models = [
+ "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN"
+ ]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
@@ -269,7 +271,6 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
-
# use amp
if scaler:
with paddle.amp.auto_cast(level='O2'):
@@ -310,6 +311,9 @@ def train(config,
]: # for multi head loss
post_result = post_process_class(
preds['ctc'], batch[1]) # for CTC head out
+ elif config['Loss']['name'] in ['VLLoss']:
+ post_result = post_process_class(preds, batch[1],
+ batch[-1])
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
@@ -372,7 +376,8 @@ def train(config,
post_process_class,
eval_class,
model_type,
- extra_input=extra_input)
+ extra_input=extra_input,
+ scaler=scaler)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str)
@@ -462,7 +467,8 @@ def eval(model,
post_process_class,
eval_class,
model_type=None,
- extra_input=False):
+ extra_input=False,
+ scaler=None):
model.eval()
with paddle.no_grad():
total_frame = 0.0
@@ -479,12 +485,24 @@ def eval(model,
break
images = batch[0]
start = time.time()
- if model_type == 'table' or extra_input:
- preds = model(images, data=batch[1:])
- elif model_type in ["kie", 'vqa']:
- preds = model(batch)
+
+ # use amp
+ if scaler:
+ with paddle.amp.auto_cast(level='O2'):
+ if model_type == 'table' or extra_input:
+ preds = model(images, data=batch[1:])
+ elif model_type in ["kie", 'vqa']:
+ preds = model(batch)
+ else:
+ preds = model(images)
else:
- preds = model(images)
+ if model_type == 'table' or extra_input:
+ preds = model(images, data=batch[1:])
+ elif model_type in ["kie", 'vqa']:
+ preds = model(batch)
+ else:
+ preds = model(images)
+
batch_numpy = []
for item in batch:
if isinstance(item, paddle.Tensor):
@@ -598,7 +616,8 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
- 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'SLANet'
+ 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
+ 'SLANet'
]
if use_xpu:
@@ -617,7 +636,7 @@ def preprocess(is_train=False):
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
- log_writer = VDLLogger(save_model_dir)
+ log_writer = VDLLogger(vdl_writer_path)
loggers.append(log_writer)
if ('use_wandb' in config['Global'] and
config['Global']['use_wandb']) or 'wandb' in config: