diff --git a/configs/rec/rec_satrn.yml b/configs/rec/rec_satrn.yml new file mode 100644 index 0000000000000000000000000000000000000000..8ed688b65b75ab4fad5f3c06b58ec8e78bcf59fd --- /dev/null +++ b/configs/rec/rec_satrn.yml @@ -0,0 +1,117 @@ +Global: + use_gpu: true + epoch_num: 5 + log_smooth_window: 20 + print_batch_step: 50 + save_model_dir: ./output/rec/rec_satrn/ + save_epoch_step: 1 + # evaluation is run every 5000 iterations + eval_batch_step: [0, 5000] + cal_metric_during_train: False + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: + # for data or label process + character_dict_path: ppocr/utils/dict90.txt + max_text_length: 25 + infer_mode: False + use_space_char: False + rm_symbol: True + save_res_path: ./output/rec/predicts_satrn.txt + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs: [3, 4] + values: [0.0003, 0.00003, 0.000003] + regularizer: + name: 'L2' + factor: 0 + +Architecture: + model_type: rec + algorithm: SATRN + Backbone: + name: ShallowCNN + in_channels: 3 + hidden_dim: 256 + Head: + name: SATRNHead + enc_cfg: + n_layers: 6 + n_head: 8 + d_k: 32 + d_v: 32 + d_model: 256 + n_position: 100 + d_inner: 1024 + dropout: 0.1 + dec_cfg: + n_layers: 6 + d_embedding: 256 + n_head: 8 + d_model: 256 + d_inner: 1024 + d_k: 32 + d_v: 32 + max_seq_len: 25 + start_idx: 91 + +Loss: + name: SATRNLoss + +PostProcess: + name: SATRNLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SATRNLabelEncode: # Class handling label + - SVTRRecResizeImg: + image_shape: [3, 32, 100] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 128 + drop_last: True + num_workers: 8 + use_shared_memory: False + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SATRNLabelEncode: # Class handling label + - SVTRRecResizeImg: + image_shape: [3, 32, 100] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order + + loader: + shuffle: False + drop_last: False + batch_size_per_card: 128 + num_workers: 4 + use_shared_memory: False + diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index b64dc0767e6f1ac3e56e4bc8cda91bd0f77877dc..f2eb85f72b0697cfbd74e482ab61804e742c5497 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -886,6 +886,62 @@ class SARLabelEncode(BaseRecLabelEncode): return [self.padding_idx] +class SATRNLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + lower=False, + **kwargs): + super(SATRNLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + self.lower = lower + + def add_special_char(self, dict_character): + beg_end_str = "" + unknown_str = "" + padding_str = "" + dict_character = dict_character + [unknown_str] + self.unknown_idx = len(dict_character) - 1 + dict_character = dict_character + [beg_end_str] + self.start_idx = len(dict_character) - 1 + self.end_idx = len(dict_character) - 1 + dict_character = dict_character + [padding_str] + self.padding_idx = len(dict_character) - 1 + + return dict_character + + def encode(self, text): + if self.lower: + text = text.lower() + text_list = [] + for char in text: + text_list.append(self.dict.get(char, self.unknown_idx)) + if len(text_list) == 0: + return None + return text_list + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + data['length'] = np.array(len(text)) + target = [self.start_idx] + text + [self.end_idx] + padded_text = [self.padding_idx for _ in range(self.max_text_len)] + if len(target) > self.max_text_len: + padded_text = target[:self.max_text_len] + else: + padded_text[:len(target)] = target + data['label'] = np.array(padded_text) + return data + + def get_ignored_tokens(self): + return [self.padding_idx] + + class PRENLabelEncode(BaseRecLabelEncode): def __init__(self, max_text_length, diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index c7142e3e5e73e25764dde4631a47be939905e3be..3e86f18c64176833bb84d7c30e7c6c53485f0d31 100644 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -41,6 +41,7 @@ from .rec_vl_loss import VLLoss from .rec_spin_att_loss import SPINAttentionLoss from .rec_rfl_loss import RFLLoss from .rec_can_loss import CANLoss +from .rec_satrn_loss import SATRNLoss # cls loss from .cls_loss import ClsLoss @@ -73,7 +74,8 @@ def build_loss(config): 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss', - 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss' + 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss', + 'SATRNLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/rec_satrn_loss.py b/ppocr/losses/rec_satrn_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fc7b517878d5349154fa6a9c6e05fe6d45a00dd7 --- /dev/null +++ b/ppocr/losses/rec_satrn_loss.py @@ -0,0 +1,46 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/module_losses/ce_module_loss.py +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + + +class SATRNLoss(nn.Layer): + def __init__(self, **kwargs): + super(SATRNLoss, self).__init__() + ignore_index = kwargs.get('ignore_index', 92) # 6626 + self.loss_func = paddle.nn.loss.CrossEntropyLoss( + reduction="none", ignore_index=ignore_index) + + def forward(self, predicts, batch): + predict = predicts[:, : + -1, :] # ignore last index of outputs to be in same seq_len with targets + label = batch[1].astype( + "int64")[:, 1:] # ignore first index of target in loss calculation + batch_size, num_steps, num_classes = predict.shape[0], predict.shape[ + 1], predict.shape[2] + assert len(label.shape) == len(list(predict.shape)) - 1, \ + "The target's shape and inputs's shape is [N, d] and [N, num_steps]" + + inputs = paddle.reshape(predict, [-1, num_classes]) + targets = paddle.reshape(label, [-1]) + loss = self.loss_func(inputs, targets) + return {'loss': loss.mean()} diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index e2c2e9c4a4ed526b36d512d824ae8a8a701c17bc..e07595591b18eb3e9ff28e10bf77f912a36162fa 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -44,11 +44,12 @@ def build_backbone(config, model_type): from .rec_vitstr import ViTSTR from .rec_resnet_rfl import ResNetRFL from .rec_densenet import DenseNet + from .rec_shallow_cnn import ShallowCNN support_dict = [ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet', 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL', - 'DenseNet' + 'DenseNet', 'ShallowCNN' ] elif model_type == 'e2e': from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_shallow_cnn.py b/ppocr/modeling/backbones/rec_shallow_cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..544f108d26397421ae77ee025b15f31e319ab54c --- /dev/null +++ b/ppocr/modeling/backbones/rec_shallow_cnn.py @@ -0,0 +1,87 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/backbones/shallow_cnn.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import MaxPool2D +from paddle.nn.initializer import KaimingNormal, Uniform, Constant + + +class ConvBNLayer(nn.Layer): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + padding, + num_groups=1): + super(ConvBNLayer, self).__init__() + + self.conv = nn.Conv2D( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + weight_attr=ParamAttr(initializer=KaimingNormal()), + bias_attr=False) + + self.bn = nn.BatchNorm2D( + num_filters, + weight_attr=ParamAttr(initializer=Uniform(0, 1)), + bias_attr=ParamAttr(initializer=Constant(0))) + self.relu = nn.ReLU() + + def forward(self, inputs): + y = self.conv(inputs) + y = self.bn(y) + y = self.relu(y) + return y + + +class ShallowCNN(nn.Layer): + def __init__(self, in_channels=1, hidden_dim=512): + super().__init__() + assert isinstance(in_channels, int) + assert isinstance(hidden_dim, int) + + self.conv1 = ConvBNLayer( + in_channels, 3, hidden_dim // 2, stride=1, padding=1) + self.conv2 = ConvBNLayer( + hidden_dim // 2, 3, hidden_dim, stride=1, padding=1) + self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.out_channels = hidden_dim + + def forward(self, x): + + x = self.conv1(x) + x = self.pool(x) + + x = self.conv2(x) + x = self.pool(x) + + return x diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 65afaf84f4453f2d4199371576ac71bb93a1e6d5..794bc3a357c7f6efb314164be111fcc42ffab77e 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -40,6 +40,7 @@ def build_head(config): from .rec_visionlan_head import VLHead from .rec_rfl_head import RFLHead from .rec_can_head import CANHead + from .rec_satrn_head import SATRNHead # cls head from .cls_head import ClsHead @@ -56,7 +57,7 @@ def build_head(config): 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead', - 'DRRGHead', 'CANHead' + 'DRRGHead', 'CANHead', 'SATRNHead' ] if config['name'] == 'DRRGHead': diff --git a/ppocr/modeling/heads/rec_satrn_head.py b/ppocr/modeling/heads/rec_satrn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b969c89693b677489b7191a9120f16d02c322802 --- /dev/null +++ b/ppocr/modeling/heads/rec_satrn_head.py @@ -0,0 +1,568 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/encoders/satrn_encoder.py +https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/decoders/nrtr_decoder.py +""" + +import math +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr, reshape, transpose +from paddle.nn import Conv2D, BatchNorm, Linear, Dropout +from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D +from paddle.nn.initializer import KaimingNormal, Uniform, Constant + + +class ConvBNLayer(nn.Layer): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + padding, + num_groups=1): + super(ConvBNLayer, self).__init__() + + self.conv = nn.Conv2D( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + bias_attr=False) + + self.bn = nn.BatchNorm2D( + num_filters, + weight_attr=ParamAttr(initializer=Constant(1)), + bias_attr=ParamAttr(initializer=Constant(0))) + self.relu = nn.ReLU() + + def forward(self, inputs): + y = self.conv(inputs) + y = self.bn(y) + y = self.relu(y) + return y + + +class SATRNEncoderLayer(nn.Layer): + def __init__(self, + d_model=512, + d_inner=512, + n_head=8, + d_k=64, + d_v=64, + dropout=0.1, + qkv_bias=False): + super().__init__() + self.norm1 = nn.LayerNorm(d_model) + self.attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout) + self.norm2 = nn.LayerNorm(d_model) + self.feed_forward = LocalityAwareFeedforward( + d_model, d_inner, dropout=dropout) + + def forward(self, x, h, w, mask=None): + n, hw, c = x.shape + residual = x + x = self.norm1(x) + x = residual + self.attn(x, x, x, mask) + residual = x + x = self.norm2(x) + x = x.transpose([0, 2, 1]).reshape([n, c, h, w]) + x = self.feed_forward(x) + x = x.reshape([n, c, hw]).transpose([0, 2, 1]) + x = residual + x + return x + + +class LocalityAwareFeedforward(nn.Layer): + def __init__( + self, + d_in, + d_hid, + dropout=0.1, ): + super().__init__() + self.conv1 = ConvBNLayer(d_in, 1, d_hid, stride=1, padding=0) + + self.depthwise_conv = ConvBNLayer( + d_hid, 3, d_hid, stride=1, padding=1, num_groups=d_hid) + + self.conv2 = ConvBNLayer(d_hid, 1, d_in, stride=1, padding=0) + + def forward(self, x): + x = self.conv1(x) + x = self.depthwise_conv(x) + x = self.conv2(x) + + return x + + +class Adaptive2DPositionalEncoding(nn.Layer): + def __init__(self, d_hid=512, n_height=100, n_width=100, dropout=0.1): + super().__init__() + + h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid) + h_position_encoder = h_position_encoder.transpose([1, 0]) + h_position_encoder = h_position_encoder.reshape([1, d_hid, n_height, 1]) + + w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid) + w_position_encoder = w_position_encoder.transpose([1, 0]) + w_position_encoder = w_position_encoder.reshape([1, d_hid, 1, n_width]) + + self.register_buffer('h_position_encoder', h_position_encoder) + self.register_buffer('w_position_encoder', w_position_encoder) + + self.h_scale = self.scale_factor_generate(d_hid) + self.w_scale = self.scale_factor_generate(d_hid) + self.pool = nn.AdaptiveAvgPool2D(1) + self.dropout = nn.Dropout(p=dropout) + + def _get_sinusoid_encoding_table(self, n_position, d_hid): + """Sinusoid position encoding table.""" + denominator = paddle.to_tensor([ + 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ]) + denominator = denominator.reshape([1, -1]) + pos_tensor = paddle.cast( + paddle.arange(n_position).unsqueeze(-1), 'float32') + sinusoid_table = pos_tensor * denominator + sinusoid_table[:, 0::2] = paddle.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = paddle.cos(sinusoid_table[:, 1::2]) + + return sinusoid_table + + def scale_factor_generate(self, d_hid): + scale_factor = nn.Sequential( + nn.Conv2D(d_hid, d_hid, 1), + nn.ReLU(), nn.Conv2D(d_hid, d_hid, 1), nn.Sigmoid()) + + return scale_factor + + def forward(self, x): + b, c, h, w = x.shape + + avg_pool = self.pool(x) + + h_pos_encoding = \ + self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :] + w_pos_encoding = \ + self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w] + + out = x + h_pos_encoding + w_pos_encoding + + out = self.dropout(out) + + return out + + +class ScaledDotProductAttention(nn.Layer): + def __init__(self, temperature, attn_dropout=0.1): + super().__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + + def forward(self, q, k, v, mask=None): + def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + attn = paddle.matmul(q / self.temperature, k.transpose([0, 1, 3, 2])) + if mask is not None: + attn = masked_fill(attn, mask == 0, -1e9) + # attn = attn.masked_fill(mask == 0, float('-inf')) + # attn += mask + + attn = self.dropout(F.softmax(attn, axis=-1)) + output = paddle.matmul(attn, v) + + return output, attn + + +class MultiHeadAttention(nn.Layer): + def __init__(self, + n_head=8, + d_model=512, + d_k=64, + d_v=64, + dropout=0.1, + qkv_bias=False): + super().__init__() + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.dim_k = n_head * d_k + self.dim_v = n_head * d_v + + self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias_attr=qkv_bias) + self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias_attr=qkv_bias) + self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias_attr=qkv_bias) + + self.attention = ScaledDotProductAttention(d_k**0.5, dropout) + + self.fc = nn.Linear(self.dim_v, d_model, bias_attr=qkv_bias) + self.proj_drop = nn.Dropout(dropout) + + def forward(self, q, k, v, mask=None): + batch_size, len_q, _ = q.shape + _, len_k, _ = k.shape + + q = self.linear_q(q).reshape([batch_size, len_q, self.n_head, self.d_k]) + k = self.linear_k(k).reshape([batch_size, len_k, self.n_head, self.d_k]) + v = self.linear_v(v).reshape([batch_size, len_k, self.n_head, self.d_v]) + + q, k, v = q.transpose([0, 2, 1, 3]), k.transpose( + [0, 2, 1, 3]), v.transpose([0, 2, 1, 3]) + + if mask is not None: + if mask.dim() == 3: + mask = mask.unsqueeze(1) + elif mask.dim() == 2: + mask = mask.unsqueeze(1).unsqueeze(1) + + attn_out, _ = self.attention(q, k, v, mask=mask) + + attn_out = attn_out.transpose([0, 2, 1, 3]).reshape( + [batch_size, len_q, self.dim_v]) + + attn_out = self.fc(attn_out) + attn_out = self.proj_drop(attn_out) + + return attn_out + + +class SATRNEncoder(nn.Layer): + def __init__(self, + n_layers=12, + n_head=8, + d_k=64, + d_v=64, + d_model=512, + n_position=100, + d_inner=256, + dropout=0.1): + super().__init__() + self.d_model = d_model + self.position_enc = Adaptive2DPositionalEncoding( + d_hid=d_model, + n_height=n_position, + n_width=n_position, + dropout=dropout) + self.layer_stack = nn.LayerList([ + SATRNEncoderLayer( + d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers) + ]) + self.layer_norm = nn.LayerNorm(d_model) + + def forward(self, feat, valid_ratios=None): + """ + Args: + feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`. + img_metas (dict): A dict that contains meta information of input + images. Preferably with the key ``valid_ratio``. + + Returns: + Tensor: A tensor of shape :math:`(N, T, D_m)`. + """ + if valid_ratios is None: + valid_ratios = [1.0 for _ in range(feat.shape[0])] + feat = self.position_enc(feat) + n, c, h, w = feat.shape + + mask = paddle.zeros((n, h, w)) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, math.ceil(w * valid_ratio)) + mask[i, :, :valid_width] = 1 + + mask = mask.reshape([n, h * w]) + feat = feat.reshape([n, c, h * w]) + + output = feat.transpose([0, 2, 1]) + for enc_layer in self.layer_stack: + output = enc_layer(output, h, w, mask) + output = self.layer_norm(output) + + return output + + +class PositionwiseFeedForward(nn.Layer): + def __init__(self, d_in, d_hid, dropout=0.1): + super().__init__() + self.w_1 = nn.Linear(d_in, d_hid) + self.w_2 = nn.Linear(d_hid, d_in) + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.w_1(x) + x = self.act(x) + x = self.w_2(x) + x = self.dropout(x) + + return x + + +class PositionalEncoding(nn.Layer): + def __init__(self, d_hid=512, n_position=200, dropout=0): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + # Not a parameter + # Position table of shape (1, n_position, d_hid) + self.register_buffer( + 'position_table', + self._get_sinusoid_encoding_table(n_position, d_hid)) + + def _get_sinusoid_encoding_table(self, n_position, d_hid): + """Sinusoid position encoding table.""" + denominator = paddle.to_tensor([ + 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ]) + denominator = denominator.reshape([1, -1]) + pos_tensor = paddle.cast( + paddle.arange(n_position).unsqueeze(-1), 'float32') + sinusoid_table = pos_tensor * denominator + sinusoid_table[:, 0::2] = paddle.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = paddle.cos(sinusoid_table[:, 1::2]) + + return sinusoid_table.unsqueeze(0) + + def forward(self, x): + + x = x + self.position_table[:, :x.shape[1]].clone().detach() + return self.dropout(x) + + +class TFDecoderLayer(nn.Layer): + def __init__(self, + d_model=512, + d_inner=256, + n_head=8, + d_k=64, + d_v=64, + dropout=0.1, + qkv_bias=False, + operation_order=None): + super().__init__() + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + + self.self_attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias) + + self.enc_attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias) + + self.mlp = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) + + self.operation_order = operation_order + if self.operation_order is None: + self.operation_order = ('norm', 'self_attn', 'norm', 'enc_dec_attn', + 'norm', 'ffn') + assert self.operation_order in [ + ('norm', 'self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn'), + ('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', 'norm') + ] + + def forward(self, + dec_input, + enc_output, + self_attn_mask=None, + dec_enc_attn_mask=None): + if self.operation_order == ('self_attn', 'norm', 'enc_dec_attn', 'norm', + 'ffn', 'norm'): + dec_attn_out = self.self_attn(dec_input, dec_input, dec_input, + self_attn_mask) + dec_attn_out += dec_input + dec_attn_out = self.norm1(dec_attn_out) + + enc_dec_attn_out = self.enc_attn(dec_attn_out, enc_output, + enc_output, dec_enc_attn_mask) + enc_dec_attn_out += dec_attn_out + enc_dec_attn_out = self.norm2(enc_dec_attn_out) + + mlp_out = self.mlp(enc_dec_attn_out) + mlp_out += enc_dec_attn_out + mlp_out = self.norm3(mlp_out) + elif self.operation_order == ('norm', 'self_attn', 'norm', + 'enc_dec_attn', 'norm', 'ffn'): + dec_input_norm = self.norm1(dec_input) + dec_attn_out = self.self_attn(dec_input_norm, dec_input_norm, + dec_input_norm, self_attn_mask) + dec_attn_out += dec_input + + enc_dec_attn_in = self.norm2(dec_attn_out) + enc_dec_attn_out = self.enc_attn(enc_dec_attn_in, enc_output, + enc_output, dec_enc_attn_mask) + enc_dec_attn_out += dec_attn_out + + mlp_out = self.mlp(self.norm3(enc_dec_attn_out)) + mlp_out += enc_dec_attn_out + + return mlp_out + + +class SATRNDecoder(nn.Layer): + def __init__(self, + n_layers=6, + d_embedding=512, + n_head=8, + d_k=64, + d_v=64, + d_model=512, + d_inner=256, + n_position=200, + dropout=0.1, + num_classes=93, + max_seq_len=40, + start_idx=1, + padding_idx=92): + super().__init__() + + self.padding_idx = padding_idx + self.start_idx = start_idx + self.max_seq_len = max_seq_len + + self.trg_word_emb = nn.Embedding( + num_classes, d_embedding, padding_idx=padding_idx) + + self.position_enc = PositionalEncoding( + d_embedding, n_position=n_position) + self.dropout = nn.Dropout(p=dropout) + + self.layer_stack = nn.LayerList([ + TFDecoderLayer( + d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers) + ]) + self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6) + + pred_num_class = num_classes - 1 # ignore padding_idx + self.classifier = nn.Linear(d_model, pred_num_class) + + @staticmethod + def get_pad_mask(seq, pad_idx): + + return (seq != pad_idx).unsqueeze(-2) + + @staticmethod + def get_subsequent_mask(seq): + """For masking out the subsequent info.""" + len_s = seq.shape[1] + subsequent_mask = 1 - paddle.triu( + paddle.ones((len_s, len_s)), diagonal=1) + subsequent_mask = paddle.cast(subsequent_mask.unsqueeze(0), 'bool') + + return subsequent_mask + + def _attention(self, trg_seq, src, src_mask=None): + trg_embedding = self.trg_word_emb(trg_seq) + trg_pos_encoded = self.position_enc(trg_embedding) + tgt = self.dropout(trg_pos_encoded) + + trg_mask = self.get_pad_mask( + trg_seq, + pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq) + output = tgt + for dec_layer in self.layer_stack: + output = dec_layer( + output, + src, + self_attn_mask=trg_mask, + dec_enc_attn_mask=src_mask) + output = self.layer_norm(output) + + return output + + def _get_mask(self, logit, valid_ratios): + N, T, _ = logit.shape + mask = None + if valid_ratios is not None: + mask = paddle.zeros((N, T)) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(T, math.ceil(T * valid_ratio)) + mask[i, :valid_width] = 1 + + return mask + + def forward_train(self, feat, out_enc, targets, valid_ratio): + src_mask = self._get_mask(out_enc, valid_ratio) + attn_output = self._attention(targets, out_enc, src_mask=src_mask) + outputs = self.classifier(attn_output) + + return outputs + + def forward_test(self, feat, out_enc, valid_ratio): + + src_mask = self._get_mask(out_enc, valid_ratio) + N = out_enc.shape[0] + init_target_seq = paddle.full( + (N, self.max_seq_len + 1), self.padding_idx, dtype='int64') + # bsz * seq_len + init_target_seq[:, 0] = self.start_idx + + outputs = [] + for step in range(0, paddle.to_tensor(self.max_seq_len)): + decoder_output = self._attention( + init_target_seq, out_enc, src_mask=src_mask) + # bsz * seq_len * C + step_result = F.softmax( + self.classifier(decoder_output[:, step, :]), axis=-1) + # bsz * num_classes + outputs.append(step_result) + step_max_index = paddle.argmax(step_result, axis=-1) + init_target_seq[:, step + 1] = step_max_index + + outputs = paddle.stack(outputs, axis=1) + + return outputs + + def forward(self, feat, out_enc, targets=None, valid_ratio=None): + if self.training: + return self.forward_train(feat, out_enc, targets, valid_ratio) + else: + return self.forward_test(feat, out_enc, valid_ratio) + + +class SATRNHead(nn.Layer): + def __init__(self, enc_cfg, dec_cfg, **kwargs): + super(SATRNHead, self).__init__() + + # encoder module + self.encoder = SATRNEncoder(**enc_cfg) + + # decoder module + self.decoder = SATRNDecoder(**dec_cfg) + + def forward(self, feat, targets=None): + + if targets is not None: + targets, valid_ratio = targets + else: + targets, valid_ratio = None, None + holistic_feat = self.encoder(feat, valid_ratio) # bsz c + + final_out = self.decoder(feat, holistic_feat, targets, valid_ratio) + + return final_out diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 36a3152f2f2d68ed0884bd415844d209d850f5ca..c89345e70b3dcf22b292ebf1250bf3f258a3355c 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \ SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \ - SPINLabelDecode, VLLabelDecode, RFLLabelDecode + SPINLabelDecode, VLLabelDecode, RFLLabelDecode, SATRNLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess @@ -52,7 +52,8 @@ def build_post_process(config, global_config=None): 'TableMasterLabelDecode', 'SPINLabelDecode', 'DistillationSerPostProcess', 'DistillationRePostProcess', 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess', - 'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode' + 'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode', + 'SATRNLabelDecode' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 5547f005759b6585252c62d1e8a9d468e9880a2f..347661b7d2dd5479aa077441da3bf2ecb01e2e13 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -568,6 +568,82 @@ class SARLabelDecode(BaseRecLabelDecode): return [self.padding_idx] +class SATRNLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(SATRNLabelDecode, self).__init__(character_dict_path, + use_space_char) + + self.rm_symbol = kwargs.get('rm_symbol', False) + + def add_special_char(self, dict_character): + beg_end_str = "" + unknown_str = "" + padding_str = "" + dict_character = dict_character + [unknown_str] + self.unknown_idx = len(dict_character) - 1 + dict_character = dict_character + [beg_end_str] + self.start_idx = len(dict_character) - 1 + self.end_idx = len(dict_character) - 1 + dict_character = dict_character + [padding_str] + self.padding_idx = len(dict_character) - 1 + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if int(text_index[batch_idx][idx]) == int(self.end_idx): + if text_prob is None and idx == 0: + continue + else: + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + if self.rm_symbol: + comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]') + text = text.lower() + text = comp.sub('', text) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def get_ignored_tokens(self): + return [self.padding_idx] + + class DistillationSARLabelDecode(SARLabelDecode): """ Convert diff --git a/tools/export_model.py b/tools/export_model.py index 4b90fcae435619a53a3def8cc4dc46b4e2963bff..9d3bf5629e639bc0a7112090cd18e4cb57bd55d0 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -105,6 +105,12 @@ def export_single_model(model, shape=[None, 1, 32, 100], dtype="float32"), ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == 'SATRN': + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 32, 100], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] == "VisionLAN": other_shape = [ paddle.static.InputSpec( diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index d3231545a99cb42d71ed3f78a0f915825c3f821f..1c7e03c71d1ec793755e9ccbfd5057bd70f83396 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -106,6 +106,13 @@ class TextRecognizer(object): "character_dict_path": None, "use_space_char": args.use_space_char } + elif self.rec_algorithm == "SATRN": + postprocess_params = { + 'name': 'SATRNLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char, + "rm_symbol": True + } elif self.rec_algorithm == "PREN": postprocess_params = {'name': 'PRENLabelDecode'} elif self.rec_algorithm == "CAN": @@ -429,7 +436,7 @@ class TextRecognizer(object): gsrm_slf_attn_bias1_list.append(norm_img[3]) gsrm_slf_attn_bias2_list.append(norm_img[4]) norm_img_batch.append(norm_img[0]) - elif self.rec_algorithm == "SVTR": + elif self.rec_algorithm in ["SVTR", "SATRN"]: norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], self.rec_image_shape) norm_img = norm_img[np.newaxis, :] diff --git a/tools/program.py b/tools/program.py index afb8a47254b9847e4a4d432b7f17902c3ee78725..9134472b83f785bbabcf44dc0675c0e0d9a08fc9 100755 --- a/tools/program.py +++ b/tools/program.py @@ -220,7 +220,7 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" extra_input_models = [ "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", - "RobustScanner", "RFL", 'DRRG' + "RobustScanner", "RFL", 'DRRG', 'SATRN' ] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': @@ -643,7 +643,7 @@ def preprocess(is_train=False): 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', 'CAN', - 'Telescope' + 'Telescope', 'SATRN' ] if use_xpu: