未验证 提交 43abe2fa 编写于 作者: T topduke 提交者: GitHub

V4Rec code pr (#9725)

* v4rec code

* v4rec add nrtrloss

* Add V4rec backbone file

* Add V4Rec config file.

* Fix V4rec reparameters when export_model

* convert lvnetv3

* fix codestyle

* fix infer_rec v4rec
上级 385a1f99
Global:
debug: false
use_gpu: true
epoch_num: 200
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_ppocr_v4
save_epoch_step: 10
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: &max_text_length 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_ppocrv3.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 5
regularizer:
name: L2
factor: 3.0e-05
Architecture:
model_type: rec
algorithm: SVTR_LCNet
Transform:
Backbone:
name: LCNetv3
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: *max_text_length
Loss:
name: MultiLoss
loss_config_list:
- CTCLoss:
- NRTRLoss:
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecConAug:
prob: 0.5
ext_data_num: 2
image_shape: [48, 320, 3]
max_text_length: *max_text_length
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4
......@@ -1241,27 +1241,36 @@ class MultiLabelEncode(BaseRecLabelEncode):
max_text_length,
character_dict_path=None,
use_space_char=False,
gtc_encode=None,
**kwargs):
super(MultiLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
self.ctc_encode = CTCLabelEncode(max_text_length, character_dict_path,
use_space_char, **kwargs)
self.sar_encode = SARLabelEncode(max_text_length, character_dict_path,
use_space_char, **kwargs)
self.gtc_encode_type = gtc_encode
if gtc_encode is None:
self.gtc_encode = SARLabelEncode(
max_text_length, character_dict_path, use_space_char, **kwargs)
else:
self.gtc_encode = eval(gtc_encode)(
max_text_length, character_dict_path, use_space_char, **kwargs)
def __call__(self, data):
data_ctc = copy.deepcopy(data)
data_sar = copy.deepcopy(data)
data_gtc = copy.deepcopy(data)
data_out = dict()
data_out['img_path'] = data.get('img_path', None)
data_out['image'] = data['image']
ctc = self.ctc_encode.__call__(data_ctc)
sar = self.sar_encode.__call__(data_sar)
if ctc is None or sar is None:
gtc = self.gtc_encode.__call__(data_gtc)
if ctc is None or gtc is None:
return None
data_out['label_ctc'] = ctc['label']
data_out['label_sar'] = sar['label']
if self.gtc_encode_type is not None:
data_out['label_gtc'] = gtc['label']
else:
data_out['label_sar'] = gtc['label']
data_out['length'] = ctc['length']
return data_out
......
......@@ -42,6 +42,7 @@ from .rec_spin_att_loss import SPINAttentionLoss
from .rec_rfl_loss import RFLLoss
from .rec_can_loss import CANLoss
from .rec_satrn_loss import SATRNLoss
from .rec_nrtr_loss import NRTRLoss
# cls loss
from .cls_loss import ClsLoss
......@@ -75,7 +76,7 @@ def build_loss(config):
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss',
'SATRNLoss'
'SATRNLoss', 'NRTRLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
......
......@@ -21,6 +21,7 @@ from paddle import nn
from .rec_ctc_loss import CTCLoss
from .rec_sar_loss import SARLoss
from .rec_nrtr_loss import NRTRLoss
class MultiLoss(nn.Layer):
......@@ -30,7 +31,6 @@ class MultiLoss(nn.Layer):
self.loss_list = kwargs.pop('loss_config_list')
self.weight_1 = kwargs.get('weight_1', 1.0)
self.weight_2 = kwargs.get('weight_2', 1.0)
self.gtc_loss = kwargs.get('gtc_loss', 'sar')
for loss_info in self.loss_list:
for name, param in loss_info.items():
if param is not None:
......@@ -49,6 +49,9 @@ class MultiLoss(nn.Layer):
elif name == 'SARLoss':
loss = loss_func(predicts['sar'],
batch[:1] + batch[2:])['loss'] * self.weight_2
elif name == 'NRTRLoss':
loss = loss_func(predicts['nrtr'],
batch[:1] + batch[2:])['loss'] * self.weight_2
else:
raise NotImplementedError(
'{} is not supported in MultiLoss yet'.format(name))
......
import paddle
from paddle import nn
import paddle.nn.functional as F
class NRTRLoss(nn.Layer):
def __init__(self, smoothing=True, ignore_index=0, **kwargs):
super(NRTRLoss, self).__init__()
if ignore_index >= 0 and not smoothing:
self.loss_func = nn.CrossEntropyLoss(
reduction='mean', ignore_index=ignore_index)
self.smoothing = smoothing
def forward(self, pred, batch):
max_len = batch[2].max()
tgt = batch[1][:, 1:2 + max_len]
pred = pred.reshape([-1, pred.shape[2]])
tgt = tgt.reshape([-1])
if self.smoothing:
eps = 0.1
n_class = pred.shape[1]
one_hot = F.one_hot(tgt, pred.shape[1])
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype=tgt.dtype))
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:
loss = self.loss_func(pred, tgt)
return {'loss': loss}
......@@ -45,11 +45,12 @@ def build_backbone(config, model_type):
from .rec_resnet_rfl import ResNetRFL
from .rec_densenet import DenseNet
from .rec_shallow_cnn import ShallowCNN
from .rec_lcnetv3 import LCNetv3
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
'DenseNet', 'ShallowCNN'
'DenseNet', 'ShallowCNN', 'LCNetv3'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Constant, KaimingNormal
from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Hardsigmoid, Hardswish, Identity, Linear, ReLU
from paddle.regularizer import L2Decay
NET_CONFIG = {
"blocks2":
#k, in_c, out_c, s, use_se
[[3, 16, 32, 1, False]],
"blocks3": [[3, 32, 64, 1, False], [3, 64, 64, 1, False]],
"blocks4": [[3, 64, 128, (2, 1), False], [3, 128, 128, 1, False]],
"blocks5":
[[3, 128, 256, (1, 2), False], [5, 256, 256, 1, False],
[5, 256, 256, 1, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False]],
"blocks6": [[5, 256, 512, (2, 1), True], [5, 512, 512, 1, True],
[5, 512, 512, (2, 1), False], [5, 512, 512, 1, False]]
}
def make_divisible(v, divisor=16, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class LearnableAffineBlock(nn.Layer):
def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0,
lab_lr=0.1):
super().__init__()
self.scale = self.create_parameter(
shape=[1, ],
default_initializer=Constant(value=scale_value),
attr=ParamAttr(learning_rate=lr_mult * lab_lr))
self.add_parameter("scale", self.scale)
self.bias = self.create_parameter(
shape=[1, ],
default_initializer=Constant(value=bias_value),
attr=ParamAttr(learning_rate=lr_mult * lab_lr))
self.add_parameter("bias", self.bias)
def forward(self, x):
return self.scale * x + self.bias
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1,
lr_mult=1.0):
super().__init__()
self.conv = Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(
initializer=KaimingNormal(), learning_rate=lr_mult),
bias_attr=False)
self.bn = BatchNorm2D(
out_channels,
weight_attr=ParamAttr(
regularizer=L2Decay(0.0), learning_rate=lr_mult),
bias_attr=ParamAttr(
regularizer=L2Decay(0.0), learning_rate=lr_mult))
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Act(nn.Layer):
def __init__(self, act="hswish", lr_mult=1.0, lab_lr=0.1):
super().__init__()
if act == "hswish":
self.act = Hardswish()
else:
assert act == "relu"
self.act = ReLU()
self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
def forward(self, x):
return self.lab(self.act(x))
class LearnableRepLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
num_conv_branches=1,
lr_mult=1.0,
lab_lr=0.1):
super().__init__()
self.is_repped = False
self.groups = groups
self.stride = stride
self.kernel_size = kernel_size
self.in_channels = in_channels
self.out_channels = out_channels
self.num_conv_branches = num_conv_branches
self.padding = (kernel_size - 1) // 2
self.identity = BatchNorm2D(
num_features=in_channels,
weight_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=ParamAttr(learning_rate=lr_mult)
) if out_channels == in_channels and stride == 1 else None
self.conv_kxk = nn.LayerList([
ConvBNLayer(
in_channels,
out_channels,
kernel_size,
stride,
groups=groups,
lr_mult=lr_mult) for _ in range(self.num_conv_branches)
])
self.conv_1x1 = ConvBNLayer(
in_channels,
out_channels,
1,
stride,
groups=groups,
lr_mult=lr_mult) if kernel_size > 1 else None
self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
self.act = Act(lr_mult=lr_mult, lab_lr=lab_lr)
def forward(self, x):
# for export
if self.is_repped:
out = self.lab(self.reparam_conv(x))
if self.stride != 2:
out = self.act(out)
return out
out = 0
if self.identity is not None:
out += self.identity(x)
if self.conv_1x1 is not None:
out += self.conv_1x1(x)
for conv in self.conv_kxk:
out += conv(x)
out = self.lab(out)
if self.stride != 2:
out = self.act(out)
return out
def rep(self):
if self.is_repped:
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = Conv2D(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
groups=self.groups)
self.reparam_conv.weight.set_value(kernel)
self.reparam_conv.bias.set_value(bias)
self.is_repped = True
def _pad_kernel_1x1_to_kxk(self, kernel1x1, pad):
if not isinstance(kernel1x1, paddle.Tensor):
return 0
else:
return nn.functional.pad(kernel1x1, [pad, pad, pad, pad])
def _get_kernel_bias(self):
kernel_conv_1x1, bias_conv_1x1 = self._fuse_bn_tensor(self.conv_1x1)
kernel_conv_1x1 = self._pad_kernel_1x1_to_kxk(kernel_conv_1x1,
self.kernel_size // 2)
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
kernel_conv_kxk = 0
bias_conv_kxk = 0
for conv in self.conv_kxk:
kernel, bias = self._fuse_bn_tensor(conv)
kernel_conv_kxk += kernel
bias_conv_kxk += bias
kernel_reparam = kernel_conv_kxk + kernel_conv_1x1 + kernel_identity
bias_reparam = bias_conv_kxk + bias_conv_1x1 + bias_identity
return kernel_reparam, bias_reparam
def _fuse_bn_tensor(self, branch):
if not branch:
return 0, 0
elif isinstance(branch, ConvBNLayer):
kernel = branch.conv.weight
running_mean = branch.bn._mean
running_var = branch.bn._variance
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn._epsilon
else:
assert isinstance(branch, BatchNorm2D)
if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups
kernel_value = paddle.zeros(
(self.in_channels, input_dim, self.kernel_size,
self.kernel_size),
dtype=branch.weight.dtype)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, self.kernel_size // 2,
self.kernel_size // 2] = 1
self.id_tensor = kernel_value
kernel = self.id_tensor
running_mean = branch._mean
running_var = branch._variance
gamma = branch.weight
beta = branch.bias
eps = branch._epsilon
std = (running_var + eps).sqrt()
t = (gamma / std).reshape((-1, 1, 1, 1))
return kernel * t, beta - running_mean * gamma / std
class SELayer(nn.Layer):
def __init__(self, channel, reduction=4, lr_mult=1.0):
super().__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv1 = Conv2D(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=ParamAttr(learning_rate=lr_mult))
self.relu = ReLU()
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=ParamAttr(learning_rate=lr_mult))
self.hardsigmoid = Hardsigmoid()
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.hardsigmoid(x)
x = paddle.multiply(x=identity, y=x)
return x
class LCNetV3Block(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
dw_size,
use_se=False,
conv_kxk_num=4,
lr_mult=1.0,
lab_lr=0.1):
super().__init__()
self.use_se = use_se
self.dw_conv = LearnableRepLayer(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=dw_size,
stride=stride,
groups=in_channels,
num_conv_branches=conv_kxk_num,
lr_mult=lr_mult,
lab_lr=lab_lr)
if use_se:
self.se = SELayer(in_channels, lr_mult=lr_mult)
self.pw_conv = LearnableRepLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
num_conv_branches=conv_kxk_num,
lr_mult=lr_mult,
lab_lr=lab_lr)
def forward(self, x):
x = self.dw_conv(x)
if self.use_se:
x = self.se(x)
x = self.pw_conv(x)
return x
class PPLCNetV3(nn.Layer):
def __init__(self,
scale=1.0,
conv_kxk_num=4,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
lab_lr=0.1,
**kwargs):
super().__init__()
self.scale = scale
self.lr_mult_list = lr_mult_list
self.net_config = NET_CONFIG
assert isinstance(self.lr_mult_list, (
list, tuple
)), "lr_mult_list should be in (list, tuple) but got {}".format(
type(self.lr_mult_list))
assert len(self.lr_mult_list
) == 6, "lr_mult_list length should be 6 but got {}".format(
len(self.lr_mult_list))
self.conv1 = ConvBNLayer(
in_channels=3,
out_channels=make_divisible(16 * scale),
kernel_size=3,
stride=2,
lr_mult=self.lr_mult_list[0])
self.blocks2 = nn.Sequential(* [
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[1],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
self.net_config["blocks2"])
])
self.blocks3 = nn.Sequential(* [
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[2],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
self.net_config["blocks3"])
])
self.blocks4 = nn.Sequential(* [
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[3],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
self.net_config["blocks4"])
])
self.blocks5 = nn.Sequential(* [
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[4],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
self.net_config["blocks5"])
])
self.blocks6 = nn.Sequential(* [
LCNetV3Block(
in_channels=make_divisible(in_c * scale),
out_channels=make_divisible(out_c * scale),
dw_size=k,
stride=s,
use_se=se,
conv_kxk_num=conv_kxk_num,
lr_mult=self.lr_mult_list[5],
lab_lr=lab_lr) for i, (k, in_c, out_c, s, se) in enumerate(
self.net_config["blocks6"])
])
self.out_channels = make_divisible(512 * scale)
def forward(self, x):
x = self.conv1(x)
x = self.blocks2(x)
x = self.blocks3(x)
x = self.blocks4(x)
x = self.blocks5(x)
x = self.blocks6(x)
if self.training:
x = F.adaptive_avg_pool2d(x, [1, 40])
else:
x = F.avg_pool2d(x, [3, 2])
return x
def LCNetv3(pretrained=False, use_ssld=False, **kwargs):
model = PPLCNetV3(scale=0.95, conv_kxk_num=4, **kwargs)
return model
......@@ -108,6 +108,7 @@ class MobileNetV1Enhance(nn.Layer):
scale=0.5,
last_conv_stride=1,
last_pool_type='max',
last_pool_kernel_size=[3, 2],
**kwargs):
super().__init__()
self.scale = scale
......@@ -214,7 +215,10 @@ class MobileNetV1Enhance(nn.Layer):
self.block_list = nn.Sequential(*self.block_list)
if last_pool_type == 'avg':
self.pool = nn.AvgPool2D(kernel_size=2, stride=2, padding=0)
self.pool = nn.AvgPool2D(
kernel_size=last_pool_kernel_size,
stride=last_pool_kernel_size,
padding=0)
else:
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = int(1024 * scale)
......
......@@ -155,8 +155,9 @@ class Attention(nn.Layer):
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.dim = dim
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
......@@ -183,13 +184,9 @@ class Attention(nn.Layer):
self.mixer = mixer
def forward(self, x):
if self.HW is not None:
N = self.N
C = self.C
else:
_, N, C = x.shape
qkv = self.qkv(x).reshape((0, N, 3, self.num_heads, C //
self.num_heads)).transpose((2, 0, 3, 1, 4))
qkv = self.qkv(x).reshape(
(0, -1, 3, self.num_heads, self.head_dim)).transpose(
(2, 0, 3, 1, 4))
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = (q.matmul(k.transpose((0, 1, 3, 2))))
......@@ -198,7 +195,7 @@ class Attention(nn.Layer):
attn = nn.functional.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, N, C))
x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, -1, self.dim))
x = self.proj(x)
x = self.proj_drop(x)
return x
......
......@@ -25,12 +25,28 @@ import paddle.nn.functional as F
from ppocr.modeling.necks.rnn import Im2Seq, EncoderWithRNN, EncoderWithFC, SequenceEncoder, EncoderWithSVTR
from .rec_ctc_head import CTCHead
from .rec_sar_head import SARHead
from .rec_nrtr_head import Transformer
class FCTranspose(nn.Layer):
def __init__(self, in_channels, out_channels, only_transpose=False):
super().__init__()
self.only_transpose = only_transpose
if not self.only_transpose:
self.fc = nn.Linear(in_channels, out_channels, bias_attr=False)
def forward(self, x):
if self.only_transpose:
return x.transpose([0, 2, 1])
else:
return self.fc(x.transpose([0, 2, 1]))
class MultiHead(nn.Layer):
def __init__(self, in_channels, out_channels_list, **kwargs):
super().__init__()
self.head_list = kwargs.pop('head_list')
self.gtc_head = 'sar'
assert len(self.head_list) >= 2
for idx, head_name in enumerate(self.head_list):
......@@ -40,12 +56,27 @@ class MultiHead(nn.Layer):
sar_args = self.head_list[idx][name]
self.sar_head = eval(name)(in_channels=in_channels, \
out_channels=out_channels_list['SARLabelDecode'], **sar_args)
elif name == 'NRTRHead':
gtc_args = self.head_list[idx][name]
max_text_length = gtc_args.get('max_text_length', 25)
nrtr_dim = gtc_args.get('nrtr_dim', 256)
num_decoder_layers = gtc_args.get('num_decoder_layers', 4)
self.before_gtc = nn.Sequential(
nn.Flatten(2), FCTranspose(in_channels, nrtr_dim))
self.gtc_head = Transformer(
d_model=nrtr_dim,
nhead=nrtr_dim // 32,
num_encoder_layers=-1,
beam_size=-1,
num_decoder_layers=num_decoder_layers,
max_len=max_text_length,
dim_feedforward=nrtr_dim * 4,
out_channels=out_channels_list['NRTRLabelDecode'])
elif name == 'CTCHead':
# ctc neck
self.encoder_reshape = Im2Seq(in_channels)
neck_args = self.head_list[idx][name]['Neck']
encoder_type = neck_args.pop('name')
self.encoder = encoder_type
self.ctc_encoder = SequenceEncoder(in_channels=in_channels, \
encoder_type=encoder_type, **neck_args)
# ctc head
......@@ -57,6 +88,7 @@ class MultiHead(nn.Layer):
'{} is not supported in MultiHead yet'.format(name))
def forward(self, x, targets=None):
ctc_encoder = self.ctc_encoder(x)
ctc_out = self.ctc_head(ctc_encoder, targets)
head_out = dict()
......@@ -68,6 +100,7 @@ class MultiHead(nn.Layer):
if self.gtc_head == 'sar':
sar_out = self.sar_head(x, targets[1:])
head_out['sar'] = sar_out
return head_out
else:
return head_out
gtc_out = self.gtc_head(self.before_gtc(x), targets[1:])
head_out['nrtr'] = gtc_out
return head_out
......@@ -47,8 +47,10 @@ class EncoderWithRNN(nn.Layer):
x, _ = self.lstm(x)
return x
class BidirectionalLSTM(nn.Layer):
def __init__(self, input_size,
def __init__(self,
input_size,
hidden_size,
output_size=None,
num_layers=1,
......@@ -58,39 +60,46 @@ class BidirectionalLSTM(nn.Layer):
with_linear=False):
super(BidirectionalLSTM, self).__init__()
self.with_linear = with_linear
self.rnn = nn.LSTM(input_size,
hidden_size,
num_layers=num_layers,
dropout=dropout,
direction=direction,
time_major=time_major)
self.rnn = nn.LSTM(
input_size,
hidden_size,
num_layers=num_layers,
dropout=dropout,
direction=direction,
time_major=time_major)
# text recognition the specified structure LSTM with linear
if self.with_linear:
self.linear = nn.Linear(hidden_size * 2, output_size)
def forward(self, input_feature):
recurrent, _ = self.rnn(input_feature) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
recurrent, _ = self.rnn(
input_feature
) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
if self.with_linear:
output = self.linear(recurrent) # batch_size x T x output_size
output = self.linear(recurrent) # batch_size x T x output_size
return output
return recurrent
class EncoderWithCascadeRNN(nn.Layer):
def __init__(self, in_channels, hidden_size, out_channels, num_layers=2, with_linear=False):
def __init__(self,
in_channels,
hidden_size,
out_channels,
num_layers=2,
with_linear=False):
super(EncoderWithCascadeRNN, self).__init__()
self.out_channels = out_channels[-1]
self.encoder = nn.LayerList(
[BidirectionalLSTM(
in_channels if i == 0 else out_channels[i - 1],
hidden_size,
output_size=out_channels[i],
num_layers=1,
direction='bidirectional',
with_linear=with_linear)
for i in range(num_layers)]
)
self.encoder = nn.LayerList([
BidirectionalLSTM(
in_channels if i == 0 else out_channels[i - 1],
hidden_size,
output_size=out_channels[i],
num_layers=1,
direction='bidirectional',
with_linear=with_linear) for i in range(num_layers)
])
def forward(self, x):
for i, l in enumerate(self.encoder):
......@@ -130,12 +139,17 @@ class EncoderWithSVTR(nn.Layer):
drop_rate=0.1,
attn_drop_rate=0.1,
drop_path=0.,
kernel_size=[3, 3],
qk_scale=None):
super(EncoderWithSVTR, self).__init__()
self.depth = depth
self.use_guide = use_guide
self.conv1 = ConvBNLayer(
in_channels, in_channels // 8, padding=1, act=nn.Swish)
in_channels,
in_channels // 8,
kernel_size=kernel_size,
padding=[kernel_size[0] // 2, kernel_size[1] // 2],
act=nn.Swish)
self.conv2 = ConvBNLayer(
in_channels // 8, hidden_dims, kernel_size=1, act=nn.Swish)
......@@ -161,7 +175,11 @@ class EncoderWithSVTR(nn.Layer):
hidden_dims, in_channels, kernel_size=1, act=nn.Swish)
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
self.conv4 = ConvBNLayer(
2 * in_channels, in_channels // 8, padding=1, act=nn.Swish)
2 * in_channels,
in_channels // 8,
kernel_size=kernel_size,
padding=[kernel_size[0] // 2, kernel_size[1] // 2],
act=nn.Swish)
self.conv1x1 = ConvBNLayer(
in_channels // 8, dims, kernel_size=1, act=nn.Swish)
......
......@@ -54,8 +54,12 @@ def main():
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
if config['PostProcess'][
'name'] == 'DistillationNRTRLabelDecode':
char_num = char_num - 3
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
out_channels_list['NRTRLabelDecode'] = char_num + 3
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
......@@ -66,8 +70,11 @@ def main():
out_channels_list = {}
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
if config['PostProcess']['name'] == 'NRTRLabelDecode':
char_num = char_num - 3
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
out_channels_list['NRTRLabelDecode'] = char_num + 3
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model
......
......@@ -187,6 +187,12 @@ def export_single_model(model,
shape=[None] + infer_shape, dtype="float32")
])
if arch_config["Backbone"]["name"] == "LCNetv3":
# for rep lcnetv3
for layer in model.sublayers():
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
layer.rep()
if quanter is None:
paddle.jit.save(model, save_path)
else:
......@@ -218,8 +224,12 @@ def main():
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
if config['PostProcess'][
'name'] == 'DistillationNRTRLabelDecode':
char_num = char_num - 3
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
out_channels_list['NRTRLabelDecode'] = char_num + 3
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
......@@ -234,8 +244,11 @@ def main():
char_num = len(getattr(post_process_class, 'character'))
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
if config['PostProcess']['name'] == 'NRTRLabelDecode':
char_num = char_num - 3
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
out_channels_list['NRTRLabelDecode'] = char_num + 3
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model
......
......@@ -48,33 +48,44 @@ def main():
# build model
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
if config['Architecture']["algorithm"] in ["Distillation",
if config["Architecture"]["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
if config['Architecture']['Models'][key]['Head'][
'name'] == 'MultiHead': # for multi head
for key in config["Architecture"]["Models"]:
if config["Architecture"]["Models"][key]["Head"][
"name"] == 'MultiHead': # multi head
out_channels_list = {}
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
if config['PostProcess'][
'name'] == 'DistillationNRTRLabelDecode':
char_num = char_num - 3
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
out_channels_list['NRTRLabelDecode'] = char_num + 3
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num
# just one final tensor needs to exported for inference
config["Architecture"]["Models"][key][
"return_all_feats"] = False
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # for multi head loss
'name'] == 'MultiHead': # multi head
out_channels_list = {}
char_num = len(getattr(post_process_class, 'character'))
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
if config['PostProcess']['name'] == 'NRTRLabelDecode':
char_num = char_num - 3
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
out_channels_list['NRTRLabelDecode'] = char_num + 3
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config['Architecture'])
......
......@@ -80,14 +80,22 @@ def main(config, device, logger, vdl_writer):
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][-1].keys())[
0] == 'DistillationSARLoss'
config['Loss']['loss_config_list'][-1][
'DistillationSARLoss']['ignore_index'] = char_num + 1
if config['PostProcess'][
'name'] == 'DistillationNRTRLabelDecode':
char_num = char_num - 3
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
# update SARLoss params
if list(config['Loss']['loss_config_list'][-1].keys())[
0] == 'DistillationSARLoss':
config['Loss']['loss_config_list'][-1][
'DistillationSARLoss'][
'ignore_index'] = char_num + 1
out_channels_list['SARLabelDecode'] = char_num + 2
elif list(config['Loss']['loss_config_list'][-1].keys())[
0] == 'DistillationNRTRLoss':
out_channels_list['NRTRLabelDecode'] = char_num + 3
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
......@@ -97,19 +105,24 @@ def main(config, device, logger, vdl_writer):
'name'] == 'MultiHead': # for multi head
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][1].keys())[
0] == 'SARLoss'
if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
config['Loss']['loss_config_list'][1]['SARLoss'] = {
'ignore_index': char_num + 1
}
else:
config['Loss']['loss_config_list'][1]['SARLoss'][
'ignore_index'] = char_num + 1
if config['PostProcess']['name'] == 'NRTRLabelDecode':
char_num = char_num - 3
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
# update SARLoss params
if list(config['Loss']['loss_config_list'][1].keys())[
0] == 'SARLoss':
if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
config['Loss']['loss_config_list'][1]['SARLoss'] = {
'ignore_index': char_num + 1
}
else:
config['Loss']['loss_config_list'][1]['SARLoss'][
'ignore_index'] = char_num + 1
out_channels_list['SARLabelDecode'] = char_num + 2
elif list(config['Loss']['loss_config_list'][1].keys())[
0] == 'NRTRLoss':
out_channels_list['NRTRLabelDecode'] = char_num + 3
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册