diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml new file mode 100644 index 0000000000000000000000000000000000000000..635c392d705acd1fcfbf9f744a8d7167c448d74c --- /dev/null +++ b/configs/rec/rec_mtb_nrtr.yml @@ -0,0 +1,102 @@ +Global: + use_gpu: True + epoch_num: 21 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/nrtr/ + save_epoch_step: 1 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + character_type: EN_symbol + max_text_length: 25 + infer_mode: False + use_space_char: True + save_res_path: ./output/rec/predicts_nrtr.txt + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.99 + clip_norm: 5.0 + lr: + name: Cosine + learning_rate: 0.0005 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0. + +Architecture: + model_type: rec + algorithm: NRTR + in_channels: 1 + Transform: + Backbone: + name: MTB + cnn_num: 2 + Head: + name: Transformer + d_model: 512 + num_encoder_layers: 6 + beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation. + + +Loss: + name: NRTRLoss + smoothing: True + +PostProcess: + name: NRTRLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training/ + transforms: + - NRTRDecodeImage: # load image + img_mode: BGR + channel_first: False + - NRTRLabelEncode: # Class handling label + - NRTRRecResizeImg: + image_shape: [100, 32] + resize_type: PIL # PIL or OpenCV + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 512 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation/ + transforms: + - NRTRDecodeImage: # load image + img_mode: BGR + channel_first: False + - NRTRLabelEncode: # Class handling label + - NRTRRecResizeImg: + image_shape: [100, 32] + resize_type: PIL # PIL or OpenCV + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 1 + use_shared_memory: False diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 19d7a69c7fb08a8e7fb36c3043aa211de19b9295..e8f23b54e50be7ac2f5388a726ed6540d2c159c3 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] +- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)) 参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -58,6 +59,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | +|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index 3c79c03ec4e82147b0f235a4b204baa2d1ec92ad..9d5bf2861ba2e700975cdb5584efcf6a8ab26801 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t | rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | +| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder | 训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件: diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index d70f99bb5c5b0bdcb7d39209dfc9a77c56918260..8e8f0d3f8c36885515ea65f17e910c31838bc603 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -46,6 +46,7 @@ PaddleOCR open-source text recognition algorithms list: - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] - [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] +- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)) Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -60,5 +61,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| +|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md) diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md index 4a36db1e0a1940342d329fbbce2e79bf862dcefb..ff5b802d8dbf54e842c035bce9ed480428eade83 100644 --- a/doc/doc_en/recognition_en.md +++ b/doc/doc_en/recognition_en.md @@ -207,7 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend | rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att | | rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att | | rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | - +| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder | For training Chinese data, it is recommended to use [rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file: diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 52194eb964f7a7fd159cc1a42b73d280f8ee5fb4..4418d075cb297880017b4839438aa1cbb9b5ebcd 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg from .randaugment import RandAugment from .copy_paste import CopyPaste from .operators import * diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index d222c4109c3723bc1adb71ee7c21a27a010f8f45..f6263950959b0ee6a96647fb248098bb5c567651 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -161,6 +161,34 @@ class BaseRecLabelEncode(object): return text_list +class NRTRLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + character_type='EN_symbol', + use_space_char=False, + **kwargs): + + super(NRTRLabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + data['length'] = np.array(len(text)) + text.insert(0, 2) + text.append(3) + text = text + [0] * (self.max_text_len - len(text)) + data['label'] = np.array(text) + return data + def add_special_char(self, dict_character): + dict_character = ['blank','','',''] + dict_character + return dict_character + class CTCLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 2535b4420c503f2e9e9cc5a677ef70c4dd9c36be..aa3acd1d1264fb2af75c773efd1e5c575b465fe5 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -57,6 +57,38 @@ class DecodeImage(object): return data +class NRTRDecodeImage(object): + """ decode image """ + + def __init__(self, img_mode='RGB', channel_first=False, **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + + def __call__(self, data): + img = data['image'] + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype='uint8') + + img = cv2.imdecode(img, 1) + + if img is None: + return None + if self.img_mode == 'GRAY': + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == 'RGB': + assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) + img = img[:, :, ::-1] + img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) + if self.channel_first: + img = img.transpose((2, 0, 1)) + data['image'] = img + return data + class NormalizeImage(object): """ normalize image such as substract mean, divide std """ diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 28e6bd0bce768c45dbc334c15ace601fd6403f5d..e914d3844606b5b88333a89e5d0e5fda65729458 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -16,7 +16,7 @@ import math import cv2 import numpy as np import random - +from PIL import Image from .text_image_aug import tia_perspective, tia_stretch, tia_distort @@ -43,6 +43,25 @@ class ClsResizeImg(object): return data +class NRTRRecResizeImg(object): + def __init__(self, image_shape, resize_type, **kwargs): + self.image_shape = image_shape + self.resize_type = resize_type + + def __call__(self, data): + img = data['image'] + if self.resize_type == 'PIL': + image_pil = Image.fromarray(np.uint8(img)) + img = image_pil.resize(self.image_shape, Image.ANTIALIAS) + img = np.array(img) + if self.resize_type == 'OpenCV': + img = cv2.resize(img, self.image_shape) + norm_img = np.expand_dims(img, -1) + norm_img = norm_img.transpose((2, 0, 1)) + data['image'] = norm_img.astype(np.float32) / 128. - 1. + return data + + class RecResizeImg(object): def __init__(self, image_shape, diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 025ae7ca5cc604eea59423ca7f523c37c1492e35..eed5a46efcb94cce5e38101640f7ba50d5c60801 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -25,7 +25,7 @@ from .det_sast_loss import SASTLoss from .rec_ctc_loss import CTCLoss from .rec_att_loss import AttentionLoss from .rec_srn_loss import SRNLoss - +from .rec_nrtr_loss import NRTRLoss # cls loss from .cls_loss import ClsLoss @@ -44,8 +44,9 @@ from .table_att_loss import TableAttentionLoss def build_loss(config): support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss' + 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss' ] + config = copy.deepcopy(config) module_name = config.pop('name') assert module_name in support_dict, Exception('loss only support {}'.format( diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_nrtr_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..41714dd2a3ae15eeedc62521d97935f68271c598 --- /dev/null +++ b/ppocr/losses/rec_nrtr_loss.py @@ -0,0 +1,30 @@ +import paddle +from paddle import nn +import paddle.nn.functional as F + + +class NRTRLoss(nn.Layer): + def __init__(self, smoothing=True, **kwargs): + super(NRTRLoss, self).__init__() + self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0) + self.smoothing = smoothing + + def forward(self, pred, batch): + pred = pred.reshape([-1, pred.shape[2]]) + max_len = batch[2].max() + tgt = batch[1][:, 1:2 + max_len] + tgt = tgt.reshape([-1]) + if self.smoothing: + eps = 0.1 + n_class = pred.shape[1] + one_hot = F.one_hot(tgt, pred.shape[1]) + one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) + log_prb = F.log_softmax(pred, axis=1) + non_pad_mask = paddle.not_equal( + tgt, paddle.zeros( + tgt.shape, dtype='int64')) + loss = -(one_hot * log_prb).sum(axis=1) + loss = loss.masked_select(non_pad_mask).mean() + else: + loss = self.loss_func(pred, tgt) + return {'loss': loss} diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index 66c084d771dece0e2974bc72a177b53f564a8f2e..3e82fe756b1e42d6a77192af41e27e0fe30a86ab 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -57,3 +57,4 @@ class RecMetric(object): self.correct_num = 0 self.all_num = 0 self.norm_edit_dis = 0 + diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index dbd18070b36f7e99c62de94048ab53d1bedcebe0..c498d9862abcfc85eaf29ed1d949230a1dc1629c 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -14,7 +14,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - from paddle import nn from ppocr.modeling.transforms import build_transform from ppocr.modeling.backbones import build_backbone diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index f4fe8c76be0835f55f402f35ad6a91a5ca116d88..f8ca7e408ac5fb9118114d1dfaae4a9eb481160f 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -26,8 +26,9 @@ def build_backbone(config, model_type): from .rec_resnet_vd import ResNet from .rec_resnet_fpn import ResNetFPN from .rec_mv1_enhance import MobileNetV1Enhance + from .rec_nrtr_mtb import MTB support_dict = [ - "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN" + 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB' ] elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_nrtr_mtb.py b/ppocr/modeling/backbones/rec_nrtr_mtb.py new file mode 100644 index 0000000000000000000000000000000000000000..04b5c9bb5fdff448fbf7ad366bc39bf0e3ebfe6b --- /dev/null +++ b/ppocr/modeling/backbones/rec_nrtr_mtb.py @@ -0,0 +1,46 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import nn + + +class MTB(nn.Layer): + def __init__(self, cnn_num, in_channels): + super(MTB, self).__init__() + self.block = nn.Sequential() + self.out_channels = in_channels + self.cnn_num = cnn_num + if self.cnn_num == 2: + for i in range(self.cnn_num): + self.block.add_sublayer( + 'conv_{}'.format(i), + nn.Conv2D( + in_channels=in_channels + if i == 0 else 32 * (2**(i - 1)), + out_channels=32 * (2**i), + kernel_size=3, + stride=2, + padding=1)) + self.block.add_sublayer('relu_{}'.format(i), nn.ReLU()) + self.block.add_sublayer('bn_{}'.format(i), + nn.BatchNorm2D(32 * (2**i))) + + def forward(self, images): + x = self.block(images) + if self.cnn_num == 2: + # (b, w, h, c) + x = x.transpose([0, 3, 2, 1]) + x_shape = x.shape + x = x.reshape([x_shape[0], x_shape[1], x_shape[2] * x_shape[3]]) + return x diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 5096479415f504aa9f074d55bd9b2e4a31c730b4..572ec4aa8a0c9cdee236ef66b0277a2f614cfd47 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -26,12 +26,14 @@ def build_head(config): from .rec_ctc_head import CTCHead from .rec_att_head import AttentionHead from .rec_srn_head import SRNHead + from .rec_nrtr_head import Transformer # cls head from .cls_head import ClsHead support_dict = [ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', - 'SRNHead', 'PGHead', 'TableAttentionHead'] + 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead' + ] #table head from .table_att_head import TableAttentionHead diff --git a/ppocr/modeling/heads/multiheadAttention.py b/ppocr/modeling/heads/multiheadAttention.py new file mode 100755 index 0000000000000000000000000000000000000000..651d4f577d2f5d1c11e36f90d1c7fea5fc3ab86e --- /dev/null +++ b/ppocr/modeling/heads/multiheadAttention.py @@ -0,0 +1,178 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle.nn import Linear +from paddle.nn.initializer import XavierUniform as xavier_uniform_ +from paddle.nn.initializer import Constant as constant_ +from paddle.nn.initializer import XavierNormal as xavier_normal_ + +zeros_ = constant_(value=0.) +ones_ = constant_(value=1.) + + +class MultiheadAttention(nn.Layer): + """Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model + num_heads: parallel attention layers, or heads + + """ + + def __init__(self, + embed_dim, + num_heads, + dropout=0., + bias=True, + add_bias_kv=False, + add_zero_attn=False): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) + self._reset_parameters() + self.conv1 = paddle.nn.Conv2D( + in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv2 = paddle.nn.Conv2D( + in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + self.conv3 = paddle.nn.Conv2D( + in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) + + def _reset_parameters(self): + xavier_uniform_(self.out_proj.weight) + + def forward(self, + query, + key, + value, + key_padding_mask=None, + incremental_state=None, + need_weights=True, + static_kv=False, + attn_mask=None): + """ + Inputs of forward function + query: [target length, batch size, embed dim] + key: [sequence length, batch size, embed dim] + value: [sequence length, batch size, embed dim] + key_padding_mask: if True, mask padding based on batch size + incremental_state: if provided, previous time steps are cashed + need_weights: output attn_output_weights + static_kv: key and value are static + + Outputs of forward function + attn_output: [target length, batch size, embed dim] + attn_output_weights: [batch size, target length, sequence length] + """ + tgt_len, bsz, embed_dim = query.shape + assert embed_dim == self.embed_dim + assert list(query.shape) == [tgt_len, bsz, embed_dim] + assert key.shape == value.shape + + q = self._in_proj_q(query) + k = self._in_proj_k(key) + v = self._in_proj_v(value) + q *= self.scaling + + q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose( + [1, 0, 2]) + k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose( + [1, 0, 2]) + v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose( + [1, 0, 2]) + + src_len = k.shape[1] + + if key_padding_mask is not None: + assert key_padding_mask.shape[0] == bsz + assert key_padding_mask.shape[1] == src_len + + attn_output_weights = paddle.bmm(q, k.transpose([0, 2, 1])) + assert list(attn_output_weights. + shape) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_output_weights += attn_mask + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.reshape( + [bsz, self.num_heads, tgt_len, src_len]) + key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') + y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') + y = paddle.where(key == 0., key, y) + attn_output_weights += y + attn_output_weights = attn_output_weights.reshape( + [bsz * self.num_heads, tgt_len, src_len]) + + attn_output_weights = F.softmax( + attn_output_weights.astype('float32'), + axis=-1, + dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 + else attn_output_weights.dtype) + attn_output_weights = F.dropout( + attn_output_weights, p=self.dropout, training=self.training) + + attn_output = paddle.bmm(attn_output_weights, v) + assert list(attn_output. + shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn_output = attn_output.transpose([1, 0, 2]).reshape( + [tgt_len, bsz, embed_dim]) + attn_output = self.out_proj(attn_output) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.reshape( + [bsz, self.num_heads, tgt_len, src_len]) + attn_output_weights = attn_output_weights.sum( + axis=1) / self.num_heads + else: + attn_output_weights = None + return attn_output, attn_output_weights + + def _in_proj_q(self, query): + query = query.transpose([1, 2, 0]) + query = paddle.unsqueeze(query, axis=2) + res = self.conv1(query) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_k(self, key): + key = key.transpose([1, 2, 0]) + key = paddle.unsqueeze(key, axis=2) + res = self.conv2(key) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res + + def _in_proj_v(self, value): + value = value.transpose([1, 2, 0]) #(1, 2, 0) + value = paddle.unsqueeze(value, axis=2) + res = self.conv3(value) + res = paddle.squeeze(res, axis=2) + res = res.transpose([2, 0, 1]) + return res diff --git a/ppocr/modeling/heads/rec_nrtr_head.py b/ppocr/modeling/heads/rec_nrtr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..05dba677b4109897b6a20888151e680e652d6741 --- /dev/null +++ b/ppocr/modeling/heads/rec_nrtr_head.py @@ -0,0 +1,844 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +import copy +from paddle import nn +import paddle.nn.functional as F +from paddle.nn import LayerList +from paddle.nn.initializer import XavierNormal as xavier_uniform_ +from paddle.nn import Dropout, Linear, LayerNorm, Conv2D +import numpy as np +from ppocr.modeling.heads.multiheadAttention import MultiheadAttention +from paddle.nn.initializer import Constant as constant_ +from paddle.nn.initializer import XavierNormal as xavier_normal_ + +zeros_ = constant_(value=0.) +ones_ = constant_(value=1.) + + +class Transformer(nn.Layer): + """A transformer model. User is able to modify the attributes as needed. The architechture + is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, + Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and + Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information + Processing Systems, pages 6000-6010. + + Args: + d_model: the number of expected features in the encoder/decoder inputs (default=512). + nhead: the number of heads in the multiheadattention models (default=8). + num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). + num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + custom_encoder: custom encoder (default=None). + custom_decoder: custom decoder (default=None). + + """ + + def __init__(self, + d_model=512, + nhead=8, + num_encoder_layers=6, + beam_size=0, + num_decoder_layers=6, + dim_feedforward=1024, + attention_dropout_rate=0.0, + residual_dropout_rate=0.1, + custom_encoder=None, + custom_decoder=None, + in_channels=0, + out_channels=0, + dst_vocab_size=99, + scale_embedding=True): + super(Transformer, self).__init__() + self.embedding = Embeddings( + d_model=d_model, + vocab=dst_vocab_size, + padding_idx=0, + scale_embedding=scale_embedding) + self.positional_encoding = PositionalEncoding( + dropout=residual_dropout_rate, + dim=d_model, ) + if custom_encoder is not None: + self.encoder = custom_encoder + else: + if num_encoder_layers > 0: + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, attention_dropout_rate, + residual_dropout_rate) + self.encoder = TransformerEncoder(encoder_layer, + num_encoder_layers) + else: + self.encoder = None + + if custom_decoder is not None: + self.decoder = custom_decoder + else: + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, attention_dropout_rate, + residual_dropout_rate) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers) + + self._reset_parameters() + self.beam_size = beam_size + self.d_model = d_model + self.nhead = nhead + self.tgt_word_prj = nn.Linear(d_model, dst_vocab_size, bias_attr=False) + w0 = np.random.normal(0.0, d_model**-0.5, + (d_model, dst_vocab_size)).astype(np.float32) + self.tgt_word_prj.weight.set_value(w0) + self.apply(self._init_weights) + + def _init_weights(self, m): + + if isinstance(m, nn.Conv2D): + xavier_normal_(m.weight) + if m.bias is not None: + zeros_(m.bias) + + def forward_train(self, src, tgt): + tgt = tgt[:, :-1] + + tgt_key_padding_mask = self.generate_padding_mask(tgt) + tgt = self.embedding(tgt).transpose([1, 0, 2]) + tgt = self.positional_encoding(tgt) + tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0]) + + if self.encoder is not None: + src = self.positional_encoding(src.transpose([1, 0, 2])) + memory = self.encoder(src) + else: + memory = src.squeeze(2).transpose([2, 0, 1]) + output = self.decoder( + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=None, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=None) + output = output.transpose([1, 0, 2]) + logit = self.tgt_word_prj(output) + return logit + + def forward(self, src, targets=None): + """Take in and process masked source/target sequences. + Args: + src: the sequence to the encoder (required). + tgt: the sequence to the decoder (required). + Shape: + - src: :math:`(S, N, E)`. + - tgt: :math:`(T, N, E)`. + Examples: + >>> output = transformer_model(src, tgt) + """ + + if self.training: + max_len = targets[1].max() + tgt = targets[0][:, :2 + max_len] + return self.forward_train(src, tgt) + else: + if self.beam_size > 0: + return self.forward_beam(src) + else: + return self.forward_test(src) + + def forward_test(self, src): + bs = src.shape[0] + if self.encoder is not None: + src = self.positional_encoding(src.transpose([1, 0, 2])) + memory = self.encoder(src) + else: + memory = src.squeeze(2).transpose([2, 0, 1]) + dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64) + for len_dec_seq in range(1, 25): + src_enc = memory.clone() + tgt_key_padding_mask = self.generate_padding_mask(dec_seq) + dec_seq_embed = self.embedding(dec_seq).transpose([1, 0, 2]) + dec_seq_embed = self.positional_encoding(dec_seq_embed) + tgt_mask = self.generate_square_subsequent_mask(dec_seq_embed.shape[ + 0]) + output = self.decoder( + dec_seq_embed, + src_enc, + tgt_mask=tgt_mask, + memory_mask=None, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=None) + dec_output = output.transpose([1, 0, 2]) + + dec_output = dec_output[:, + -1, :] # Pick the last step: (bh * bm) * d_h + word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1) + word_prob = word_prob.reshape([1, bs, -1]) + preds_idx = word_prob.argmax(axis=2) + + if paddle.equal_all( + preds_idx[-1], + paddle.full( + preds_idx[-1].shape, 3, dtype='int64')): + break + + preds_prob = word_prob.max(axis=2) + dec_seq = paddle.concat( + [dec_seq, preds_idx.reshape([-1, 1])], axis=1) + + return dec_seq + + def forward_beam(self, images): + ''' Translation work in one batch ''' + + def get_inst_idx_to_tensor_position_map(inst_idx_list): + ''' Indicate the position of an instance in a tensor. ''' + return { + inst_idx: tensor_position + for tensor_position, inst_idx in enumerate(inst_idx_list) + } + + def collect_active_part(beamed_tensor, curr_active_inst_idx, + n_prev_active_inst, n_bm): + ''' Collect tensor parts associated to active instances. ''' + + _, *d_hs = beamed_tensor.shape + n_curr_active_inst = len(curr_active_inst_idx) + new_shape = (n_curr_active_inst * n_bm, *d_hs) + + beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1]) + beamed_tensor = beamed_tensor.index_select( + paddle.to_tensor(curr_active_inst_idx), axis=0) + beamed_tensor = beamed_tensor.reshape([*new_shape]) + + return beamed_tensor + + def collate_active_info(src_enc, inst_idx_to_position_map, + active_inst_idx_list): + # Sentences which are still active are collected, + # so the decoder will not run on completed sentences. + + n_prev_active_inst = len(inst_idx_to_position_map) + active_inst_idx = [ + inst_idx_to_position_map[k] for k in active_inst_idx_list + ] + active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64') + active_src_enc = collect_active_part( + src_enc.transpose([1, 0, 2]), active_inst_idx, + n_prev_active_inst, n_bm).transpose([1, 0, 2]) + active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( + active_inst_idx_list) + return active_src_enc, active_inst_idx_to_position_map + + def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output, + inst_idx_to_position_map, n_bm, + memory_key_padding_mask): + ''' Decode and update beam status, and then return active beam idx ''' + + def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): + dec_partial_seq = [ + b.get_current_state() for b in inst_dec_beams if not b.done + ] + dec_partial_seq = paddle.stack(dec_partial_seq) + + dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq]) + return dec_partial_seq + + def prepare_beam_memory_key_padding_mask( + inst_dec_beams, memory_key_padding_mask, n_bm): + keep = [] + for idx in (memory_key_padding_mask): + if not inst_dec_beams[idx].done: + keep.append(idx) + memory_key_padding_mask = memory_key_padding_mask[ + paddle.to_tensor(keep)] + len_s = memory_key_padding_mask.shape[-1] + n_inst = memory_key_padding_mask.shape[0] + memory_key_padding_mask = paddle.concat( + [memory_key_padding_mask for i in range(n_bm)], axis=1) + memory_key_padding_mask = memory_key_padding_mask.reshape( + [n_inst * n_bm, len_s]) #repeat(1, n_bm) + return memory_key_padding_mask + + def predict_word(dec_seq, enc_output, n_active_inst, n_bm, + memory_key_padding_mask): + tgt_key_padding_mask = self.generate_padding_mask(dec_seq) + dec_seq = self.embedding(dec_seq).transpose([1, 0, 2]) + dec_seq = self.positional_encoding(dec_seq) + tgt_mask = self.generate_square_subsequent_mask(dec_seq.shape[ + 0]) + dec_output = self.decoder( + dec_seq, + enc_output, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ).transpose([1, 0, 2]) + dec_output = dec_output[:, + -1, :] # Pick the last step: (bh * bm) * d_h + word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1) + word_prob = word_prob.reshape([n_active_inst, n_bm, -1]) + return word_prob + + def collect_active_inst_idx_list(inst_beams, word_prob, + inst_idx_to_position_map): + active_inst_idx_list = [] + for inst_idx, inst_position in inst_idx_to_position_map.items(): + is_inst_complete = inst_beams[inst_idx].advance(word_prob[ + inst_position]) + if not is_inst_complete: + active_inst_idx_list += [inst_idx] + + return active_inst_idx_list + + n_active_inst = len(inst_idx_to_position_map) + dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) + memory_key_padding_mask = None + word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm, + memory_key_padding_mask) + # Update the beam with predicted word prob information and collect incomplete instances + active_inst_idx_list = collect_active_inst_idx_list( + inst_dec_beams, word_prob, inst_idx_to_position_map) + return active_inst_idx_list + + def collect_hypothesis_and_scores(inst_dec_beams, n_best): + all_hyp, all_scores = [], [] + for inst_idx in range(len(inst_dec_beams)): + scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() + all_scores += [scores[:n_best]] + hyps = [ + inst_dec_beams[inst_idx].get_hypothesis(i) + for i in tail_idxs[:n_best] + ] + all_hyp += [hyps] + return all_hyp, all_scores + + with paddle.no_grad(): + #-- Encode + + if self.encoder is not None: + src = self.positional_encoding(images.transpose([1, 0, 2])) + src_enc = self.encoder(src).transpose([1, 0, 2]) + else: + src_enc = images.squeeze(2).transpose([0, 2, 1]) + + #-- Repeat data for beam search + n_bm = self.beam_size + n_inst, len_s, d_h = src_enc.shape + src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1) + src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose( + [1, 0, 2]) + #-- Prepare beams + inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)] + + #-- Bookkeeping for active or not + active_inst_idx_list = list(range(n_inst)) + inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( + active_inst_idx_list) + #-- Decode + for len_dec_seq in range(1, 25): + src_enc_copy = src_enc.clone() + active_inst_idx_list = beam_decode_step( + inst_dec_beams, len_dec_seq, src_enc_copy, + inst_idx_to_position_map, n_bm, None) + if not active_inst_idx_list: + break # all instances have finished their path to + src_enc, inst_idx_to_position_map = collate_active_info( + src_enc_copy, inst_idx_to_position_map, + active_inst_idx_list) + batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, + 1) + result_hyp = [] + for bs_hyp in batch_hyp: + bs_hyp_pad = bs_hyp[0] + [3] * (25 - len(bs_hyp[0])) + result_hyp.append(bs_hyp_pad) + return paddle.to_tensor(np.array(result_hyp), dtype=paddle.int64) + + def generate_square_subsequent_mask(self, sz): + """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + """ + mask = paddle.zeros([sz, sz], dtype='float32') + mask_inf = paddle.triu( + paddle.full( + shape=[sz, sz], dtype='float32', fill_value='-inf'), + diagonal=1) + mask = mask + mask_inf + return mask + + def generate_padding_mask(self, x): + padding_mask = x.equal(paddle.to_tensor(0, dtype=x.dtype)) + return padding_mask + + def _reset_parameters(self): + """Initiate parameters in the transformer model.""" + + for p in self.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + + +class TransformerEncoder(nn.Layer): + """TransformerEncoder is a stack of N encoder layers + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + """ + + def __init__(self, encoder_layer, num_layers): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + def forward(self, src): + """Pass the input through the endocder layers in turn. + Args: + src: the sequnce to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + """ + output = src + + for i in range(self.num_layers): + output = self.layers[i](output, + src_mask=None, + src_key_padding_mask=None) + + return output + + +class TransformerDecoder(nn.Layer): + """TransformerDecoder is a stack of N decoder layers + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + """ + + def __init__(self, decoder_layer, num_layers): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + + def forward(self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None): + """Pass the inputs (and mask) through the decoder layer in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequnce from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + """ + output = tgt + for i in range(self.num_layers): + output = self.layers[i]( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + + return output + + +class TransformerEncoderLayer(nn.Layer): + """TransformerEncoderLayer is made up of self-attn and feedforward network. + This standard encoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + + """ + + def __init__(self, + d_model, + nhead, + dim_feedforward=2048, + attention_dropout_rate=0.0, + residual_dropout_rate=0.1): + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, nhead, dropout=attention_dropout_rate) + + self.conv1 = Conv2D( + in_channels=d_model, + out_channels=dim_feedforward, + kernel_size=(1, 1)) + self.conv2 = Conv2D( + in_channels=dim_feedforward, + out_channels=d_model, + kernel_size=(1, 1)) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.dropout1 = Dropout(residual_dropout_rate) + self.dropout2 = Dropout(residual_dropout_rate) + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + """Pass the input through the endocder layer. + Args: + src: the sequnce to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + """ + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + + src = src.transpose([1, 2, 0]) + src = paddle.unsqueeze(src, 2) + src2 = self.conv2(F.relu(self.conv1(src))) + src2 = paddle.squeeze(src2, 2) + src2 = src2.transpose([2, 0, 1]) + src = paddle.squeeze(src, 2) + src = src.transpose([2, 0, 1]) + + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + +class TransformerDecoderLayer(nn.Layer): + """TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + + """ + + def __init__(self, + d_model, + nhead, + dim_feedforward=2048, + attention_dropout_rate=0.0, + residual_dropout_rate=0.1): + super(TransformerDecoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, nhead, dropout=attention_dropout_rate) + self.multihead_attn = MultiheadAttention( + d_model, nhead, dropout=attention_dropout_rate) + + self.conv1 = Conv2D( + in_channels=d_model, + out_channels=dim_feedforward, + kernel_size=(1, 1)) + self.conv2 = Conv2D( + in_channels=dim_feedforward, + out_channels=d_model, + kernel_size=(1, 1)) + + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.norm3 = LayerNorm(d_model) + self.dropout1 = Dropout(residual_dropout_rate) + self.dropout2 = Dropout(residual_dropout_rate) + self.dropout3 = Dropout(residual_dropout_rate) + + def forward(self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None): + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequnce from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + """ + tgt2 = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # default + tgt = tgt.transpose([1, 2, 0]) + tgt = paddle.unsqueeze(tgt, 2) + tgt2 = self.conv2(F.relu(self.conv1(tgt))) + tgt2 = paddle.squeeze(tgt2, 2) + tgt2 = tgt2.transpose([2, 0, 1]) + tgt = paddle.squeeze(tgt, 2) + tgt = tgt.transpose([2, 0, 1]) + + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + +def _get_clones(module, N): + return LayerList([copy.deepcopy(module) for i in range(N)]) + + +class PositionalEncoding(nn.Layer): + """Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, dropout, dim, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = paddle.zeros([max_len, dim]) + position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp( + paddle.arange(0, dim, 2).astype('float32') * + (-math.log(10000.0) / dim)) + pe[:, 0::2] = paddle.sin(position * div_term) + pe[:, 1::2] = paddle.cos(position * div_term) + pe = pe.unsqueeze(0) + pe = pe.transpose([1, 0, 2]) + self.register_buffer('pe', pe) + + def forward(self, x): + """Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + x = x + self.pe[:x.shape[0], :] + return self.dropout(x) + + +class PositionalEncoding_2d(nn.Layer): + """Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, dropout, dim, max_len=5000): + super(PositionalEncoding_2d, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = paddle.zeros([max_len, dim]) + position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp( + paddle.arange(0, dim, 2).astype('float32') * + (-math.log(10000.0) / dim)) + pe[:, 0::2] = paddle.sin(position * div_term) + pe[:, 1::2] = paddle.cos(position * div_term) + pe = pe.unsqueeze(0).transpose([1, 0, 2]) + self.register_buffer('pe', pe) + + self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1)) + self.linear1 = nn.Linear(dim, dim) + self.linear1.weight.data.fill_(1.) + self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1)) + self.linear2 = nn.Linear(dim, dim) + self.linear2.weight.data.fill_(1.) + + def forward(self, x): + """Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + w_pe = self.pe[:x.shape[-1], :] + w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0) + w_pe = w_pe * w1 + w_pe = w_pe.transpose([1, 2, 0]) + w_pe = w_pe.unsqueeze(2) + + h_pe = self.pe[:x.shape[-2], :] + w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0) + h_pe = h_pe * w2 + h_pe = h_pe.transpose([1, 2, 0]) + h_pe = h_pe.unsqueeze(3) + + x = x + w_pe + h_pe + x = x.reshape( + [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).transpose( + [2, 0, 1]) + + return self.dropout(x) + + +class Embeddings(nn.Layer): + def __init__(self, d_model, vocab, padding_idx, scale_embedding): + super(Embeddings, self).__init__() + self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx) + w0 = np.random.normal(0.0, d_model**-0.5, + (vocab, d_model)).astype(np.float32) + self.embedding.weight.set_value(w0) + self.d_model = d_model + self.scale_embedding = scale_embedding + + def forward(self, x): + if self.scale_embedding: + x = self.embedding(x) + return x * math.sqrt(self.d_model) + return self.embedding(x) + + +class Beam(): + ''' Beam search ''' + + def __init__(self, size, device=False): + + self.size = size + self._done = False + # The score for each translation on the beam. + self.scores = paddle.zeros((size, ), dtype=paddle.float32) + self.all_scores = [] + # The backpointers at each time-step. + self.prev_ks = [] + # The outputs at each time-step. + self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)] + self.next_ys[0][0] = 2 + + def get_current_state(self): + "Get the outputs for the current timestep." + return self.get_tentative_hypothesis() + + def get_current_origin(self): + "Get the backpointers for the current timestep." + return self.prev_ks[-1] + + @property + def done(self): + return self._done + + def advance(self, word_prob): + "Update beam status and check if finished or not." + num_words = word_prob.shape[1] + + # Sum the previous scores. + if len(self.prev_ks) > 0: + beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) + else: + beam_lk = word_prob[0] + + flat_beam_lk = beam_lk.reshape([-1]) + best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, + True) # 1st sort + self.all_scores.append(self.scores) + self.scores = best_scores + # bestScoresId is flattened as a (beam x word) array, + # so we need to calculate which word and beam each score came from + prev_k = best_scores_id // num_words + self.prev_ks.append(prev_k) + self.next_ys.append(best_scores_id - prev_k * num_words) + # End condition is when top-of-beam is EOS. + if self.next_ys[-1][0] == 3: + self._done = True + self.all_scores.append(self.scores) + + return self._done + + def sort_scores(self): + "Sort the scores." + return self.scores, paddle.to_tensor( + [i for i in range(self.scores.shape[0])], dtype='int32') + + def get_the_best_score_and_idx(self): + "Get the score of the best in the beam." + scores, ids = self.sort_scores() + return scores[1], ids[1] + + def get_tentative_hypothesis(self): + "Get the decoded sequence for the current timestep." + if len(self.next_ys) == 1: + dec_seq = self.next_ys[0].unsqueeze(1) + else: + _, keys = self.sort_scores() + hyps = [self.get_hypothesis(k) for k in keys] + hyps = [[2] + h for h in hyps] + dec_seq = paddle.to_tensor(hyps, dtype='int64') + return dec_seq + + def get_hypothesis(self, k): + """ Walk back to construct the full hypothesis. """ + hyp = [] + for j in range(len(self.prev_ks) - 1, -1, -1): + hyp.append(self.next_ys[j + 1][k]) + k = self.prev_ks[j][k] + return list(map(lambda x: x.item(), hyp[::-1])) diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 654ddf39d23590fbaf7f7b9b57f38cc86a1b6669..6159398770e796dc9f5a6dc97e3f845667de95f0 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -24,18 +24,16 @@ __all__ = ['build_post_process'] from .db_postprocess import DBPostProcess, DistillationDBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess -from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ +from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \ TableLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess - def build_post_process(config, global_config=None): support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', - 'DistillationCTCLabelDecode', 'TableLabelDecode', - 'DistillationDBPostProcess' + 'DistillationCTCLabelDecode', 'NRTRLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 8ebe5b2741b77537b46b8057d9aa9c36dc99aeec..9f23b5495f63a41283656ceaf9df76f96b8d1592 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -156,6 +156,69 @@ class DistillationCTCLabelDecode(CTCLabelDecode): return output +class NRTRLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='EN_symbol', + use_space_char=True, + **kwargs): + super(NRTRLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + if preds.dtype == paddle.int64: + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + if preds[0][0]==2: + preds_idx = preds[:,1:] + else: + preds_idx = preds + + text = self.decode(preds_idx) + if label is None: + return text + label = self.decode(label[:,1:]) + else: + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label[:,1:]) + return text, label + + def add_special_char(self, dict_character): + dict_character = ['blank','','',''] + dict_character + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] == 3: # end + break + try: + char_list.append(self.character[int(text_index[batch_idx][idx])]) + except: + continue + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text.lower(), np.mean(conf_list))) + return result_list + + + class AttnLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ @@ -193,8 +256,7 @@ class AttnLabelDecode(BaseRecLabelDecode): if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ batch_idx][idx]: continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) + char_list.append(self.character[int(text_index[batch_idx][idx])]) if text_prob is not None: conf_list.append(text_prob[batch_idx][idx]) else: diff --git a/tools/program.py b/tools/program.py index 595fe4cb96c0379b1a33504e0ebdd85e70086340..e7742a8f608db679ff0c71de007a819e8ba3567f 100755 --- a/tools/program.py +++ b/tools/program.py @@ -186,9 +186,11 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - try: + use_nrtr = config['Architecture']['algorithm'] == "NRTR" + + try: model_type = config['Architecture']['model_type'] - except: + except: model_type = None if 'start_epoch' in best_model_dict: @@ -213,7 +215,7 @@ def train(config, images = batch[0] if use_srn: model_average = True - if use_srn or model_type == 'table': + if use_srn or model_type == 'table' or use_nrtr: preds = model(images, data=batch[1:]) else: preds = model(images) @@ -398,7 +400,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet', 'Distillation', 'TableAttn' + 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'