From 8a95b3352df44307a1e9f0aff5458881356170ec Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Tue, 24 Aug 2021 03:49:26 +0000 Subject: [PATCH] add_rec_sar, test=dygraph --- doc/doc_ch/algorithm_overview.md | 3 +- doc/doc_ch/recognition.md | 1 + doc/doc_en/algorithm_overview_en.md | 2 + doc/doc_en/recognition_en.md | 1 + ppocr/data/imaug/__init__.py | 2 +- ppocr/data/imaug/label_ops.py | 46 +++++++++++++++++ ppocr/data/imaug/rec_img_aug.py | 50 ++++++++++++++++++ ppocr/losses/__init__.py | 3 +- ppocr/modeling/backbones/__init__.py | 3 +- ppocr/modeling/heads/__init__.py | 3 +- ppocr/postprocess/__init__.py | 4 +- ppocr/postprocess/rec_postprocess.py | 77 ++++++++++++++++++++++++++++ tools/eval.py | 3 +- tools/infer_rec.py | 9 ++++ tools/program.py | 13 +++-- 15 files changed, 207 insertions(+), 13 deletions(-) diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 19d7a69c..b6a365b3 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] +- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2)) 参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -58,6 +59,6 @@ PaddleOCR基于动态图开源的文本识别算法列表: |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | - +|SAR|Resnet31| 87.1% | rec_r31_sar | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index 0ff0513a..0ac6da87 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t | rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | +| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder | 训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件: diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index d70f99bb..f201589a 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -46,6 +46,7 @@ PaddleOCR open-source text recognition algorithms list: - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] +- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2)) Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -60,5 +61,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| +|SAR|Resnet31| 87.1% | rec_r31_sar | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md) diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md index 634ec783..91f81a6a 100644 --- a/doc/doc_en/recognition_en.md +++ b/doc/doc_en/recognition_en.md @@ -207,6 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend | rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | +| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder | For training Chinese data, it is recommended to use diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 52194eb9..6f0492e1 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, SARRecResizeImg from .randaugment import RandAugment from .copy_paste import CopyPaste from .operators import * diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index d222c410..56da029b 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -521,3 +521,49 @@ class TableLabelEncode(object): assert False, "Unsupport type %s in char_or_elem" \ % char_or_elem return idx + + +class SARLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(SARLabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + beg_end_str = "" + unknown_str = "" + padding_str = "" + dict_character = dict_character + [unknown_str] + self.unknown_idx = len(dict_character) - 1 + dict_character = dict_character + [beg_end_str] + self.start_idx = len(dict_character) - 1 + self.end_idx = len(dict_character) - 1 + dict_character = dict_character + [padding_str] + self.padding_idx = len(dict_character) - 1 + + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len - 1: + return None + data['length'] = np.array(len(text)) + target = [self.start_idx] + text + [self.end_idx] + padded_text = [self.padding_idx for _ in range(self.max_text_len)] + + padded_text[:len(target)] = target + data['label'] = np.array(padded_text) + return data + + def get_ignored_tokens(self): + return [self.padding_idx] diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 28e6bd0b..d968f437 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -83,6 +83,56 @@ class SRNRecResizeImg(object): return data +class SARRecResizeImg(object): + def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs): + self.image_shape = image_shape + self.width_downsample_ratio = width_downsample_ratio + + def __call__(self, data): + img = data['image'] + norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(img, self.image_shape, self.width_downsample_ratio) + data['image'] = norm_img + data['resized_shape'] = resize_shape + data['pad_shape'] = pad_shape + data['valid_ratio'] = valid_ratio + return data + + +def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): + imgC, imgH, imgW_min, imgW_max = image_shape + h = img.shape[0] + w = img.shape[1] + valid_ratio = 1.0 + # make sure new_width is an integral multiple of width_divisor. + width_divisor = int(1 / width_downsample_ratio) + # resize + ratio = w / float(h) + resize_w = math.ceil(imgH * ratio) + if resize_w % width_divisor != 0: + resize_w = round(resize_w / width_divisor) * width_divisor + if imgW_min is not None: + resize_w = max(imgW_min, resize_w) + if imgW_max is not None: + valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) + resize_w = min(imgW_max, resize_w) + resized_image = cv2.resize(img, (resize_w, imgH)) + resized_image = resized_image.astype('float32') + # norm + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + resize_shape = resized_image.shape + padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) + padding_im[:, :, 0:resize_w] = resized_image + pad_shape = padding_im.shape + + return padding_im, resize_shape, pad_shape, valid_ratio + + def resize_norm_img(img, image_shape): imgC, imgH, imgW = image_shape h = img.shape[0] diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 025ae7ca..d731185b 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -25,6 +25,7 @@ from .det_sast_loss import SASTLoss from .rec_ctc_loss import CTCLoss from .rec_att_loss import AttentionLoss from .rec_srn_loss import SRNLoss +from .rec_sar_loss import SARLoss # cls loss from .cls_loss import ClsLoss @@ -44,7 +45,7 @@ from .table_att_loss import TableAttentionLoss def build_loss(config): support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss' + 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss', 'SARLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index f4fe8c76..ce1d6bcf 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -26,8 +26,9 @@ def build_backbone(config, model_type): from .rec_resnet_vd import ResNet from .rec_resnet_fpn import ResNetFPN from .rec_mv1_enhance import MobileNetV1Enhance + from .rec_resnet_31 import ResNet31 support_dict = [ - "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN" + "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN", "ResNet31" ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 50964794..8414f2ad 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -26,12 +26,13 @@ def build_head(config): from .rec_ctc_head import CTCHead from .rec_att_head import AttentionHead from .rec_srn_head import SRNHead + from .rec_sar_head import SARHead # cls head from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead', 'PGHead', 'TableAttentionHead'] + 'SRNHead', 'PGHead', 'TableAttentionHead', 'SARHead'] #table head from .table_att_head import TableAttentionHead diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 654ddf39..86f9ede4 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -25,7 +25,7 @@ from .db_postprocess import DBPostProcess, DistillationDBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ - TableLabelDecode + TableLabelDecode, SARLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess @@ -35,7 +35,7 @@ def build_post_process(config, global_config=None): 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', - 'DistillationDBPostProcess' + 'DistillationDBPostProcess', 'SARLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 8ebe5b27..9e9ddd8f 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -15,6 +15,7 @@ import numpy as np import string import paddle from paddle.nn import functional as F +import re class BaseRecLabelDecode(object): @@ -454,3 +455,79 @@ class TableLabelDecode(object): assert False, "Unsupport type %s in char_or_elem" \ % char_or_elem return idx + + +class SARLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(SARLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + beg_end_str = "" + unknown_str = "" + padding_str = "" + dict_character = dict_character + [unknown_str] + self.unknown_idx = len(dict_character) - 1 + dict_character = dict_character + [beg_end_str] + self.start_idx = len(dict_character) - 1 + self.end_idx = len(dict_character) - 1 + dict_character = dict_character + [padding_str] + self.padding_idx = len(dict_character) - 1 + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if int(text_index[batch_idx][idx]) == int(self.end_idx): + if text_prob is None and idx ==0: + continue + else: + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]') + text = text.lower() + text = comp.sub('', text) + result_list.append((text, np.mean(conf_list))) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def get_ignored_tokens(self): + return [self.padding_idx] diff --git a/tools/eval.py b/tools/eval.py index 0120baab..fb8c7925 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -55,6 +55,7 @@ def main(): model = build_model(config['Architecture']) use_srn = config['Architecture']['algorithm'] == "SRN" + use_sar = config['Architecture']['algorithm'] == "SAR" if "model_type" in config['Architecture'].keys(): model_type = config['Architecture']['model_type'] else: @@ -71,7 +72,7 @@ def main(): # start eval metric = program.eval(model, valid_dataloader, post_process_class, - eval_class, model_type, use_srn) + eval_class, model_type, use_srn, use_sar) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 09f5a0c7..f16cd7d3 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -74,6 +74,10 @@ def main(): 'image', 'encoder_word_pos', 'gsrm_word_pos', 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2' ] + elif config['Architecture']['algorithm'] == "SAR": + op[op_name]['keep_keys'] = [ + 'image', 'valid_ratio' + ] else: op[op_name]['keep_keys'] = ['image'] transforms.append(op) @@ -106,11 +110,16 @@ def main(): paddle.to_tensor(gsrm_slf_attn_bias1_list), paddle.to_tensor(gsrm_slf_attn_bias2_list) ] + if config['Architecture']['algorithm'] == "SAR": + valid_ratio = np.expand_dims(batch[-1], axis=0) + img_metas = [paddle.to_tensor(valid_ratio)] images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) if config['Architecture']['algorithm'] == "SRN": preds = model(images, others) + elif config['Architecture']['algorithm'] == "SAR": + preds = model(images, img_metas) else: preds = model(images) post_result = post_process_class(preds) diff --git a/tools/program.py b/tools/program.py index 595fe4cb..cb6f8a8b 100755 --- a/tools/program.py +++ b/tools/program.py @@ -186,6 +186,7 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" + use_sar = config['Architecture']['algorithm'] == 'SAR' try: model_type = config['Architecture']['model_type'] except: @@ -213,7 +214,7 @@ def train(config, images = batch[0] if use_srn: model_average = True - if use_srn or model_type == 'table': + if use_srn or model_type == 'table' or use_sar: preds = model(images, data=batch[1:]) else: preds = model(images) @@ -277,7 +278,8 @@ def train(config, post_process_class, eval_class, model_type, - use_srn=use_srn) + use_srn=use_srn, + use_sar=use_sar) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) logger.info(cur_metric_str) @@ -349,7 +351,8 @@ def eval(model, post_process_class, eval_class, model_type, - use_srn=False): + use_srn=False, + use_sar=False): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -362,7 +365,7 @@ def eval(model, break images = batch[0] start = time.time() - if use_srn or model_type == 'table': + if use_srn or model_type == 'table' or use_sar: preds = model(images, data=batch[1:]) else: preds = model(images) @@ -398,7 +401,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet', 'Distillation', 'TableAttn' + 'CLS', 'PGNet', 'Distillation', 'TableAttn', 'SAR' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' -- GitLab