提交 d2c11969 编写于 作者: T Topdu

rm svtrlabeldecode resize

上级 ac703e56
......@@ -62,7 +62,7 @@ python3 tools/eval.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.load_static_weights=false
python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
```
......@@ -72,11 +72,11 @@ python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='
<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](#model)),可以使用如下命令进行转换:
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar)),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.save_inference_dir=./inference/rec_mtb_nrtr/ Global.load_static_weights=False
python3 tools/export_model.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.save_inference_dir=./inference/rec_mtb_nrtr/
```
执行如下命令进行模型推理:
......
......@@ -26,10 +26,13 @@
1. 首次发现单视觉模型可以达到与视觉语言模型相媲美甚至更高的准确率,并且其具有效率高和适应多语言的优点,在实际应用中很有前景。
2. SVTR从字符组件的角度出发,逐渐的合并字符组件,自下而上地完成字符的识别。
3. SVTR引入了局部和全局Mixing,分别用于提取字符组件特征和字符间依赖关系,与多尺度的特征一起,形成多粒度特征描述。
4. SVTR-L在识别英文和中文场景文本方面实现了最先进的性能。SVTR-T平衡精确度和效率,在一个NVIDIA 1080Ti GPU中,每个英文图像文本平均消耗4.5ms。
<a name="model"></a>
`SVTR`在场景文本识别公开数据集上的精度(%)和模型文件如下:
SVTR在场景文本识别公开数据集上的精度(%)和模型文件如下:
* 中文数据集来自于[Chinese Benckmark](https://arxiv.org/abs/2112.15093) ,SVTR的中文训练评估策略遵循该论文。
| SVTR |IC13<br/>857 | SVT |IIIT5k<br/>3000 |IC15<br/>1811| SVTP |CUTE80 | Avg_6 |IC15<br/>2077 |IC13<br/>1015 |IC03<br/>867|IC03<br/>860|Avg_10 |Chinese| 英文<br/>链接 | 中文<br/>链接 |
|:-----:|:------:|:-----:|:---------:|:------:|:-----:|:-----:|:-----:|:-------:|:-------:|:-----:|:-----:|:-----:|:-----:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|
......@@ -56,10 +59,6 @@
[英文数据集下载](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here)
[中文数据集下载](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download)
**注意:**
1. 训练`SVTR`时,需将将配置文件中的测试数据集路径设置为本地的评估数据集路径,例如将中文的`scene_test`数据集修改为`scene_val`
2. 训练`SVTR`时,需将配置文件中的`SVTRLableDecode`修改为`CTCLabelDecode`,将`SVTRRecResizeImg`修改为`RecResizeImg`
#### 启动训练
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`SVTR`识别模型时需要**更换配置文件**`SVTR`[配置文件](../../configs/rec/rec_svtrnet.yml)
......@@ -67,7 +66,7 @@
<a name="3-2"></a>
### 3.2 评估
可下载`SVTR`提供模型文件和配置文件[模型下载](#model),以`SVTR-T`为例,使用如下命令进行评估:
可下载`SVTR`提供模型文件和配置文件[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ,以`SVTR-T`为例,使用如下命令进行评估:
```shell
# 注意将pretrained_model的路径设置为本地路径。
......@@ -81,7 +80,7 @@ python3 tools/eval.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_s
使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.load_static_weights=false
python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
```
......@@ -91,11 +90,11 @@ python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6glo
<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。下面以基于`SVTR-T`,在英文数据集训练的模型为例([模型下载地址](#model)),可以使用如下命令进行转换:
首先将训练得到best模型,转换成inference model。下面以基于`SVTR-T`,在英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.save_inference_dir=./inference/rec_svtr_tiny_stn_en/ Global.load_static_weights=False
python3 tools/export_model.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.save_inference_dir=./inference/rec_svtr_tiny_stn_en
```
执行如下命令进行模型推理:
......
......@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
......
......@@ -207,25 +207,6 @@ class PRENResizeImg(object):
return data
class SVTRRecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
img = data['image']
norm_img = resize_norm_img_svtr(img, self.image_shape, self.padding)
data['image'] = norm_img
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
......@@ -344,58 +325,6 @@ def resize_norm_img_srn(img, image_shape):
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def resize_norm_img_svtr(img, image_shape, padding=False):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
if h > 2.0 * w:
image = Image.fromarray(img)
image1 = image.rotate(90, expand=True)
image2 = image.rotate(-90, expand=True)
img1 = np.array(image1)
img2 = np.array(image2)
else:
img1 = copy.deepcopy(img)
img2 = copy.deepcopy(img)
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image1 = cv2.resize(
img1, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image2 = cv2.resize(
img2, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
else:
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image1 = resized_image1.astype('float32')
resized_image2 = resized_image2.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image1 = resized_image1.transpose((2, 0, 1)) / 255
resized_image2 = resized_image2.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resized_image1 -= 0.5
resized_image1 /= 0.5
resized_image2 -= 0.5
resized_image2 /= 0.5
padding_im = np.zeros((3, imgC, imgH, imgW), dtype=np.float32)
padding_im[0, :, :, 0:resized_w] = resized_image
padding_im[1, :, :, 0:resized_w] = resized_image1
padding_im[2, :, :, 0:resized_w] = resized_image2
return padding_im
def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
......
......@@ -128,8 +128,6 @@ class STN_ON(nn.Layer):
self.out_channels = in_channels
def forward(self, image):
if len(image.shape)==5:
image = image.reshape([0, image.shape[-3], image.shape[-2], image.shape[-1]])
stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input)
......
......@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, SVTRLabelDecode
SEEDLabelDecode, PRENLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
......@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'SVTRLabelDecode'
'DistillationSARLabelDecode'
]
if config['name'] == 'PSEPostProcess':
......
......@@ -752,40 +752,3 @@ class PRENLabelDecode(BaseRecLabelDecode):
return text
label = self.decode(label)
return text, label
class SVTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SVTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=-1)
preds_prob = preds.max(axis=-1)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
return_text = []
for i in range(0, len(text), 3):
text0 = text[i]
text1 = text[i + 1]
text2 = text[i + 2]
text_pred = [text0[0], text1[0], text2[0]]
text_prob = [text0[1], text1[1], text2[1]]
id_max = text_prob.index(max(text_prob))
return_text.append((text_pred[id_max], text_prob[id_max]))
if label is None:
return return_text
label = self.decode(label)
return return_text, label
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册