From d2c11969c27032e5178f62e09f95c71e35104851 Mon Sep 17 00:00:00 2001
From: Topdu <784990967@qq.com>
Date: Thu, 28 Apr 2022 13:03:47 +0000
Subject: [PATCH] rm svtrlabeldecode resize
---
doc/doc_ch/algorithm_rec_nrtr.md | 6 +--
doc/doc_ch/algorithm_rec_svtr.md | 17 ++++---
ppocr/data/imaug/__init__.py | 2 +-
ppocr/data/imaug/rec_img_aug.py | 71 ----------------------------
ppocr/modeling/transforms/stn.py | 2 -
ppocr/postprocess/__init__.py | 4 +-
ppocr/postprocess/rec_postprocess.py | 37 ---------------
7 files changed, 14 insertions(+), 125 deletions(-)
diff --git a/doc/doc_ch/algorithm_rec_nrtr.md b/doc/doc_ch/algorithm_rec_nrtr.md
index d5da4f3f..0151247b 100644
--- a/doc/doc_ch/algorithm_rec_nrtr.md
+++ b/doc/doc_ch/algorithm_rec_nrtr.md
@@ -62,7 +62,7 @@ python3 tools/eval.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
-python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.load_static_weights=false
+python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
```
@@ -72,11 +72,11 @@ python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='
### 4.1 Python推理
-首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](#model)),可以使用如下命令进行转换:
+首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar)),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
-python3 tools/export_model.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.save_inference_dir=./inference/rec_mtb_nrtr/ Global.load_static_weights=False
+python3 tools/export_model.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.save_inference_dir=./inference/rec_mtb_nrtr/
```
执行如下命令进行模型推理:
diff --git a/doc/doc_ch/algorithm_rec_svtr.md b/doc/doc_ch/algorithm_rec_svtr.md
index 09b56474..8c3871c8 100644
--- a/doc/doc_ch/algorithm_rec_svtr.md
+++ b/doc/doc_ch/algorithm_rec_svtr.md
@@ -26,10 +26,13 @@
1. 首次发现单视觉模型可以达到与视觉语言模型相媲美甚至更高的准确率,并且其具有效率高和适应多语言的优点,在实际应用中很有前景。
2. SVTR从字符组件的角度出发,逐渐的合并字符组件,自下而上地完成字符的识别。
3. SVTR引入了局部和全局Mixing,分别用于提取字符组件特征和字符间依赖关系,与多尺度的特征一起,形成多粒度特征描述。
+4. SVTR-L在识别英文和中文场景文本方面实现了最先进的性能。SVTR-T平衡精确度和效率,在一个NVIDIA 1080Ti GPU中,每个英文图像文本平均消耗4.5ms。
-`SVTR`在场景文本识别公开数据集上的精度(%)和模型文件如下:
+SVTR在场景文本识别公开数据集上的精度(%)和模型文件如下:
+
+* 中文数据集来自于[Chinese Benckmark](https://arxiv.org/abs/2112.15093) ,SVTR的中文训练评估策略遵循该论文。
| SVTR |IC13
857 | SVT |IIIT5k
3000 |IC15
1811| SVTP |CUTE80 | Avg_6 |IC15
2077 |IC13
1015 |IC03
867|IC03
860|Avg_10 |Chinese| 英文
链接 | 中文
链接 |
|:-----:|:------:|:-----:|:---------:|:------:|:-----:|:-----:|:-----:|:-------:|:-------:|:-----:|:-----:|:-----:|:-----:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|
@@ -56,10 +59,6 @@
[英文数据集下载](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here)
[中文数据集下载](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download)
-**注意:**
-1. 训练`SVTR`时,需将将配置文件中的测试数据集路径设置为本地的评估数据集路径,例如将中文的`scene_test`数据集修改为`scene_val`。
-2. 训练`SVTR`时,需将配置文件中的`SVTRLableDecode`修改为`CTCLabelDecode`,将`SVTRRecResizeImg`修改为`RecResizeImg`。
-
#### 启动训练
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`SVTR`识别模型时需要**更换配置文件**为`SVTR`的[配置文件](../../configs/rec/rec_svtrnet.yml)。
@@ -67,7 +66,7 @@
### 3.2 评估
-可下载`SVTR`提供模型文件和配置文件[模型下载](#model),以`SVTR-T`为例,使用如下命令进行评估:
+可下载`SVTR`提供模型文件和配置文件:[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ,以`SVTR-T`为例,使用如下命令进行评估:
```shell
# 注意将pretrained_model的路径设置为本地路径。
@@ -81,7 +80,7 @@ python3 tools/eval.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_s
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
-python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.load_static_weights=false
+python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
```
@@ -91,11 +90,11 @@ python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6glo
### 4.1 Python推理
-首先将训练得到best模型,转换成inference model。下面以基于`SVTR-T`,在英文数据集训练的模型为例([模型下载地址](#model)),可以使用如下命令进行转换:
+首先将训练得到best模型,转换成inference model。下面以基于`SVTR-T`,在英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
-python3 tools/export_model.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.save_inference_dir=./inference/rec_svtr_tiny_stn_en/ Global.load_static_weights=False
+python3 tools/export_model.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.save_inference_dir=./inference/rec_svtr_tiny_stn_en
```
执行如下命令进行模型推理:
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 20aaf48e..548832fb 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
- SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg
+ SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 2f70b51a..7483dffe 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -207,25 +207,6 @@ class PRENResizeImg(object):
return data
-class SVTRRecResizeImg(object):
- def __init__(self,
- image_shape,
- infer_mode=False,
- character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
- padding=True,
- **kwargs):
- self.image_shape = image_shape
- self.infer_mode = infer_mode
- self.character_dict_path = character_dict_path
- self.padding = padding
-
- def __call__(self, data):
- img = data['image']
- norm_img = resize_norm_img_svtr(img, self.image_shape, self.padding)
- data['image'] = norm_img
- return data
-
-
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
@@ -344,58 +325,6 @@ def resize_norm_img_srn(img, image_shape):
return np.reshape(img_black, (c, row, col)).astype(np.float32)
-def resize_norm_img_svtr(img, image_shape, padding=False):
- imgC, imgH, imgW = image_shape
- h = img.shape[0]
- w = img.shape[1]
- if not padding:
- if h > 2.0 * w:
- image = Image.fromarray(img)
- image1 = image.rotate(90, expand=True)
- image2 = image.rotate(-90, expand=True)
- img1 = np.array(image1)
- img2 = np.array(image2)
- else:
- img1 = copy.deepcopy(img)
- img2 = copy.deepcopy(img)
-
- resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
- resized_image1 = cv2.resize(
- img1, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
- resized_image2 = cv2.resize(
- img2, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
- resized_w = imgW
- else:
- ratio = w / float(h)
- if math.ceil(imgH * ratio) > imgW:
- resized_w = imgW
- else:
- resized_w = int(math.ceil(imgH * ratio))
- resized_image = cv2.resize(img, (resized_w, imgH))
- resized_image = resized_image.astype('float32')
- resized_image1 = resized_image1.astype('float32')
- resized_image2 = resized_image2.astype('float32')
- if image_shape[0] == 1:
- resized_image = resized_image / 255
- resized_image = resized_image[np.newaxis, :]
- else:
- resized_image = resized_image.transpose((2, 0, 1)) / 255
- resized_image1 = resized_image1.transpose((2, 0, 1)) / 255
- resized_image2 = resized_image2.transpose((2, 0, 1)) / 255
- resized_image -= 0.5
- resized_image /= 0.5
- resized_image1 -= 0.5
- resized_image1 /= 0.5
- resized_image2 -= 0.5
- resized_image2 /= 0.5
- padding_im = np.zeros((3, imgC, imgH, imgW), dtype=np.float32)
- padding_im[0, :, :, 0:resized_w] = resized_image
- padding_im[1, :, :, 0:resized_w] = resized_image1
- padding_im[2, :, :, 0:resized_w] = resized_image2
- return padding_im
-
-
def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
diff --git a/ppocr/modeling/transforms/stn.py b/ppocr/modeling/transforms/stn.py
index 1b15d5b8..6f2bdda0 100644
--- a/ppocr/modeling/transforms/stn.py
+++ b/ppocr/modeling/transforms/stn.py
@@ -128,8 +128,6 @@ class STN_ON(nn.Layer):
self.out_channels = in_channels
def forward(self, image):
- if len(image.shape)==5:
- image = image.reshape([0, image.shape[-3], image.shape[-2], image.shape[-1]])
stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input)
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 390f6f45..f50b5f1c 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
- SEEDLabelDecode, PRENLabelDecode, SVTRLabelDecode
+ SEEDLabelDecode, PRENLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
- 'DistillationSARLabelDecode', 'SVTRLabelDecode'
+ 'DistillationSARLabelDecode'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 50f11f89..bf0fd890 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -752,40 +752,3 @@ class PRENLabelDecode(BaseRecLabelDecode):
return text
label = self.decode(label)
return text, label
-
-
-class SVTRLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
-
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(SVTRLabelDecode, self).__init__(character_dict_path,
- use_space_char)
-
- def __call__(self, preds, label=None, *args, **kwargs):
- if isinstance(preds, tuple):
- preds = preds[-1]
- if isinstance(preds, paddle.Tensor):
- preds = preds.numpy()
- preds_idx = preds.argmax(axis=-1)
- preds_prob = preds.max(axis=-1)
-
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
- return_text = []
- for i in range(0, len(text), 3):
- text0 = text[i]
- text1 = text[i + 1]
- text2 = text[i + 2]
-
- text_pred = [text0[0], text1[0], text2[0]]
- text_prob = [text0[1], text1[1], text2[1]]
- id_max = text_prob.index(max(text_prob))
- return_text.append((text_pred[id_max], text_prob[id_max]))
- if label is None:
- return return_text
- label = self.decode(label)
- return return_text, label
-
- def add_special_char(self, dict_character):
- dict_character = ['blank'] + dict_character
- return dict_character
\ No newline at end of file
--
GitLab