未验证 提交 21a0efea 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #6080 from Topdu/dygraph

add svtr and nrtr docs
...@@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer): ...@@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer):
config['Optimizer'], config['Optimizer'],
epochs=config['Global']['epoch_num'], epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_dataloader), step_each_epoch=len(train_dataloader),
parameters=model.parameters()) model=model)
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
......
# 场景文本识别算法-NRTR
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition](https://arxiv.org/abs/1806.00926)
> Fenfen Sheng and Zhineng Chen and Bo Xu
> ICDAR, 2019
<a name="model"></a>
`NRTR`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
|模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- |
|NRTR|MTB|[rec_mtb_nrtr.yml](../../configs/rec/rec_mtb_nrtr.yml)|84.21%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
<a name="3-1"></a>
### 3.1 模型训练
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`NRTR`识别模型时需要**更换配置文件**`NRTR`[配置文件](../../configs/rec/rec_mtb_nrtr.yml)
#### 启动训练
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
```shell
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/rec/rec_mtb_nrtr.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_mtb_nrtr.yml
```
<a name="3-2"></a>
### 3.2 评估
可下载已训练完成的[模型文件](#model),使用如下命令进行评估:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/eval.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy
```
<a name="3-3"></a>
### 3.3 预测
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
```
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) ),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.save_inference_dir=./inference/rec_mtb_nrtr/
```
**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应NRTR的`infer_shape`
转换成功后,在目录下有三个文件:
```
/inference/rec_mtb_nrtr/
├── inference.pdiparams # 识别inference模型的参数文件
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
└── inference.pdmodel # 识别inference模型的program文件
```
执行如下命令进行模型推理:
```shell
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_mtb_nrtr/' --rec_algorithm='NRTR' --rec_image_shape='1,32,100' --rec_char_dict_path='./ppocr/utils/EN_symbol_dict.txt'
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
```
![](../imgs_words_en/word_10.png)
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
结果如下:
```shell
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9265879392623901)
```
**注意**
- 训练上述模型采用的图像分辨率是[1,32,100],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中NRTR的预处理为您的预处理方法。
<a name="4-2"></a>
### 4.2 C++推理部署
由于C++预处理后处理还未支持NRTR,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
1. `NRTR`论文中使用Beam搜索进行解码字符,但是速度较慢,这里默认未使用Beam搜索,以贪婪搜索进行解码字符。
## 引用
```bibtex
@article{Sheng2019NRTR,
author = {Fenfen Sheng and Zhineng Chen andBo Xu},
title = {NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition},
journal = {ICDAR},
year = {2019},
url = {http://arxiv.org/abs/1806.00926},
pages = {781-786}
}
```
# 场景文本识别算法-SVTR
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [SVTR: Scene Text Recognition with a Single Visual Model]()
> Yongkun Du and Zhineng Chen and Caiyan Jia Xiaoting Yin and Tianlun Zheng and Chenxia Li and Yuning Du and Yu-Gang Jiang
> IJCAI, 2022
场景文本识别旨在将自然图像中的文本转录为数字字符序列,从而传达对场景理解至关重要的高级语义。这项任务由于文本变形、字体、遮挡、杂乱背景等方面的变化具有一定的挑战性。先前的方法为提高识别精度做出了许多工作。然而文本识别器除了准确度外,还因为实际需求需要考虑推理速度等因素。
### SVTR算法简介
主流的场景文本识别模型通常包含两个模块:用于特征提取的视觉模型和用于文本转录的序列模型。这种架构虽然准确,但复杂且效率较低,限制了在实际场景中的应用。SVTR提出了一种用于场景文本识别的单视觉模型,该模型在patch-wise image tokenization框架内,完全摒弃了序列建模,在精度具有竞争力的前提下,模型参数量更少,速度更快,主要有以下几点贡献:
1. 首次发现单视觉模型可以达到与视觉语言模型相媲美甚至更高的准确率,并且其具有效率高和适应多语言的优点,在实际应用中很有前景。
2. SVTR从字符组件的角度出发,逐渐的合并字符组件,自下而上地完成字符的识别。
3. SVTR引入了局部和全局Mixing,分别用于提取字符组件特征和字符间依赖关系,与多尺度的特征一起,形成多粒度特征描述。
<a name="model"></a>
SVTR在场景文本识别公开数据集上的精度(%)和模型文件如下:
* 中文数据集来自于[Chinese Benckmark](https://arxiv.org/abs/2112.15093) ,SVTR的中文训练评估策略遵循该论文。
| 模型 |IC13<br/>857 | SVT |IIIT5k<br/>3000 |IC15<br/>1811| SVTP |CUTE80 | Avg_6 |IC15<br/>2077 |IC13<br/>1015 |IC03<br/>867|IC03<br/>860|Avg_10 | Chinese<br/>scene_test| 下载链接 |
|:----------:|:------:|:-----:|:---------:|:------:|:-----:|:-----:|:-----:|:-------:|:-------:|:-----:|:-----:|:---------------------------------------------:|:-----:|:-----:|
| SVTR Tiny | 96.85 | 91.34 | 94.53 | 83.99 | 85.43 | 89.24 | 90.87 | 80.55 | 95.37 | 95.27 | 95.70 | 90.13 | 67.90 | [英文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) / [中文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_ch_train.tar) |
| SVTR Small | 95.92 | 93.04 | 95.03 | 84.70 | 87.91 | 92.01 | 91.63 | 82.72 | 94.88 | 96.08 | 96.28 | 91.02 | 69.00 | [英文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_small_none_ctc_en_train.tar) / [中文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_small_none_ctc_ch_train.tar) |
| SVTR Base | 97.08 | 91.50 | 96.03 | 85.20 | 89.92 | 91.67 | 92.33 | 83.73 | 95.66 | 95.62 | 95.81 | 91.61 | 71.40 | [英文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_base_none_ctc_en_train.tar) / - |
| SVTR Large | 97.20 | 91.65 | 96.30 | 86.58 | 88.37 | 95.14 | 92.82 | 84.54 | 96.35 | 96.54 | 96.74 | 92.24 | 72.10 | [英文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_large_none_ctc_en_train.tar) / [中文](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_large_none_ctc_ch_train.tar) |
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
<a name="3-1"></a>
### 3.1 模型训练
#### 数据集准备
[英文数据集下载](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here)
[中文数据集下载](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download)
#### 启动训练
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`SVTR`识别模型时需要**更换配置文件**`SVTR`[配置文件](../../configs/rec/rec_svtrnet.yml)
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
```shell
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/rec/rec_svtrnet.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_svtrnet.yml
```
<a name="3-2"></a>
### 3.2 评估
可下载`SVTR`提供的模型文件和配置文件:[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ,以`SVTR-T`为例,使用如下命令进行评估:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/eval.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy
```
<a name="3-3"></a>
### 3.3 预测
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
```
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。下面以基于`SVTR-T`,在英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.save_inference_dir=./inference/rec_svtr_tiny_stn_en
```
**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。
- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应SVTR的`infer_shape`
转换成功后,在目录下有三个文件:
```
/inference/rec_svtr_tiny_stn_en/
├── inference.pdiparams # 识别inference模型的参数文件
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
└── inference.pdmodel # 识别inference模型的program文件
```
执行如下命令进行模型推理:
```shell
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_svtr_tiny_stn_en/' --rec_algorithm='SVTR' --rec_image_shape='3,64,256' --rec_char_dict_path='./ppocr/utils/ic15_dict.txt'
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
```
![](../imgs_words_en/word_10.png)
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
结果如下:
```shell
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999998807907104)
```
**注意**
- 如果您调整了训练时的输入分辨率,需要通过参数`rec_image_shape`设置为您需要的识别图像形状。
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中SVTR的预处理为您的预处理方法。
<a name="4-2"></a>
### 4.2 C++推理部署
由于C++预处理后处理还未支持SVTR,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
1. 由于`SVTR`使用的op算子大多为矩阵相乘,在GPU环境下,速度具有优势,但在CPU开启mkldnn加速环境下,`SVTR`相比于被优化的卷积网络没有优势。
...@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask ...@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
from .ssl_img_aug import SSLRotateResize from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste from .copy_paste import CopyPaste
......
...@@ -207,25 +207,6 @@ class PRENResizeImg(object): ...@@ -207,25 +207,6 @@ class PRENResizeImg(object):
return data return data
class SVTRRecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
img = data['image']
norm_img = resize_norm_img_svtr(img, self.image_shape, self.padding)
data['image'] = norm_img
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0] h = img.shape[0]
...@@ -344,58 +325,6 @@ def resize_norm_img_srn(img, image_shape): ...@@ -344,58 +325,6 @@ def resize_norm_img_srn(img, image_shape):
return np.reshape(img_black, (c, row, col)).astype(np.float32) return np.reshape(img_black, (c, row, col)).astype(np.float32)
def resize_norm_img_svtr(img, image_shape, padding=False):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
if h > 2.0 * w:
image = Image.fromarray(img)
image1 = image.rotate(90, expand=True)
image2 = image.rotate(-90, expand=True)
img1 = np.array(image1)
img2 = np.array(image2)
else:
img1 = copy.deepcopy(img)
img2 = copy.deepcopy(img)
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image1 = cv2.resize(
img1, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image2 = cv2.resize(
img2, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
else:
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image1 = resized_image1.astype('float32')
resized_image2 = resized_image2.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image1 = resized_image1.transpose((2, 0, 1)) / 255
resized_image2 = resized_image2.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resized_image1 -= 0.5
resized_image1 /= 0.5
resized_image2 -= 0.5
resized_image2 /= 0.5
padding_im = np.zeros((3, imgC, imgH, imgW), dtype=np.float32)
padding_im[0, :, :, 0:resized_w] = resized_image
padding_im[1, :, :, 0:resized_w] = resized_image1
padding_im[2, :, :, 0:resized_w] = resized_image2
return padding_im
def srn_other_inputs(image_shape, num_heads, max_text_length): def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import Callable
from paddle import ParamAttr from paddle import ParamAttr
from paddle.nn.initializer import KaimingNormal from paddle.nn.initializer import KaimingNormal
import numpy as np import numpy as np
...@@ -228,11 +227,8 @@ class Block(nn.Layer): ...@@ -228,11 +227,8 @@ class Block(nn.Layer):
super().__init__() super().__init__()
if isinstance(norm_layer, str): if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
elif isinstance(norm_layer, Callable):
self.norm1 = norm_layer(dim)
else: else:
raise TypeError( self.norm1 = norm_layer(dim)
"The norm_layer must be str or paddle.nn.layer.Layer class")
if mixer == 'Global' or mixer == 'Local': if mixer == 'Global' or mixer == 'Local':
self.mixer = Attention( self.mixer = Attention(
dim, dim,
...@@ -250,15 +246,11 @@ class Block(nn.Layer): ...@@ -250,15 +246,11 @@ class Block(nn.Layer):
else: else:
raise TypeError("The mixer must be one of [Global, Local, Conv]") raise TypeError("The mixer must be one of [Global, Local, Conv]")
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
if isinstance(norm_layer, str): if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
elif isinstance(norm_layer, Callable):
self.norm2 = norm_layer(dim)
else: else:
raise TypeError( self.norm2 = norm_layer(dim)
"The norm_layer must be str or paddle.nn.layer.Layer class")
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.mlp = Mlp(in_features=dim, self.mlp = Mlp(in_features=dim,
...@@ -330,8 +322,6 @@ class PatchEmbed(nn.Layer): ...@@ -330,8 +322,6 @@ class PatchEmbed(nn.Layer):
act=nn.GELU, act=nn.GELU,
bias_attr=None), bias_attr=None),
ConvBNLayer( ConvBNLayer(
embed_dim // 2,
embed_dim,
in_channels=embed_dim // 2, in_channels=embed_dim // 2,
out_channels=embed_dim, out_channels=embed_dim,
kernel_size=3, kernel_size=3,
......
...@@ -128,8 +128,6 @@ class STN_ON(nn.Layer): ...@@ -128,8 +128,6 @@ class STN_ON(nn.Layer):
self.out_channels = in_channels self.out_channels = in_channels
def forward(self, image): def forward(self, image):
if len(image.shape)==5:
image = image.reshape([0, image.shape[-3], image.shape[-2], image.shape[-1]])
stn_input = paddle.nn.functional.interpolate( stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True) image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input) stn_img_feat, ctrl_points = self.stn_head(stn_input)
......
...@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess ...@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, SVTRLabelDecode SEEDLabelDecode, PRENLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
...@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None): ...@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'SVTRLabelDecode' 'DistillationSARLabelDecode'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -752,40 +752,3 @@ class PRENLabelDecode(BaseRecLabelDecode): ...@@ -752,40 +752,3 @@ class PRENLabelDecode(BaseRecLabelDecode):
return text return text
label = self.decode(label) label = self.decode(label)
return text, label return text, label
class SVTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SVTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=-1)
preds_prob = preds.max(axis=-1)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
return_text = []
for i in range(0, len(text), 3):
text0 = text[i]
text1 = text[i + 1]
text2 = text[i + 2]
text_pred = [text0[0], text1[0], text2[0]]
text_prob = [text0[1], text1[1], text2[1]]
id_max = text_prob.index(max(text_prob))
return_text.append((text_pred[id_max], text_prob[id_max]))
if label is None:
return return_text
label = self.decode(label)
return return_text, label
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册