From 1623c17cdccd6108ebe68ee7cef2ffbae1a1cbf3 Mon Sep 17 00:00:00 2001 From: Topdu <784990967@qq.com> Date: Wed, 11 Aug 2021 09:50:51 +0000 Subject: [PATCH] add rec_nrtr --- configs/rec/rec_mtb_nrtr.yml | 16 + doc/doc_ch/algorithm_overview.md | 2 + doc/doc_ch/recognition.md | 1 + doc/doc_en/algorithm_overview_en.md | 1 + doc/doc_en/recognition_en.md | 2 +- ppocr/data/imaug/__init__.py | 2 +- ppocr/data/imaug/label_ops.py | 2 +- ppocr/data/imaug/operators.py | 32 ++ ppocr/data/imaug/rec_img_aug.py | 30 +- ppocr/losses/__init__.py | 6 +- ppocr/losses/rec_nrtr_loss.py | 38 ++ ppocr/metrics/rec_metric.py | 7 +- ppocr/modeling/architectures/base_model.py | 2 +- ppocr/modeling/backbones/__init__.py | 5 +- .../modeling/backbones/multiheadAttention.py | 365 ++++++++++++++++++ ppocr/modeling/backbones/rec_nrtr_mtb.py | 28 ++ ppocr/modeling/heads/rec_nrtr_optim_head.py | 4 + ppocr/modeling/necks/__init__.py | 2 +- ppocr/postprocess/__init__.py | 5 +- ppocr/postprocess/rec_postprocess.py | 5 +- ppocr/utils/dict_99.txt | 95 +++++ tools/eval.py | 2 + tools/program.py | 9 +- 23 files changed, 638 insertions(+), 23 deletions(-) create mode 100644 ppocr/losses/rec_nrtr_loss.py create mode 100755 ppocr/modeling/backbones/multiheadAttention.py create mode 100644 ppocr/modeling/backbones/rec_nrtr_mtb.py create mode 100644 ppocr/utils/dict_99.txt diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml index d5d36cfa..86a833c5 100644 --- a/configs/rec/rec_mtb_nrtr.yml +++ b/configs/rec/rec_mtb_nrtr.yml @@ -3,22 +3,38 @@ Global: epoch_num: 21 log_smooth_window: 20 print_batch_step: 10 +<<<<<<< HEAD save_model_dir: ./output/rec/nrtr_final/ save_epoch_step: 1 # evaluation is run every 2000 iterations eval_batch_step: [0, 2000] cal_metric_during_train: True +======= + save_model_dir: ./output/rec/piloptimnrtr/ + save_epoch_step: 1 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + cal_metric_during_train: False +>>>>>>> 9c67a7f... add rec_nrtr pretrained_model: checkpoints: save_inference_dir: use_visualdl: False infer_img: doc/imgs_words_en/word_10.png # for data or label process +<<<<<<< HEAD character_dict_path: character_type: EN_symbol max_text_length: 25 infer_mode: False use_space_char: True +======= + character_dict_path: ppocr/utils/dict_99.txt + character_type: dict_99 + max_text_length: 25 + infer_mode: False + use_space_char: False +>>>>>>> 9c67a7f... add rec_nrtr save_res_path: ./output/rec/predicts_nrtr.txt Optimizer: diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 19d7a69c..9c352549 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] +- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2) 参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -58,6 +59,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | +|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index 0f860065..6ce3003c 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t | rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | +| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder | 训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件: diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index d70f99bb..fed9cf44 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -60,5 +60,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| +|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md) diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md index e23166e0..7a5e827d 100644 --- a/doc/doc_en/recognition_en.md +++ b/doc/doc_en/recognition_en.md @@ -207,7 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend | rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | - +| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder | For training Chinese data, it is recommended to use [rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file: diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index a808fd58..9f175382 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize from .randaugment import RandAugment from .operators import * from .label_ops import * diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 39ff8930..a233738c 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -96,7 +96,7 @@ class BaseRecLabelEncode(object): 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', 'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', - 'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari' + 'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99' ] assert character_type in support_character_type, "Only {} are supported now but get {}".format( support_character_type, character_type) diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 9c48b096..950c9377 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -57,6 +57,38 @@ class DecodeImage(object): return data +class NRTRDecodeImage(object): + """ decode image """ + + def __init__(self, img_mode='RGB', channel_first=False, **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + + def __call__(self, data): + img = data['image'] + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype='uint8') + + img = cv2.imdecode(img, 1) + + if img is None: + return None + if self.img_mode == 'GRAY': + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == 'RGB': + assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) + img = img[:, :, ::-1] + img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) + if self.channel_first: + img = img.transpose((2, 0, 1)) + data['image'] = img + return data + class NormalizeImage(object): """ normalize image such as substract mean, divide std """ diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 28e6bd0b..13a5c71d 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -16,7 +16,7 @@ import math import cv2 import numpy as np import random - +from PIL import Image from .text_image_aug import tia_perspective, tia_stretch, tia_distort @@ -42,6 +42,34 @@ class ClsResizeImg(object): data['image'] = norm_img return data +class PILResize(object): + def __init__(self, image_shape, **kwargs): + self.image_shape = image_shape + + def __call__(self, data): + img = data['image'] + image_pil = Image.fromarray(np.uint8(img)) + norm_img = image_pil.resize(self.image_shape, Image.ANTIALIAS) + norm_img = np.array(norm_img) + norm_img = np.expand_dims(norm_img, -1) + norm_img = norm_img.transpose((2, 0, 1)) + data['image'] = norm_img.astype(np.float32) / 128. - 1. + return data + + +class CVResize(object): + def __init__(self, image_shape, **kwargs): + self.image_shape = image_shape + + def __call__(self, data): + img = data['image'] + #print(img) + norm_img = cv2.resize(img,self.image_shape) + norm_img = np.expand_dims(norm_img, -1) + norm_img = norm_img.transpose((2, 0, 1)) + data['image'] = norm_img.astype(np.float32) / 128. - 1. + return data + class RecResizeImg(object): def __init__(self, diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index bf10d298..e1c3ed95 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -25,7 +25,7 @@ from .det_sast_loss import SASTLoss from .rec_ctc_loss import CTCLoss from .rec_att_loss import AttentionLoss from .rec_srn_loss import SRNLoss - +from .rec_nrtr_loss import NRTRLoss # cls loss from .cls_loss import ClsLoss @@ -42,8 +42,8 @@ from .combined_loss import CombinedLoss def build_loss(config): support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss', 'CombinedLoss' - ] + 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss'] + config = copy.deepcopy(config) module_name = config.pop('name') assert module_name in support_dict, Exception('loss only support {}'.format( diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_nrtr_loss.py new file mode 100644 index 00000000..915f506d --- /dev/null +++ b/ppocr/losses/rec_nrtr_loss.py @@ -0,0 +1,38 @@ +import paddle +from paddle import nn +import paddle.nn.functional as F + + +def cal_performance(pred, tgt): + + pred = pred.max(1)[1] + tgt = tgt.contiguous().view(-1) + non_pad_mask = tgt.ne(0) + n_correct = pred.eq(tgt) + n_correct = n_correct.masked_select(non_pad_mask).sum().item() + return n_correct + + +class NRTRLoss(nn.Layer): + def __init__(self,smoothing=True, **kwargs): + super(NRTRLoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(reduction='mean',ignore_index=0) + self.smoothing = smoothing + + def forward(self, pred, batch): + pred = pred.reshape([-1, pred.shape[2]]) + max_len = batch[2].max() + tgt = batch[1][:,1:2+max_len] + tgt = tgt.reshape([-1] ) + if self.smoothing: + eps = 0.1 + n_class = pred.shape[1] + one_hot = F.one_hot(tgt, pred.shape[1]) + one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) + log_prb = F.log_softmax(pred, axis=1) + non_pad_mask = paddle.not_equal(tgt, paddle.zeros(tgt.shape,dtype='int64')) + loss = -(one_hot * log_prb).sum(axis=1) + loss = loss.masked_select(non_pad_mask).mean() + else: + loss = self.loss_func(pred, tgt) + return {'loss': loss} diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index 66c084d7..3712e6e9 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -30,7 +30,7 @@ class RecMetric(object): target = target.replace(" ", "") norm_edit_dis += Levenshtein.distance(pred, target) / max( len(pred), len(target), 1) - if pred == target: + if pred.lower() == target.lower(): correct_num += 1 all_num += 1 self.correct_num += correct_num @@ -48,8 +48,8 @@ class RecMetric(object): 'norm_edit_dis': 0, } """ - acc = 1.0 * self.correct_num / self.all_num - norm_edit_dis = 1 - self.norm_edit_dis / self.all_num + acc = 1.0 * self.correct_num / (self.all_num) + norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num) self.reset() return {'acc': acc, 'norm_edit_dis': norm_edit_dis} @@ -57,3 +57,4 @@ class RecMetric(object): self.correct_num = 0 self.all_num = 0 self.norm_edit_dis = 0 + \ No newline at end of file diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index 4c941fcf..66da4b33 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - +import paddle from paddle import nn from ppocr.modeling.transforms import build_transform from ppocr.modeling.backbones import build_backbone diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index fe2c9bc3..73afbe11 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -25,7 +25,10 @@ def build_backbone(config, model_type): from .rec_mobilenet_v3 import MobileNetV3 from .rec_resnet_vd import ResNet from .rec_resnet_fpn import ResNetFPN - support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN'] + from .rec_nrtr_mtb import MTB + from .rec_swin import SwinTransformer + support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN','MTB','SwinTransformer'] + elif model_type == 'e2e': from .e2e_resnet_vd_pg import ResNet support_dict = ['ResNet'] diff --git a/ppocr/modeling/backbones/multiheadAttention.py b/ppocr/modeling/backbones/multiheadAttention.py new file mode 100755 index 00000000..f18e9957 --- /dev/null +++ b/ppocr/modeling/backbones/multiheadAttention.py @@ -0,0 +1,365 @@ +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle.nn import Linear +from paddle.nn.initializer import XavierUniform as xavier_uniform_ +from paddle.nn.initializer import Constant as constant_ +from paddle.nn.initializer import XavierNormal as xavier_normal_ + +zeros_ = constant_(value=0.) +ones_ = constant_(value=1.) + +class MultiheadAttention(nn.Layer): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model + num_heads: parallel attention layers, or heads + + Examples:: + + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) + + if add_bias_kv: + self.bias_k = self.create_parameter( + shape=(1, 1, embed_dim), default_initializer=zeros_) + self.add_parameter("bias_k", self.bias_k) + self.bias_v = self.create_parameter( + shape=(1, 1, embed_dim), default_initializer=zeros_) + self.add_parameter("bias_v", self.bias_v) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 2, kernel_size=(1, 1)) + self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 3, kernel_size=(1, 1)) + + def _reset_parameters(self): + + + xavier_uniform_(self.out_proj.weight) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, + need_weights=True, static_kv=False, attn_mask=None, qkv_ = [False,False,False]): + """ + Inputs of forward function + query: [target length, batch size, embed dim] + key: [sequence length, batch size, embed dim] + value: [sequence length, batch size, embed dim] + key_padding_mask: if True, mask padding based on batch size + incremental_state: if provided, previous time steps are cashed + need_weights: output attn_output_weights + static_kv: key and value are static + + Outputs of forward function + attn_output: [target length, batch size, embed dim] + attn_output_weights: [batch size, target length, sequence length] + """ + qkv_same = qkv_[0] + kv_same = qkv_[1] + + tgt_len, bsz, embed_dim = query.shape + assert embed_dim == self.embed_dim + assert list(query.shape) == [tgt_len, bsz, embed_dim] + assert key.shape == value.shape + + if qkv_same: + # self-attention + q, k, v = self._in_proj_qkv(query) + elif kv_same: + # encoder-decoder attention + q = self._in_proj_q(query) + if key is None: + assert value is None + k = v = None + else: + k, v = self._in_proj_kv(key) + else: + q = self._in_proj_q(query) + k = self._in_proj_k(key) + v = self._in_proj_v(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + self.bias_k = paddle.concat([self.bias_k for i in range(bsz)],axis=1) + self.bias_v = paddle.concat([self.bias_v for i in range(bsz)],axis=1) + k = paddle.concat([k, self.bias_k]) + v = paddle.concat([v, self.bias_v]) + if attn_mask is not None: + attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1) + if key_padding_mask is not None: + key_padding_mask = paddle.concat( + [key_padding_mask,paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1) + + q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + if k is not None: + k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + if v is not None: + v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + + + + src_len = k.shape[1] + + if key_padding_mask is not None: + assert key_padding_mask.shape[0] == bsz + assert key_padding_mask.shape[1] == src_len + + if self.add_zero_attn: + src_len += 1 + k = paddle.concat([k, paddle.zeros((k.shape[0], 1) + k.shape[2:],dtype=k.dtype)], axis=1) + v = paddle.concat([v, paddle.zeros((v.shape[0], 1) + v.shape[2:],dtype=v.dtype)], axis=1) + if attn_mask is not None: + attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1) + if key_padding_mask is not None: + key_padding_mask = paddle.concat( + [key_padding_mask, paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1) + attn_output_weights = paddle.bmm(q, k.transpose([0,2,1])) + assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_output_weights += attn_mask + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') + y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') + y = paddle.where(key==0.,key, y) + attn_output_weights += y + attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len]) + + attn_output_weights = F.softmax( + attn_output_weights.astype('float32'), axis=-1, + dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype) + attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) + + attn_output = paddle.bmm(attn_output_weights, v) + assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim]) + attn_output = self.out_proj(attn_output) + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads + else: + attn_output_weights = None + + return attn_output, attn_output_weights + + def _in_proj_qkv(self, query): + query = query.transpose([1, 2, 0]) + query = paddle.unsqueeze(query, axis=2) + res = self.conv3(query) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res.chunk(3, axis=-1) + + def _in_proj_kv(self, key): + key = key.transpose([1, 2, 0]) + key = paddle.unsqueeze(key, axis=2) + res = self.conv2(key) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res.chunk(2, axis=-1) + + def _in_proj_q(self, query): + query = query.transpose([1, 2, 0]) + query = paddle.unsqueeze(query, axis=2) + res = self.conv1(query) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_k(self, key): + + key = key.transpose([1, 2, 0]) + key = paddle.unsqueeze(key, axis=2) + res = self.conv1(key) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_v(self, value): + + value = value.transpose([1,2,0])#(1, 2, 0) + value = paddle.unsqueeze(value, axis=2) + res = self.conv1(value) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + + +class MultiheadAttentionOptim(nn.Layer): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model + num_heads: parallel attention layers, or heads + + Examples:: + + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): + super(MultiheadAttentionOptim, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) + + self._reset_parameters() + + self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + + def _reset_parameters(self): + + + xavier_uniform_(self.out_proj.weight) + + + def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, + need_weights=True, static_kv=False, attn_mask=None): + """ + Inputs of forward function + query: [target length, batch size, embed dim] + key: [sequence length, batch size, embed dim] + value: [sequence length, batch size, embed dim] + key_padding_mask: if True, mask padding based on batch size + incremental_state: if provided, previous time steps are cashed + need_weights: output attn_output_weights + static_kv: key and value are static + + Outputs of forward function + attn_output: [target length, batch size, embed dim] + attn_output_weights: [batch size, target length, sequence length] + """ + + + tgt_len, bsz, embed_dim = query.shape + assert embed_dim == self.embed_dim + assert list(query.shape) == [tgt_len, bsz, embed_dim] + assert key.shape == value.shape + + q = self._in_proj_q(query) + k = self._in_proj_k(key) + v = self._in_proj_v(value) + q *= self.scaling + + + q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2]) + + + src_len = k.shape[1] + + if key_padding_mask is not None: + assert key_padding_mask.shape[0] == bsz + assert key_padding_mask.shape[1] == src_len + + + attn_output_weights = paddle.bmm(q, k.transpose([0,2,1])) + assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_output_weights += attn_mask + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') + + y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') + + y = paddle.where(key==0.,key, y) + + attn_output_weights += y + attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len]) + + attn_output_weights = F.softmax( + attn_output_weights.astype('float32'), axis=-1, + dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype) + attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) + + attn_output = paddle.bmm(attn_output_weights, v) + assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim]) + attn_output = self.out_proj(attn_output) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len]) + attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads + else: + attn_output_weights = None + + return attn_output, attn_output_weights + + + def _in_proj_q(self, query): + query = query.transpose([1, 2, 0]) + query = paddle.unsqueeze(query, axis=2) + res = self.conv1(query) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_k(self, key): + + key = key.transpose([1, 2, 0]) + key = paddle.unsqueeze(key, axis=2) + res = self.conv2(key) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_v(self, value): + + value = value.transpose([1,2,0])#(1, 2, 0) + value = paddle.unsqueeze(value, axis=2) + res = self.conv3(value) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res \ No newline at end of file diff --git a/ppocr/modeling/backbones/rec_nrtr_mtb.py b/ppocr/modeling/backbones/rec_nrtr_mtb.py new file mode 100644 index 00000000..26a0dc7f --- /dev/null +++ b/ppocr/modeling/backbones/rec_nrtr_mtb.py @@ -0,0 +1,28 @@ +from paddle import nn + +class MTB(nn.Layer): + def __init__(self, cnn_num, in_channels): + super(MTB, self).__init__() + self.block = nn.Sequential() + self.out_channels = in_channels + self.cnn_num = cnn_num + if self.cnn_num == 2: + for i in range(self.cnn_num): + self.block.add_sublayer('conv_{}'.format(i), nn.Conv2D( + in_channels = in_channels if i == 0 else 32*(2**(i-1)), + out_channels = 32*(2**i), + kernel_size = 3, + stride = 2, + padding=1)) + self.block.add_sublayer('relu_{}'.format(i), nn.ReLU()) + self.block.add_sublayer('bn_{}'.format(i), nn.BatchNorm2D(32*(2**i))) + + def forward(self, images): + + x = self.block(images) + if self.cnn_num == 2: + # (b, w, h, c) + x = x.transpose([0, 3, 2, 1]) + x_shape = x.shape + x = x.reshape([x_shape[0], x_shape[1], x_shape[2] * x_shape[3]]) + return x diff --git a/ppocr/modeling/heads/rec_nrtr_optim_head.py b/ppocr/modeling/heads/rec_nrtr_optim_head.py index b9a5100a..1537b0ca 100644 --- a/ppocr/modeling/heads/rec_nrtr_optim_head.py +++ b/ppocr/modeling/heads/rec_nrtr_optim_head.py @@ -7,7 +7,11 @@ from paddle.nn import LayerList from paddle.nn.initializer import XavierNormal as xavier_uniform_ from paddle.nn import Dropout, Linear, LayerNorm, Conv2D import numpy as np +<<<<<<< HEAD from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim +======= +from ppocr.modeling.backbones.multiheadAttention import MultiheadAttentionOptim +>>>>>>> 9c67a7f... add rec_nrtr from paddle.nn.initializer import Constant as constant_ from paddle.nn.initializer import XavierNormal as xavier_normal_ diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index 37a5cf78..1be38e93 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -21,7 +21,7 @@ def build_neck(config): from .sast_fpn import SASTFPN from .rnn import SequenceEncoder from .pg_fpn import PGFPN - support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN'] + support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN','TFEncoder'] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index cd2b7ea7..f7f1bcd6 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -24,16 +24,15 @@ __all__ = ['build_post_process'] from .db_postprocess import DBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess -from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode +from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess - def build_post_process(config, global_config=None): support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', - 'DistillationCTCLabelDecode' + 'DistillationCTCLabelDecode', 'NRTRLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 7350b4ec..e0f3b740 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -28,7 +28,7 @@ class BaseRecLabelDecode(object): 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr', - 'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari' + 'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99' ] assert character_type in support_character_type, "Only {} are supported now but get {}".format( support_character_type, character_type) @@ -256,8 +256,7 @@ class AttnLabelDecode(BaseRecLabelDecode): if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ batch_idx][idx]: continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) + char_list.append(self.character[int(text_index[batch_idx][idx])]) if text_prob is not None: conf_list.append(text_prob[batch_idx][idx]) else: diff --git a/ppocr/utils/dict_99.txt b/ppocr/utils/dict_99.txt new file mode 100644 index 00000000..e00863bf --- /dev/null +++ b/ppocr/utils/dict_99.txt @@ -0,0 +1,95 @@ +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +; +< += +> +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +[ +\ +] +^ +_ +` +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +{ +| +} +~ + \ No newline at end of file diff --git a/tools/eval.py b/tools/eval.py index d26f2a04..66eb315f 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -22,6 +22,7 @@ import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process @@ -30,6 +31,7 @@ from ppocr.utils.save_load import init_model from ppocr.utils.utility import print_dict import tools.program as program + def main(): global_config = config['Global'] # build dataloader diff --git a/tools/program.py b/tools/program.py index 7641bed7..4b6dc9e4 100755 --- a/tools/program.py +++ b/tools/program.py @@ -186,7 +186,7 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - + use_nrtr = config['Architecture']['algorithm'] == "NRTR" if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] else: @@ -211,6 +211,9 @@ def train(config, others = batch[-4:] preds = model(images, others) model_average = True + elif use_nrtr: + max_len = batch[2].max() + preds = model(images, batch[1][:,:2+max_len]) else: preds = model(images) loss = loss_class(preds, batch) @@ -350,13 +353,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class, break images = batch[0] start = time.time() - if use_srn: others = batch[-4:] preds = model(images, others) else: preds = model(images) - batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods post_result = post_process_class(preds, batch[1]) @@ -386,7 +387,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet', 'Distillation' + 'CLS', 'PGNet', 'Distillation','NRTR' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' -- GitLab