diff --git a/doc/doc_ch/algorithm_rec_nrtr.md b/doc/doc_ch/algorithm_rec_nrtr.md
index d5da4f3f2a395c1d8b4360a11bb9841805ab2155..0151247bd44833d71aa216bce56fe0526d531cae 100644
--- a/doc/doc_ch/algorithm_rec_nrtr.md
+++ b/doc/doc_ch/algorithm_rec_nrtr.md
@@ -62,7 +62,7 @@ python3 tools/eval.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
-python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.load_static_weights=false
+python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
```
@@ -72,11 +72,11 @@ python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='
### 4.1 Python推理
-首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](#model)),可以使用如下命令进行转换:
+首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar)),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
-python3 tools/export_model.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.save_inference_dir=./inference/rec_mtb_nrtr/ Global.load_static_weights=False
+python3 tools/export_model.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.save_inference_dir=./inference/rec_mtb_nrtr/
```
执行如下命令进行模型推理:
diff --git a/doc/doc_ch/algorithm_rec_svtr.md b/doc/doc_ch/algorithm_rec_svtr.md
index 09b564740018e50e2aa8c459aefd7d5896b3f0cf..8c3871c8df1033ca5ee32a507e5de008a0671c83 100644
--- a/doc/doc_ch/algorithm_rec_svtr.md
+++ b/doc/doc_ch/algorithm_rec_svtr.md
@@ -26,10 +26,13 @@
1. 首次发现单视觉模型可以达到与视觉语言模型相媲美甚至更高的准确率,并且其具有效率高和适应多语言的优点,在实际应用中很有前景。
2. SVTR从字符组件的角度出发,逐渐的合并字符组件,自下而上地完成字符的识别。
3. SVTR引入了局部和全局Mixing,分别用于提取字符组件特征和字符间依赖关系,与多尺度的特征一起,形成多粒度特征描述。
+4. SVTR-L在识别英文和中文场景文本方面实现了最先进的性能。SVTR-T平衡精确度和效率,在一个NVIDIA 1080Ti GPU中,每个英文图像文本平均消耗4.5ms。
-`SVTR`在场景文本识别公开数据集上的精度(%)和模型文件如下:
+SVTR在场景文本识别公开数据集上的精度(%)和模型文件如下:
+
+* 中文数据集来自于[Chinese Benckmark](https://arxiv.org/abs/2112.15093) ,SVTR的中文训练评估策略遵循该论文。
| SVTR |IC13
857 | SVT |IIIT5k
3000 |IC15
1811| SVTP |CUTE80 | Avg_6 |IC15
2077 |IC13
1015 |IC03
867|IC03
860|Avg_10 |Chinese| 英文
链接 | 中文
链接 |
|:-----:|:------:|:-----:|:---------:|:------:|:-----:|:-----:|:-----:|:-------:|:-------:|:-----:|:-----:|:-----:|:-----:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|
@@ -56,10 +59,6 @@
[英文数据集下载](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here)
[中文数据集下载](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download)
-**注意:**
-1. 训练`SVTR`时,需将将配置文件中的测试数据集路径设置为本地的评估数据集路径,例如将中文的`scene_test`数据集修改为`scene_val`。
-2. 训练`SVTR`时,需将配置文件中的`SVTRLableDecode`修改为`CTCLabelDecode`,将`SVTRRecResizeImg`修改为`RecResizeImg`。
-
#### 启动训练
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`SVTR`识别模型时需要**更换配置文件**为`SVTR`的[配置文件](../../configs/rec/rec_svtrnet.yml)。
@@ -67,7 +66,7 @@
### 3.2 评估
-可下载`SVTR`提供模型文件和配置文件[模型下载](#model),以`SVTR-T`为例,使用如下命令进行评估:
+可下载`SVTR`提供模型文件和配置文件:[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ,以`SVTR-T`为例,使用如下命令进行评估:
```shell
# 注意将pretrained_model的路径设置为本地路径。
@@ -81,7 +80,7 @@ python3 tools/eval.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_s
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
-python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.load_static_weights=false
+python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
```
@@ -91,11 +90,11 @@ python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6glo
### 4.1 Python推理
-首先将训练得到best模型,转换成inference model。下面以基于`SVTR-T`,在英文数据集训练的模型为例([模型下载地址](#model)),可以使用如下命令进行转换:
+首先将训练得到best模型,转换成inference model。下面以基于`SVTR-T`,在英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
-python3 tools/export_model.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.save_inference_dir=./inference/rec_svtr_tiny_stn_en/ Global.load_static_weights=False
+python3 tools/export_model.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.save_inference_dir=./inference/rec_svtr_tiny_stn_en
```
执行如下命令进行模型推理:
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 20aaf48e119d68e6c37ce9246a87701fb149d5e7..548832fb0d116ba2de622bd97562b591d74501d8 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
- SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg
+ SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 2f70b51a3b88422274353046209c6d0d4dc79489..7483dffe5b6d9a0a2204702757fcb49762a1cc7a 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -207,25 +207,6 @@ class PRENResizeImg(object):
return data
-class SVTRRecResizeImg(object):
- def __init__(self,
- image_shape,
- infer_mode=False,
- character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
- padding=True,
- **kwargs):
- self.image_shape = image_shape
- self.infer_mode = infer_mode
- self.character_dict_path = character_dict_path
- self.padding = padding
-
- def __call__(self, data):
- img = data['image']
- norm_img = resize_norm_img_svtr(img, self.image_shape, self.padding)
- data['image'] = norm_img
- return data
-
-
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
@@ -344,58 +325,6 @@ def resize_norm_img_srn(img, image_shape):
return np.reshape(img_black, (c, row, col)).astype(np.float32)
-def resize_norm_img_svtr(img, image_shape, padding=False):
- imgC, imgH, imgW = image_shape
- h = img.shape[0]
- w = img.shape[1]
- if not padding:
- if h > 2.0 * w:
- image = Image.fromarray(img)
- image1 = image.rotate(90, expand=True)
- image2 = image.rotate(-90, expand=True)
- img1 = np.array(image1)
- img2 = np.array(image2)
- else:
- img1 = copy.deepcopy(img)
- img2 = copy.deepcopy(img)
-
- resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
- resized_image1 = cv2.resize(
- img1, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
- resized_image2 = cv2.resize(
- img2, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
- resized_w = imgW
- else:
- ratio = w / float(h)
- if math.ceil(imgH * ratio) > imgW:
- resized_w = imgW
- else:
- resized_w = int(math.ceil(imgH * ratio))
- resized_image = cv2.resize(img, (resized_w, imgH))
- resized_image = resized_image.astype('float32')
- resized_image1 = resized_image1.astype('float32')
- resized_image2 = resized_image2.astype('float32')
- if image_shape[0] == 1:
- resized_image = resized_image / 255
- resized_image = resized_image[np.newaxis, :]
- else:
- resized_image = resized_image.transpose((2, 0, 1)) / 255
- resized_image1 = resized_image1.transpose((2, 0, 1)) / 255
- resized_image2 = resized_image2.transpose((2, 0, 1)) / 255
- resized_image -= 0.5
- resized_image /= 0.5
- resized_image1 -= 0.5
- resized_image1 /= 0.5
- resized_image2 -= 0.5
- resized_image2 /= 0.5
- padding_im = np.zeros((3, imgC, imgH, imgW), dtype=np.float32)
- padding_im[0, :, :, 0:resized_w] = resized_image
- padding_im[1, :, :, 0:resized_w] = resized_image1
- padding_im[2, :, :, 0:resized_w] = resized_image2
- return padding_im
-
-
def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
diff --git a/ppocr/modeling/transforms/stn.py b/ppocr/modeling/transforms/stn.py
index 1b15d5b8a7b7a1b1ab686d20acea750437463939..6f2bdda050f217d8253740001901fbff4065782a 100644
--- a/ppocr/modeling/transforms/stn.py
+++ b/ppocr/modeling/transforms/stn.py
@@ -128,8 +128,6 @@ class STN_ON(nn.Layer):
self.out_channels = in_channels
def forward(self, image):
- if len(image.shape)==5:
- image = image.reshape([0, image.shape[-3], image.shape[-2], image.shape[-1]])
stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input)
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 390f6f4560f9814a3af757a4fd16c55fe93d01f9..f50b5f1c5f8e617066bb47636c8f4d2b171b6ecb 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
- SEEDLabelDecode, PRENLabelDecode, SVTRLabelDecode
+ SEEDLabelDecode, PRENLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
- 'DistillationSARLabelDecode', 'SVTRLabelDecode'
+ 'DistillationSARLabelDecode'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 50f11f899fb4dd49da75199095772a92cc4a8d7b..bf0fd890bf25949361665d212bf8e1a657054e5b 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -752,40 +752,3 @@ class PRENLabelDecode(BaseRecLabelDecode):
return text
label = self.decode(label)
return text, label
-
-
-class SVTRLabelDecode(BaseRecLabelDecode):
- """ Convert between text-label and text-index """
-
- def __init__(self, character_dict_path=None, use_space_char=False,
- **kwargs):
- super(SVTRLabelDecode, self).__init__(character_dict_path,
- use_space_char)
-
- def __call__(self, preds, label=None, *args, **kwargs):
- if isinstance(preds, tuple):
- preds = preds[-1]
- if isinstance(preds, paddle.Tensor):
- preds = preds.numpy()
- preds_idx = preds.argmax(axis=-1)
- preds_prob = preds.max(axis=-1)
-
- text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
- return_text = []
- for i in range(0, len(text), 3):
- text0 = text[i]
- text1 = text[i + 1]
- text2 = text[i + 2]
-
- text_pred = [text0[0], text1[0], text2[0]]
- text_prob = [text0[1], text1[1], text2[1]]
- id_max = text_prob.index(max(text_prob))
- return_text.append((text_pred[id_max], text_prob[id_max]))
- if label is None:
- return return_text
- label = self.decode(label)
- return return_text, label
-
- def add_special_char(self, dict_character):
- dict_character = ['blank'] + dict_character
- return dict_character
\ No newline at end of file