diff --git a/deploy/slim/prune/sensitivity_anal.py b/deploy/slim/prune/sensitivity_anal.py index 306f1a83ae0945614518514dcd00ca869254d5f8..be64a6bcf860c3e2e7a8a6fa20c4c241149a147b 100644 --- a/deploy/slim/prune/sensitivity_anal.py +++ b/deploy/slim/prune/sensitivity_anal.py @@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer): config['Optimizer'], epochs=config['Global']['epoch_num'], step_each_epoch=len(train_dataloader), - parameters=model.parameters()) + model=model) # build metric eval_class = build_metric(config['Metric']) diff --git a/doc/doc_ch/algorithm_rec_nrtr.md b/doc/doc_ch/algorithm_rec_nrtr.md new file mode 100644 index 0000000000000000000000000000000000000000..f05b8c7ba82a2b490dce9948fbe2abcaa7495c62 --- /dev/null +++ b/doc/doc_ch/algorithm_rec_nrtr.md @@ -0,0 +1,154 @@ +# 场景文本识别算法-NRTR + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition](https://arxiv.org/abs/1806.00926) +> Fenfen Sheng and Zhineng Chen and Bo Xu +> ICDAR, 2019 + + + +`NRTR`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下: + +|模型|骨干网络|配置文件|Acc|下载链接| +| --- | --- | --- | --- | --- | +|NRTR|MTB|[rec_mtb_nrtr.yml](../../configs/rec/rec_mtb_nrtr.yml)|84.21%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar)| + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + + +### 3.1 模型训练 + +请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`NRTR`识别模型时需要**更换配置文件**为`NRTR`的[配置文件](../../configs/rec/rec_mtb_nrtr.yml)。 + +#### 启动训练 + + +具体地,在完成数据准备后,便可以启动训练,训练命令如下: +```shell +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/rec/rec_mtb_nrtr.yml + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_mtb_nrtr.yml +``` + + +### 3.2 评估 + +可下载已训练完成的[模型文件](#model),使用如下命令进行评估: + +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 tools/eval.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy +``` + + +### 3.3 预测 + +使用如下命令进行单张图片预测: +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy +# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。 +``` + + + +## 4. 推理部署 + + +### 4.1 Python推理 +首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) ),可以使用如下命令进行转换: + +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 tools/export_model.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.save_inference_dir=./inference/rec_mtb_nrtr/ +``` +**注意:** +- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。 +- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应NRTR的`infer_shape`。 + +转换成功后,在目录下有三个文件: +``` +/inference/rec_mtb_nrtr/ + ├── inference.pdiparams # 识别inference模型的参数文件 + ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略 + └── inference.pdmodel # 识别inference模型的program文件 +``` + +执行如下命令进行模型推理: + +```shell +python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_mtb_nrtr/' --rec_algorithm='NRTR' --rec_image_shape='1,32,100' --rec_char_dict_path='./ppocr/utils/EN_symbol_dict.txt' +# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。 +``` + +![](../imgs_words_en/word_10.png) + +执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下: +结果如下: +```shell +Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9265879392623901) +``` + +**注意**: + +- 训练上述模型采用的图像分辨率是[1,32,100],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。 +- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。 +- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中NRTR的预处理为您的预处理方法。 + + + +### 4.2 C++推理部署 + +由于C++预处理后处理还未支持NRTR,所以暂未支持 + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ + +1. `NRTR`论文中使用Beam搜索进行解码字符,但是速度较慢,这里默认未使用Beam搜索,以贪婪搜索进行解码字符。 + +## 引用 + +```bibtex +@article{Sheng2019NRTR, + author = {Fenfen Sheng and Zhineng Chen andBo Xu}, + title = {NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition}, + journal = {ICDAR}, + year = {2019}, + url = {http://arxiv.org/abs/1806.00926}, + pages = {781-786} +} +``` diff --git a/doc/doc_ch/algorithm_rec_svtr.md b/doc/doc_ch/algorithm_rec_svtr.md new file mode 100644 index 0000000000000000000000000000000000000000..71e3f5c5e67fd1f92356ae8306af015b96ab06b8 --- /dev/null +++ b/doc/doc_ch/algorithm_rec_svtr.md @@ -0,0 +1,161 @@ +# 场景文本识别算法-SVTR + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [SVTR: Scene Text Recognition with a Single Visual Model]() +> Yongkun Du and Zhineng Chen and Caiyan Jia Xiaoting Yin and Tianlun Zheng and Chenxia Li and Yuning Du and Yu-Gang Jiang +> IJCAI, 2022 + +场景文本识别旨在将自然图像中的文本转录为数字字符序列,从而传达对场景理解至关重要的高级语义。这项任务由于文本变形、字体、遮挡、杂乱背景等方面的变化具有一定的挑战性。先前的方法为提高识别精度做出了许多工作。然而文本识别器除了准确度外,还因为实际需求需要考虑推理速度等因素。 + +### SVTR算法简介 + +主流的场景文本识别模型通常包含两个模块:用于特征提取的视觉模型和用于文本转录的序列模型。这种架构虽然准确,但复杂且效率较低,限制了在实际场景中的应用。SVTR提出了一种用于场景文本识别的单视觉模型,该模型在patch-wise image tokenization框架内,完全摒弃了序列建模,在精度具有竞争力的前提下,模型参数量更少,速度更快,主要有以下几点贡献: +1. 首次发现单视觉模型可以达到与视觉语言模型相媲美甚至更高的准确率,并且其具有效率高和适应多语言的优点,在实际应用中很有前景。 +2. SVTR从字符组件的角度出发,逐渐的合并字符组件,自下而上地完成字符的识别。 +3. SVTR引入了局部和全局Mixing,分别用于提取字符组件特征和字符间依赖关系,与多尺度的特征一起,形成多粒度特征描述。 + + + +SVTR在场景文本识别公开数据集上的精度(%)和模型文件如下: + +* 中文数据集来自于[Chinese Benckmark](https://arxiv.org/abs/2112.15093) ,SVTR的中文训练评估策略遵循该论文。 + +| 模型 |IC13
857 | SVT |IIIT5k
3000 |IC15
1811| SVTP |CUTE80 | Avg_6 |IC15
2077 |IC13
1015 |IC03
867|IC03
860|Avg_10 | Chinese
scene_test| 下载链接 | +|:----------:|:------:|:-----:|:---------:|:------:|:-----:|:-----:|:-----:|:-------:|:-------:|:-----:|:-----:|:---------------------------------------------:|:-----:|:-----:| +| SVTR Tiny | 96.85 | 91.34 | 94.53 | 83.99 | 85.43 | 89.24 | 90.87 | 80.55 | 95.37 | 95.27 | 95.70 | 90.13 | 67.90 | [英文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) / [中文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_ch_train.tar) | +| SVTR Small | 95.92 | 93.04 | 95.03 | 84.70 | 87.91 | 92.01 | 91.63 | 82.72 | 94.88 | 96.08 | 96.28 | 91.02 | 69.00 | [英文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_small_none_ctc_en_train.tar) / [中文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_small_none_ctc_ch_train.tar) | +| SVTR Base | 97.08 | 91.50 | 96.03 | 85.20 | 89.92 | 91.67 | 92.33 | 83.73 | 95.66 | 95.62 | 95.81 | 91.61 | 71.40 | [英文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_base_none_ctc_en_train.tar) / - | +| SVTR Large | 97.20 | 91.65 | 96.30 | 86.58 | 88.37 | 95.14 | 92.82 | 84.54 | 96.35 | 96.54 | 96.74 | 92.24 | 72.10 | [英文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_large_none_ctc_en_train.tar) / [中文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_large_none_ctc_ch_train.tar) | + + + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + + +### 3.1 模型训练 + +#### 数据集准备 + +[英文数据集下载](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) +[中文数据集下载](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download) + +#### 启动训练 + +请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`SVTR`识别模型时需要**更换配置文件**为`SVTR`的[配置文件](../../configs/rec/rec_svtrnet.yml)。 + +具体地,在完成数据准备后,便可以启动训练,训练命令如下: +```shell +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/rec/rec_svtrnet.yml + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_svtrnet.yml +``` + + +### 3.2 评估 + +可下载`SVTR`提供的模型文件和配置文件:[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ,以`SVTR-T`为例,使用如下命令进行评估: + +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 tools/eval.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy +``` + + +### 3.3 预测 + +使用如下命令进行单张图片预测: +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy +# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。 +``` + + + +## 4. 推理部署 + + +### 4.1 Python推理 +首先将训练得到best模型,转换成inference model。下面以基于`SVTR-T`,在英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ),可以使用如下命令进行转换: + +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 tools/export_model.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.save_inference_dir=./inference/rec_svtr_tiny_stn_en +``` + +**注意:** +- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。 +- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应SVTR的`infer_shape`。 + +转换成功后,在目录下有三个文件: +``` +/inference/rec_svtr_tiny_stn_en/ + ├── inference.pdiparams # 识别inference模型的参数文件 + ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略 + └── inference.pdmodel # 识别inference模型的program文件 +``` + + +执行如下命令进行模型推理: + +```shell +python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_svtr_tiny_stn_en/' --rec_algorithm='SVTR' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/ic15_dict.txt' +# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。 +``` +![](../imgs_words_en/word_10.png) + +执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下: +结果如下: +```shell +Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999998807907104) +``` + +**注意**: + +- 如果您调整了训练时的输入分辨率,需要通过参数`rec_image_shape`设置为您需要的识别图像形状。 +- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。 +- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中SVTR的预处理为您的预处理方法。 + + +### 4.2 C++推理部署 + +由于C++预处理后处理还未支持SVTR,所以暂未支持 + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ + +1. 由于`SVTR`使用的op算子大多为矩阵相乘,在GPU环境下,速度具有优势,但在CPU开启mkldnn加速环境下,`SVTR`相比于被优化的卷积网络没有优势。 diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 20aaf48e119d68e6c37ce9246a87701fb149d5e7..548832fb0d116ba2de622bd97562b591d74501d8 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask from .make_pse_gt import MakePseGt from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ - SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg + SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment from .copy_paste import CopyPaste diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 2f70b51a3b88422274353046209c6d0d4dc79489..7483dffe5b6d9a0a2204702757fcb49762a1cc7a 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -207,25 +207,6 @@ class PRENResizeImg(object): return data -class SVTRRecResizeImg(object): - def __init__(self, - image_shape, - infer_mode=False, - character_dict_path='./ppocr/utils/ppocr_keys_v1.txt', - padding=True, - **kwargs): - self.image_shape = image_shape - self.infer_mode = infer_mode - self.character_dict_path = character_dict_path - self.padding = padding - - def __call__(self, data): - img = data['image'] - norm_img = resize_norm_img_svtr(img, self.image_shape, self.padding) - data['image'] = norm_img - return data - - def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): imgC, imgH, imgW_min, imgW_max = image_shape h = img.shape[0] @@ -344,58 +325,6 @@ def resize_norm_img_srn(img, image_shape): return np.reshape(img_black, (c, row, col)).astype(np.float32) -def resize_norm_img_svtr(img, image_shape, padding=False): - imgC, imgH, imgW = image_shape - h = img.shape[0] - w = img.shape[1] - if not padding: - if h > 2.0 * w: - image = Image.fromarray(img) - image1 = image.rotate(90, expand=True) - image2 = image.rotate(-90, expand=True) - img1 = np.array(image1) - img2 = np.array(image2) - else: - img1 = copy.deepcopy(img) - img2 = copy.deepcopy(img) - - resized_image = cv2.resize( - img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) - resized_image1 = cv2.resize( - img1, (imgW, imgH), interpolation=cv2.INTER_LINEAR) - resized_image2 = cv2.resize( - img2, (imgW, imgH), interpolation=cv2.INTER_LINEAR) - resized_w = imgW - else: - ratio = w / float(h) - if math.ceil(imgH * ratio) > imgW: - resized_w = imgW - else: - resized_w = int(math.ceil(imgH * ratio)) - resized_image = cv2.resize(img, (resized_w, imgH)) - resized_image = resized_image.astype('float32') - resized_image1 = resized_image1.astype('float32') - resized_image2 = resized_image2.astype('float32') - if image_shape[0] == 1: - resized_image = resized_image / 255 - resized_image = resized_image[np.newaxis, :] - else: - resized_image = resized_image.transpose((2, 0, 1)) / 255 - resized_image1 = resized_image1.transpose((2, 0, 1)) / 255 - resized_image2 = resized_image2.transpose((2, 0, 1)) / 255 - resized_image -= 0.5 - resized_image /= 0.5 - resized_image1 -= 0.5 - resized_image1 /= 0.5 - resized_image2 -= 0.5 - resized_image2 /= 0.5 - padding_im = np.zeros((3, imgC, imgH, imgW), dtype=np.float32) - padding_im[0, :, :, 0:resized_w] = resized_image - padding_im[1, :, :, 0:resized_w] = resized_image1 - padding_im[2, :, :, 0:resized_w] = resized_image2 - return padding_im - - def srn_other_inputs(image_shape, num_heads, max_text_length): imgC, imgH, imgW = image_shape diff --git a/ppocr/modeling/backbones/rec_svtrnet.py b/ppocr/modeling/backbones/rec_svtrnet.py index 5ded74378c60e6f08a4adf68671afaa1168737b6..b699386c52afc17f556fc073d5a4e13216dd23ec 100644 --- a/ppocr/modeling/backbones/rec_svtrnet.py +++ b/ppocr/modeling/backbones/rec_svtrnet.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import Callable from paddle import ParamAttr from paddle.nn.initializer import KaimingNormal import numpy as np @@ -228,11 +227,8 @@ class Block(nn.Layer): super().__init__() if isinstance(norm_layer, str): self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) - elif isinstance(norm_layer, Callable): - self.norm1 = norm_layer(dim) else: - raise TypeError( - "The norm_layer must be str or paddle.nn.layer.Layer class") + self.norm1 = norm_layer(dim) if mixer == 'Global' or mixer == 'Local': self.mixer = Attention( dim, @@ -250,15 +246,11 @@ class Block(nn.Layer): else: raise TypeError("The mixer must be one of [Global, Local, Conv]") - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() if isinstance(norm_layer, str): self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) - elif isinstance(norm_layer, Callable): - self.norm2 = norm_layer(dim) else: - raise TypeError( - "The norm_layer must be str or paddle.nn.layer.Layer class") + self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp_ratio = mlp_ratio self.mlp = Mlp(in_features=dim, @@ -330,8 +322,6 @@ class PatchEmbed(nn.Layer): act=nn.GELU, bias_attr=None), ConvBNLayer( - embed_dim // 2, - embed_dim, in_channels=embed_dim // 2, out_channels=embed_dim, kernel_size=3, diff --git a/ppocr/modeling/transforms/stn.py b/ppocr/modeling/transforms/stn.py index 1b15d5b8a7b7a1b1ab686d20acea750437463939..6f2bdda050f217d8253740001901fbff4065782a 100644 --- a/ppocr/modeling/transforms/stn.py +++ b/ppocr/modeling/transforms/stn.py @@ -128,8 +128,6 @@ class STN_ON(nn.Layer): self.out_channels = in_channels def forward(self, image): - if len(image.shape)==5: - image = image.reshape([0, image.shape[-3], image.shape[-2], image.shape[-1]]) stn_input = paddle.nn.functional.interpolate( image, self.tps_inputsize, mode="bilinear", align_corners=True) stn_img_feat, ctrl_points = self.stn_head(stn_input) diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 390f6f4560f9814a3af757a4fd16c55fe93d01f9..f50b5f1c5f8e617066bb47636c8f4d2b171b6ecb 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess from .fce_postprocess import FCEPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ - SEEDLabelDecode, PRENLabelDecode, SVTRLabelDecode + SEEDLabelDecode, PRENLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess @@ -42,7 +42,7 @@ def build_post_process(config, global_config=None): 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', - 'DistillationSARLabelDecode', 'SVTRLabelDecode' + 'DistillationSARLabelDecode' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 50f11f899fb4dd49da75199095772a92cc4a8d7b..bf0fd890bf25949361665d212bf8e1a657054e5b 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -752,40 +752,3 @@ class PRENLabelDecode(BaseRecLabelDecode): return text label = self.decode(label) return text, label - - -class SVTRLabelDecode(BaseRecLabelDecode): - """ Convert between text-label and text-index """ - - def __init__(self, character_dict_path=None, use_space_char=False, - **kwargs): - super(SVTRLabelDecode, self).__init__(character_dict_path, - use_space_char) - - def __call__(self, preds, label=None, *args, **kwargs): - if isinstance(preds, tuple): - preds = preds[-1] - if isinstance(preds, paddle.Tensor): - preds = preds.numpy() - preds_idx = preds.argmax(axis=-1) - preds_prob = preds.max(axis=-1) - - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) - return_text = [] - for i in range(0, len(text), 3): - text0 = text[i] - text1 = text[i + 1] - text2 = text[i + 2] - - text_pred = [text0[0], text1[0], text2[0]] - text_prob = [text0[1], text1[1], text2[1]] - id_max = text_prob.index(max(text_prob)) - return_text.append((text_pred[id_max], text_prob[id_max])) - if label is None: - return return_text - label = self.decode(label) - return return_text, label - - def add_special_char(self, dict_character): - dict_character = ['blank'] + dict_character - return dict_character \ No newline at end of file