未验证 提交 30201ef9 编写于 作者: Z zhiminzhang0830 提交者: GitHub

add satrn (#8433)

* add satrn

* 修复satrn导出问题

* 规范satrn config文件

* 删除SATRNRecResizeImg

---------
Co-authored-by: Nzhiminzhang0830 <zhangzhimin04@baidu.com>
上级 3ded6010
Global:
use_gpu: true
epoch_num: 5
log_smooth_window: 20
print_batch_step: 50
save_model_dir: ./output/rec/rec_satrn/
save_epoch_step: 1
# evaluation is run every 5000 iterations
eval_batch_step: [0, 5000]
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img:
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
max_text_length: 25
infer_mode: False
use_space_char: False
rm_symbol: True
save_res_path: ./output/rec/predicts_satrn.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs: [3, 4]
values: [0.0003, 0.00003, 0.000003]
regularizer:
name: 'L2'
factor: 0
Architecture:
model_type: rec
algorithm: SATRN
Backbone:
name: ShallowCNN
in_channels: 3
hidden_dim: 256
Head:
name: SATRNHead
enc_cfg:
n_layers: 6
n_head: 8
d_k: 32
d_v: 32
d_model: 256
n_position: 100
d_inner: 1024
dropout: 0.1
dec_cfg:
n_layers: 6
d_embedding: 256
n_head: 8
d_model: 256
d_inner: 1024
d_k: 32
d_v: 32
max_seq_len: 25
start_idx: 91
Loss:
name: SATRNLoss
PostProcess:
name: SATRNLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SATRNLabelEncode: # Class handling label
- SVTRRecResizeImg:
image_shape: [3, 32, 100]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 128
drop_last: True
num_workers: 8
use_shared_memory: False
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SATRNLabelEncode: # Class handling label
- SVTRRecResizeImg:
image_shape: [3, 32, 100]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 128
num_workers: 4
use_shared_memory: False
...@@ -886,6 +886,62 @@ class SARLabelEncode(BaseRecLabelEncode): ...@@ -886,6 +886,62 @@ class SARLabelEncode(BaseRecLabelEncode):
return [self.padding_idx] return [self.padding_idx]
class SATRNLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
lower=False,
**kwargs):
super(SATRNLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
self.lower = lower
def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
unknown_str = "<UKN>"
padding_str = "<PAD>"
dict_character = dict_character + [unknown_str]
self.unknown_idx = len(dict_character) - 1
dict_character = dict_character + [beg_end_str]
self.start_idx = len(dict_character) - 1
self.end_idx = len(dict_character) - 1
dict_character = dict_character + [padding_str]
self.padding_idx = len(dict_character) - 1
return dict_character
def encode(self, text):
if self.lower:
text = text.lower()
text_list = []
for char in text:
text_list.append(self.dict.get(char, self.unknown_idx))
if len(text_list) == 0:
return None
return text_list
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
data['length'] = np.array(len(text))
target = [self.start_idx] + text + [self.end_idx]
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
if len(target) > self.max_text_len:
padded_text = target[:self.max_text_len]
else:
padded_text[:len(target)] = target
data['label'] = np.array(padded_text)
return data
def get_ignored_tokens(self):
return [self.padding_idx]
class PRENLabelEncode(BaseRecLabelEncode): class PRENLabelEncode(BaseRecLabelEncode):
def __init__(self, def __init__(self,
max_text_length, max_text_length,
......
...@@ -41,6 +41,7 @@ from .rec_vl_loss import VLLoss ...@@ -41,6 +41,7 @@ from .rec_vl_loss import VLLoss
from .rec_spin_att_loss import SPINAttentionLoss from .rec_spin_att_loss import SPINAttentionLoss
from .rec_rfl_loss import RFLLoss from .rec_rfl_loss import RFLLoss
from .rec_can_loss import CANLoss from .rec_can_loss import CANLoss
from .rec_satrn_loss import SATRNLoss
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
...@@ -73,7 +74,8 @@ def build_loss(config): ...@@ -73,7 +74,8 @@ def build_loss(config):
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss', 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss' 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss',
'SATRNLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/module_losses/ce_module_loss.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class SATRNLoss(nn.Layer):
def __init__(self, **kwargs):
super(SATRNLoss, self).__init__()
ignore_index = kwargs.get('ignore_index', 92) # 6626
self.loss_func = paddle.nn.loss.CrossEntropyLoss(
reduction="none", ignore_index=ignore_index)
def forward(self, predicts, batch):
predict = predicts[:, :
-1, :] # ignore last index of outputs to be in same seq_len with targets
label = batch[1].astype(
"int64")[:, 1:] # ignore first index of target in loss calculation
batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
1], predict.shape[2]
assert len(label.shape) == len(list(predict.shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = paddle.reshape(predict, [-1, num_classes])
targets = paddle.reshape(label, [-1])
loss = self.loss_func(inputs, targets)
return {'loss': loss.mean()}
...@@ -44,11 +44,12 @@ def build_backbone(config, model_type): ...@@ -44,11 +44,12 @@ def build_backbone(config, model_type):
from .rec_vitstr import ViTSTR from .rec_vitstr import ViTSTR
from .rec_resnet_rfl import ResNetRFL from .rec_resnet_rfl import ResNetRFL
from .rec_densenet import DenseNet from .rec_densenet import DenseNet
from .rec_shallow_cnn import ShallowCNN
support_dict = [ support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet', 'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL', 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
'DenseNet' 'DenseNet', 'ShallowCNN'
] ]
elif model_type == 'e2e': elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet from .e2e_resnet_vd_pg import ResNet
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/backbones/shallow_cnn.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import MaxPool2D
from paddle.nn.initializer import KaimingNormal, Uniform, Constant
class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
num_groups=1):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
self.bn = nn.BatchNorm2D(
num_filters,
weight_attr=ParamAttr(initializer=Uniform(0, 1)),
bias_attr=ParamAttr(initializer=Constant(0)))
self.relu = nn.ReLU()
def forward(self, inputs):
y = self.conv(inputs)
y = self.bn(y)
y = self.relu(y)
return y
class ShallowCNN(nn.Layer):
def __init__(self, in_channels=1, hidden_dim=512):
super().__init__()
assert isinstance(in_channels, int)
assert isinstance(hidden_dim, int)
self.conv1 = ConvBNLayer(
in_channels, 3, hidden_dim // 2, stride=1, padding=1)
self.conv2 = ConvBNLayer(
hidden_dim // 2, 3, hidden_dim, stride=1, padding=1)
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = hidden_dim
def forward(self, x):
x = self.conv1(x)
x = self.pool(x)
x = self.conv2(x)
x = self.pool(x)
return x
...@@ -40,6 +40,7 @@ def build_head(config): ...@@ -40,6 +40,7 @@ def build_head(config):
from .rec_visionlan_head import VLHead from .rec_visionlan_head import VLHead
from .rec_rfl_head import RFLHead from .rec_rfl_head import RFLHead
from .rec_can_head import CANHead from .rec_can_head import CANHead
from .rec_satrn_head import SATRNHead
# cls head # cls head
from .cls_head import ClsHead from .cls_head import ClsHead
...@@ -56,7 +57,7 @@ def build_head(config): ...@@ -56,7 +57,7 @@ def build_head(config):
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead', 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
'DRRGHead', 'CANHead' 'DRRGHead', 'CANHead', 'SATRNHead'
] ]
if config['name'] == 'DRRGHead': if config['name'] == 'DRRGHead':
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/encoders/satrn_encoder.py
https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/decoders/nrtr_decoder.py
"""
import math
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr, reshape, transpose
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import KaimingNormal, Uniform, Constant
class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
num_groups=1):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
bias_attr=False)
self.bn = nn.BatchNorm2D(
num_filters,
weight_attr=ParamAttr(initializer=Constant(1)),
bias_attr=ParamAttr(initializer=Constant(0)))
self.relu = nn.ReLU()
def forward(self, inputs):
y = self.conv(inputs)
y = self.bn(y)
y = self.relu(y)
return y
class SATRNEncoderLayer(nn.Layer):
def __init__(self,
d_model=512,
d_inner=512,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout)
self.norm2 = nn.LayerNorm(d_model)
self.feed_forward = LocalityAwareFeedforward(
d_model, d_inner, dropout=dropout)
def forward(self, x, h, w, mask=None):
n, hw, c = x.shape
residual = x
x = self.norm1(x)
x = residual + self.attn(x, x, x, mask)
residual = x
x = self.norm2(x)
x = x.transpose([0, 2, 1]).reshape([n, c, h, w])
x = self.feed_forward(x)
x = x.reshape([n, c, hw]).transpose([0, 2, 1])
x = residual + x
return x
class LocalityAwareFeedforward(nn.Layer):
def __init__(
self,
d_in,
d_hid,
dropout=0.1, ):
super().__init__()
self.conv1 = ConvBNLayer(d_in, 1, d_hid, stride=1, padding=0)
self.depthwise_conv = ConvBNLayer(
d_hid, 3, d_hid, stride=1, padding=1, num_groups=d_hid)
self.conv2 = ConvBNLayer(d_hid, 1, d_in, stride=1, padding=0)
def forward(self, x):
x = self.conv1(x)
x = self.depthwise_conv(x)
x = self.conv2(x)
return x
class Adaptive2DPositionalEncoding(nn.Layer):
def __init__(self, d_hid=512, n_height=100, n_width=100, dropout=0.1):
super().__init__()
h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid)
h_position_encoder = h_position_encoder.transpose([1, 0])
h_position_encoder = h_position_encoder.reshape([1, d_hid, n_height, 1])
w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid)
w_position_encoder = w_position_encoder.transpose([1, 0])
w_position_encoder = w_position_encoder.reshape([1, d_hid, 1, n_width])
self.register_buffer('h_position_encoder', h_position_encoder)
self.register_buffer('w_position_encoder', w_position_encoder)
self.h_scale = self.scale_factor_generate(d_hid)
self.w_scale = self.scale_factor_generate(d_hid)
self.pool = nn.AdaptiveAvgPool2D(1)
self.dropout = nn.Dropout(p=dropout)
def _get_sinusoid_encoding_table(self, n_position, d_hid):
"""Sinusoid position encoding table."""
denominator = paddle.to_tensor([
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
])
denominator = denominator.reshape([1, -1])
pos_tensor = paddle.cast(
paddle.arange(n_position).unsqueeze(-1), 'float32')
sinusoid_table = pos_tensor * denominator
sinusoid_table[:, 0::2] = paddle.sin(sinusoid_table[:, 0::2])
sinusoid_table[:, 1::2] = paddle.cos(sinusoid_table[:, 1::2])
return sinusoid_table
def scale_factor_generate(self, d_hid):
scale_factor = nn.Sequential(
nn.Conv2D(d_hid, d_hid, 1),
nn.ReLU(), nn.Conv2D(d_hid, d_hid, 1), nn.Sigmoid())
return scale_factor
def forward(self, x):
b, c, h, w = x.shape
avg_pool = self.pool(x)
h_pos_encoding = \
self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
w_pos_encoding = \
self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
out = x + h_pos_encoding + w_pos_encoding
out = self.dropout(out)
return out
class ScaledDotProductAttention(nn.Layer):
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
def forward(self, q, k, v, mask=None):
def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
attn = paddle.matmul(q / self.temperature, k.transpose([0, 1, 3, 2]))
if mask is not None:
attn = masked_fill(attn, mask == 0, -1e9)
# attn = attn.masked_fill(mask == 0, float('-inf'))
# attn += mask
attn = self.dropout(F.softmax(attn, axis=-1))
output = paddle.matmul(attn, v)
return output, attn
class MultiHeadAttention(nn.Layer):
def __init__(self,
n_head=8,
d_model=512,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.dim_k = n_head * d_k
self.dim_v = n_head * d_v
self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias_attr=qkv_bias)
self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias_attr=qkv_bias)
self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias_attr=qkv_bias)
self.attention = ScaledDotProductAttention(d_k**0.5, dropout)
self.fc = nn.Linear(self.dim_v, d_model, bias_attr=qkv_bias)
self.proj_drop = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
batch_size, len_q, _ = q.shape
_, len_k, _ = k.shape
q = self.linear_q(q).reshape([batch_size, len_q, self.n_head, self.d_k])
k = self.linear_k(k).reshape([batch_size, len_k, self.n_head, self.d_k])
v = self.linear_v(v).reshape([batch_size, len_k, self.n_head, self.d_v])
q, k, v = q.transpose([0, 2, 1, 3]), k.transpose(
[0, 2, 1, 3]), v.transpose([0, 2, 1, 3])
if mask is not None:
if mask.dim() == 3:
mask = mask.unsqueeze(1)
elif mask.dim() == 2:
mask = mask.unsqueeze(1).unsqueeze(1)
attn_out, _ = self.attention(q, k, v, mask=mask)
attn_out = attn_out.transpose([0, 2, 1, 3]).reshape(
[batch_size, len_q, self.dim_v])
attn_out = self.fc(attn_out)
attn_out = self.proj_drop(attn_out)
return attn_out
class SATRNEncoder(nn.Layer):
def __init__(self,
n_layers=12,
n_head=8,
d_k=64,
d_v=64,
d_model=512,
n_position=100,
d_inner=256,
dropout=0.1):
super().__init__()
self.d_model = d_model
self.position_enc = Adaptive2DPositionalEncoding(
d_hid=d_model,
n_height=n_position,
n_width=n_position,
dropout=dropout)
self.layer_stack = nn.LayerList([
SATRNEncoderLayer(
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)
])
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, feat, valid_ratios=None):
"""
Args:
feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
img_metas (dict): A dict that contains meta information of input
images. Preferably with the key ``valid_ratio``.
Returns:
Tensor: A tensor of shape :math:`(N, T, D_m)`.
"""
if valid_ratios is None:
valid_ratios = [1.0 for _ in range(feat.shape[0])]
feat = self.position_enc(feat)
n, c, h, w = feat.shape
mask = paddle.zeros((n, h, w))
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, math.ceil(w * valid_ratio))
mask[i, :, :valid_width] = 1
mask = mask.reshape([n, h * w])
feat = feat.reshape([n, c, h * w])
output = feat.transpose([0, 2, 1])
for enc_layer in self.layer_stack:
output = enc_layer(output, h, w, mask)
output = self.layer_norm(output)
return output
class PositionwiseFeedForward(nn.Layer):
def __init__(self, d_in, d_hid, dropout=0.1):
super().__init__()
self.w_1 = nn.Linear(d_in, d_hid)
self.w_2 = nn.Linear(d_hid, d_in)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.w_1(x)
x = self.act(x)
x = self.w_2(x)
x = self.dropout(x)
return x
class PositionalEncoding(nn.Layer):
def __init__(self, d_hid=512, n_position=200, dropout=0):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# Not a parameter
# Position table of shape (1, n_position, d_hid)
self.register_buffer(
'position_table',
self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
"""Sinusoid position encoding table."""
denominator = paddle.to_tensor([
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
])
denominator = denominator.reshape([1, -1])
pos_tensor = paddle.cast(
paddle.arange(n_position).unsqueeze(-1), 'float32')
sinusoid_table = pos_tensor * denominator
sinusoid_table[:, 0::2] = paddle.sin(sinusoid_table[:, 0::2])
sinusoid_table[:, 1::2] = paddle.cos(sinusoid_table[:, 1::2])
return sinusoid_table.unsqueeze(0)
def forward(self, x):
x = x + self.position_table[:, :x.shape[1]].clone().detach()
return self.dropout(x)
class TFDecoderLayer(nn.Layer):
def __init__(self,
d_model=512,
d_inner=256,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False,
operation_order=None):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.self_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
self.enc_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias)
self.mlp = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
self.operation_order = operation_order
if self.operation_order is None:
self.operation_order = ('norm', 'self_attn', 'norm', 'enc_dec_attn',
'norm', 'ffn')
assert self.operation_order in [
('norm', 'self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn'),
('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', 'norm')
]
def forward(self,
dec_input,
enc_output,
self_attn_mask=None,
dec_enc_attn_mask=None):
if self.operation_order == ('self_attn', 'norm', 'enc_dec_attn', 'norm',
'ffn', 'norm'):
dec_attn_out = self.self_attn(dec_input, dec_input, dec_input,
self_attn_mask)
dec_attn_out += dec_input
dec_attn_out = self.norm1(dec_attn_out)
enc_dec_attn_out = self.enc_attn(dec_attn_out, enc_output,
enc_output, dec_enc_attn_mask)
enc_dec_attn_out += dec_attn_out
enc_dec_attn_out = self.norm2(enc_dec_attn_out)
mlp_out = self.mlp(enc_dec_attn_out)
mlp_out += enc_dec_attn_out
mlp_out = self.norm3(mlp_out)
elif self.operation_order == ('norm', 'self_attn', 'norm',
'enc_dec_attn', 'norm', 'ffn'):
dec_input_norm = self.norm1(dec_input)
dec_attn_out = self.self_attn(dec_input_norm, dec_input_norm,
dec_input_norm, self_attn_mask)
dec_attn_out += dec_input
enc_dec_attn_in = self.norm2(dec_attn_out)
enc_dec_attn_out = self.enc_attn(enc_dec_attn_in, enc_output,
enc_output, dec_enc_attn_mask)
enc_dec_attn_out += dec_attn_out
mlp_out = self.mlp(self.norm3(enc_dec_attn_out))
mlp_out += enc_dec_attn_out
return mlp_out
class SATRNDecoder(nn.Layer):
def __init__(self,
n_layers=6,
d_embedding=512,
n_head=8,
d_k=64,
d_v=64,
d_model=512,
d_inner=256,
n_position=200,
dropout=0.1,
num_classes=93,
max_seq_len=40,
start_idx=1,
padding_idx=92):
super().__init__()
self.padding_idx = padding_idx
self.start_idx = start_idx
self.max_seq_len = max_seq_len
self.trg_word_emb = nn.Embedding(
num_classes, d_embedding, padding_idx=padding_idx)
self.position_enc = PositionalEncoding(
d_embedding, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
self.layer_stack = nn.LayerList([
TFDecoderLayer(
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)
])
self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
pred_num_class = num_classes - 1 # ignore padding_idx
self.classifier = nn.Linear(d_model, pred_num_class)
@staticmethod
def get_pad_mask(seq, pad_idx):
return (seq != pad_idx).unsqueeze(-2)
@staticmethod
def get_subsequent_mask(seq):
"""For masking out the subsequent info."""
len_s = seq.shape[1]
subsequent_mask = 1 - paddle.triu(
paddle.ones((len_s, len_s)), diagonal=1)
subsequent_mask = paddle.cast(subsequent_mask.unsqueeze(0), 'bool')
return subsequent_mask
def _attention(self, trg_seq, src, src_mask=None):
trg_embedding = self.trg_word_emb(trg_seq)
trg_pos_encoded = self.position_enc(trg_embedding)
tgt = self.dropout(trg_pos_encoded)
trg_mask = self.get_pad_mask(
trg_seq,
pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq)
output = tgt
for dec_layer in self.layer_stack:
output = dec_layer(
output,
src,
self_attn_mask=trg_mask,
dec_enc_attn_mask=src_mask)
output = self.layer_norm(output)
return output
def _get_mask(self, logit, valid_ratios):
N, T, _ = logit.shape
mask = None
if valid_ratios is not None:
mask = paddle.zeros((N, T))
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(T, math.ceil(T * valid_ratio))
mask[i, :valid_width] = 1
return mask
def forward_train(self, feat, out_enc, targets, valid_ratio):
src_mask = self._get_mask(out_enc, valid_ratio)
attn_output = self._attention(targets, out_enc, src_mask=src_mask)
outputs = self.classifier(attn_output)
return outputs
def forward_test(self, feat, out_enc, valid_ratio):
src_mask = self._get_mask(out_enc, valid_ratio)
N = out_enc.shape[0]
init_target_seq = paddle.full(
(N, self.max_seq_len + 1), self.padding_idx, dtype='int64')
# bsz * seq_len
init_target_seq[:, 0] = self.start_idx
outputs = []
for step in range(0, paddle.to_tensor(self.max_seq_len)):
decoder_output = self._attention(
init_target_seq, out_enc, src_mask=src_mask)
# bsz * seq_len * C
step_result = F.softmax(
self.classifier(decoder_output[:, step, :]), axis=-1)
# bsz * num_classes
outputs.append(step_result)
step_max_index = paddle.argmax(step_result, axis=-1)
init_target_seq[:, step + 1] = step_max_index
outputs = paddle.stack(outputs, axis=1)
return outputs
def forward(self, feat, out_enc, targets=None, valid_ratio=None):
if self.training:
return self.forward_train(feat, out_enc, targets, valid_ratio)
else:
return self.forward_test(feat, out_enc, valid_ratio)
class SATRNHead(nn.Layer):
def __init__(self, enc_cfg, dec_cfg, **kwargs):
super(SATRNHead, self).__init__()
# encoder module
self.encoder = SATRNEncoder(**enc_cfg)
# decoder module
self.decoder = SATRNDecoder(**dec_cfg)
def forward(self, feat, targets=None):
if targets is not None:
targets, valid_ratio = targets
else:
targets, valid_ratio = None, None
holistic_feat = self.encoder(feat, valid_ratio) # bsz c
final_out = self.decoder(feat, holistic_feat, targets, valid_ratio)
return final_out
...@@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess ...@@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \ DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \ SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
SPINLabelDecode, VLLabelDecode, RFLLabelDecode SPINLabelDecode, VLLabelDecode, RFLLabelDecode, SATRNLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
...@@ -52,7 +52,8 @@ def build_post_process(config, global_config=None): ...@@ -52,7 +52,8 @@ def build_post_process(config, global_config=None):
'TableMasterLabelDecode', 'SPINLabelDecode', 'TableMasterLabelDecode', 'SPINLabelDecode',
'DistillationSerPostProcess', 'DistillationRePostProcess', 'DistillationSerPostProcess', 'DistillationRePostProcess',
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess', 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess',
'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode' 'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode',
'SATRNLabelDecode'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -568,6 +568,82 @@ class SARLabelDecode(BaseRecLabelDecode): ...@@ -568,6 +568,82 @@ class SARLabelDecode(BaseRecLabelDecode):
return [self.padding_idx] return [self.padding_idx]
class SATRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SATRNLabelDecode, self).__init__(character_dict_path,
use_space_char)
self.rm_symbol = kwargs.get('rm_symbol', False)
def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
unknown_str = "<UKN>"
padding_str = "<PAD>"
dict_character = dict_character + [unknown_str]
self.unknown_idx = len(dict_character) - 1
dict_character = dict_character + [beg_end_str]
self.start_idx = len(dict_character) - 1
self.end_idx = len(dict_character) - 1
dict_character = dict_character + [padding_str]
self.padding_idx = len(dict_character) - 1
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if int(text_index[batch_idx][idx]) == int(self.end_idx):
if text_prob is None and idx == 0:
continue
else:
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
if self.rm_symbol:
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
text = text.lower()
text = comp.sub('', text)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
def get_ignored_tokens(self):
return [self.padding_idx]
class DistillationSARLabelDecode(SARLabelDecode): class DistillationSARLabelDecode(SARLabelDecode):
""" """
Convert Convert
......
...@@ -105,6 +105,12 @@ def export_single_model(model, ...@@ -105,6 +105,12 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"), shape=[None, 1, 32, 100], dtype="float32"),
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == 'SATRN':
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "VisionLAN": elif arch_config["algorithm"] == "VisionLAN":
other_shape = [ other_shape = [
paddle.static.InputSpec( paddle.static.InputSpec(
......
...@@ -106,6 +106,13 @@ class TextRecognizer(object): ...@@ -106,6 +106,13 @@ class TextRecognizer(object):
"character_dict_path": None, "character_dict_path": None,
"use_space_char": args.use_space_char "use_space_char": args.use_space_char
} }
elif self.rec_algorithm == "SATRN":
postprocess_params = {
'name': 'SATRNLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char,
"rm_symbol": True
}
elif self.rec_algorithm == "PREN": elif self.rec_algorithm == "PREN":
postprocess_params = {'name': 'PRENLabelDecode'} postprocess_params = {'name': 'PRENLabelDecode'}
elif self.rec_algorithm == "CAN": elif self.rec_algorithm == "CAN":
...@@ -429,7 +436,7 @@ class TextRecognizer(object): ...@@ -429,7 +436,7 @@ class TextRecognizer(object):
gsrm_slf_attn_bias1_list.append(norm_img[3]) gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4]) gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0]) norm_img_batch.append(norm_img[0])
elif self.rec_algorithm == "SVTR": elif self.rec_algorithm in ["SVTR", "SATRN"]:
norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
self.rec_image_shape) self.rec_image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
......
...@@ -220,7 +220,7 @@ def train(config, ...@@ -220,7 +220,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = [ extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
"RobustScanner", "RFL", 'DRRG' "RobustScanner", "RFL", 'DRRG', 'SATRN'
] ]
extra_input = False extra_input = False
if config['Architecture']['algorithm'] == 'Distillation': if config['Architecture']['algorithm'] == 'Distillation':
...@@ -643,7 +643,7 @@ def preprocess(is_train=False): ...@@ -643,7 +643,7 @@ def preprocess(is_train=False):
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', 'CAN', 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', 'CAN',
'Telescope' 'Telescope', 'SATRN'
] ]
if use_xpu: if use_xpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册