提交 0002349d 编写于 作者: z37757's avatar z37757

add text recognition algorithm rflearning

上级 6a8a0eeb
Global:
use_gpu: True
epoch_num: 6
log_smooth_window: 20
print_batch_step: 50
save_model_dir: ./output/rec/rec_resnet_rfl_att/
save_epoch_step: 1
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 5000]
cal_metric_during_train: True
pretrained_model: ./pretrain_models/rec_resnet_rfl_visual/best_accuracy.pdparams
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/rec_resnet_rfl.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
weight_decay: 0.0
clip_norm_global: 5.0
lr:
name: Piecewise
decay_epochs : [3, 4, 5]
values : [0.001, 0.0003, 0.00009, 0.000027]
Architecture:
model_type: rec
algorithm: RFL
in_channels: 1
Transform:
name: TPS
num_fiducial: 20
loc_lr: 1.0
model_name: large
Backbone:
name: ResNetRFL
use_cnt: True
use_seq: True
Neck:
name: RFAdaptor
use_v2s: True
use_s2v: True
Head:
name: RFLHead
in_channels: 512
hidden_size: 256
batch_max_legnth: 25
out_channels: 38
use_cnt: True
use_seq: True
Loss:
name: RFLLoss
# ignore_index: 0
PostProcess:
name: RFLLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/rfl_dataset2/training
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- RFLLabelEncode: # Class handling label
- RFLRecResizeImg:
image_shape: [1, 32, 100]
padding: false
interpolation: 2
- KeepKeys:
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 64
drop_last: True
num_workers: 8
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/rfl_dataset2/evaluation
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- RFLLabelEncode: # Class handling label
- RFLRecResizeImg:
image_shape: [1, 32, 100]
padding: false
interpolation: 2
- KeepKeys:
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 8
Global:
use_gpu: True
epoch_num: 6
log_smooth_window: 20
print_batch_step: 50
save_model_dir: ./output/rec/rec_resnet_rfl_visual/
save_epoch_step: 1
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 5000]
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/rec_resnet_rfl_visual.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
weight_decay: 0.0
clip_norm_global: 5.0
lr:
name: Piecewise
decay_epochs : [3, 4, 5]
values : [0.001, 0.0003, 0.00009, 0.000027]
Architecture:
model_type: rec
algorithm: RFL
in_channels: 1
Transform:
name: TPS
num_fiducial: 20
loc_lr: 1.0
model_name: large
Backbone:
name: ResNetRFL
use_cnt: True
use_seq: False
Neck:
name: RFAdaptor
use_v2s: False
use_s2v: False
Head:
name: RFLHead
in_channels: 512
hidden_size: 256
batch_max_legnth: 25
out_channels: 38
use_cnt: True
use_seq: False
Loss:
name: RFLLoss
PostProcess:
name: RFLLabelDecode
Metric:
name: CNTMetric
main_indicator: acc
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/rfl_dataset2/training
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- RFLLabelEncode: # Class handling label
- RFLRecResizeImg:
image_shape: [1, 32, 100]
padding: false
interpolation: 2
- KeepKeys:
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 64
drop_last: True
num_workers: 8
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/rfl_dataset2/evaluation
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- RFLLabelEncode: # Class handling label
- RFLRecResizeImg:
image_shape: [1, 32, 100]
padding: false
interpolation: 2
- KeepKeys:
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 8
......@@ -26,7 +26,8 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
RFLRecResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
......
......@@ -488,6 +488,62 @@ class AttnLabelEncode(BaseRecLabelEncode):
return idx
class RFLLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(RFLLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = [self.beg_str] + dict_character + [self.end_str]
return dict_character
def encode_cnt(self, text):
cnt_label = [0.0] * len(self.character)
for char_ in text:
cnt_label[char_] += 1
return np.array(cnt_label)
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
cnt_label = self.encode_cnt(text)
data['length'] = np.array(len(text))
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
- len(text) - 2)
if len(text) != self.max_text_len:
return None
data['label'] = np.array(text)
data['cnt_label'] = cnt_label
return data
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class SEEDLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
......
......@@ -237,6 +237,33 @@ class VLRecResizeImg(object):
return data
class RFLRecResizeImg(object):
def __init__(self, image_shape, padding=True, interpolation=1, **kwargs):
self.image_shape = image_shape
self.padding = padding
self.interpolation = interpolation
if self.interpolation == 0:
self.interpolation = cv2.INTER_NEAREST
elif self.interpolation == 1:
self.interpolation = cv2.INTER_LINEAR
elif self.interpolation == 2:
self.interpolation = cv2.INTER_CUBIC
elif self.interpolation == 3:
self.interpolation = cv2.INTER_AREA
else:
raise Exception("Unsupported interpolation type !!!")
def __call__(self, data):
img = data['image']
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
norm_img, valid_ratio = resize_norm_img(
img, self.image_shape, self.padding, self.interpolation)
data['image'] = norm_img
data['valid_ratio'] = valid_ratio
return data
class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape
......@@ -414,8 +441,13 @@ class SVTRRecResizeImg(object):
data['valid_ratio'] = valid_ratio
return data
class RobustScannerRecResizeImg(object):
def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs):
def __init__(self,
image_shape,
max_text_length,
width_downsample_ratio=0.25,
**kwargs):
self.image_shape = image_shape
self.width_downsample_ratio = width_downsample_ratio
self.max_text_length = max_text_length
......@@ -432,6 +464,7 @@ class RobustScannerRecResizeImg(object):
data['word_positons'] = word_positons
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
......@@ -467,13 +500,16 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img(img, image_shape, padding=True):
def resize_norm_img(img,
image_shape,
padding=True,
interpolation=cv2.INTER_LINEAR):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
img, (imgW, imgH), interpolation=interpolation)
resized_w = imgW
else:
ratio = w / float(h)
......
......@@ -38,6 +38,7 @@ from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss
from .rec_vl_loss import VLLoss
from .rec_spin_att_loss import SPINAttentionLoss
from .rec_rfl_loss import RFLLoss
# cls loss
from .cls_loss import ClsLoss
......@@ -69,7 +70,7 @@ def build_loss(config):
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
'SLALoss', 'CTLoss'
'SLALoss', 'CTLoss', 'RFLLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from .basic_loss import CELoss, DistanceLoss
class RFLLoss(nn.Layer):
def __init__(self, ignore_index=-100, **kwargs):
super().__init__()
self.cnt_loss = nn.MSELoss(**kwargs)
self.seq_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
def forward(self, predicts, batch):
self.total_loss = {}
total_loss = 0.0
# batch [image, label, length, cnt_label]
if predicts[0] is not None:
cnt_loss = self.cnt_loss(predicts[0],
paddle.cast(batch[3], paddle.float32))
self.total_loss['cnt_loss'] = cnt_loss
total_loss += cnt_loss
if predicts[1] is not None:
targets = batch[1].astype("int64")
label_lengths = batch[2].astype('int64')
batch_size, num_steps, num_classes = predicts[1].shape[0], predicts[
1].shape[1], predicts[1].shape[2]
assert len(targets.shape) == len(list(predicts[1].shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = predicts[1][:, :-1, :]
targets = targets[:, 1:]
inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]])
targets = paddle.reshape(targets, [-1])
seq_loss = self.seq_loss(inputs, targets)
self.total_loss['seq_loss'] = seq_loss
total_loss += seq_loss
self.total_loss['loss'] = total_loss
return self.total_loss
......@@ -22,7 +22,7 @@ import copy
__all__ = ["build_metric"]
from .det_metric import DetMetric, DetFCEMetric
from .rec_metric import RecMetric
from .rec_metric import RecMetric, CNTMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
......@@ -38,7 +38,7 @@ def build_metric(config):
support_dict = [
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
'VQAReTokenMetric', 'SRMetric', 'CTMetric'
'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric'
]
config = copy.deepcopy(config)
......
......@@ -16,7 +16,6 @@ from rapidfuzz.distance import Levenshtein
import string
class RecMetric(object):
def __init__(self,
main_indicator='acc',
......@@ -74,3 +73,42 @@ class RecMetric(object):
self.correct_num = 0
self.all_num = 0
self.norm_edit_dis = 0
class CNTMetric(object):
def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator
self.eps = 1e-5
self.reset()
def _normalize_text(self, text):
text = ''.join(
filter(lambda x: x in (string.digits + string.ascii_letters), text))
return text.lower()
def __call__(self, pred_label, *args, **kwargs):
preds, labels = pred_label
correct_num = 0
all_num = 0
for pred, target in zip(preds, labels):
if pred == target:
correct_num += 1
all_num += 1
self.correct_num += correct_num
self.all_num += all_num
return {'acc': correct_num / (all_num + self.eps), }
def get_metric(self):
"""
return metrics {
'acc': 0,
'norm_edit_dis': 0,
}
"""
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
self.reset()
return {'acc': acc}
def reset(self):
self.correct_num = 0
self.all_num = 0
......@@ -42,10 +42,11 @@ def build_backbone(config, model_type):
from .rec_efficientb3_pren import EfficientNetb3_PREN
from .rec_svtrnet import SVTRNet
from .rec_vitstr import ViTSTR
from .rec_resnet_rfl import ResNetRFL
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32'
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/backbones/ResNetRFL.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
kaiming_init_ = KaimingNormal()
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
class BasicBlock(nn.Layer):
"""Res-net Basic Block"""
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
norm_type='BN',
**kwargs):
"""
Args:
inplanes (int): input channel
planes (int): channels of the middle feature
stride (int): stride of the convolution
downsample (int): type of the down_sample
norm_type (str): type of the normalization
**kwargs (None): backup parameter
"""
super(BasicBlock, self).__init__()
self.conv1 = self._conv3x3(inplanes, planes)
self.bn1 = nn.BatchNorm(planes)
self.conv2 = self._conv3x3(planes, planes)
self.bn2 = nn.BatchNorm(planes)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def _conv3x3(self, in_planes, out_planes, stride=1):
return nn.Conv2D(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNetRFL(nn.Layer):
def __init__(self,
in_channels,
out_channels=512,
use_cnt=True,
use_seq=True):
"""
Args:
in_channels (int): input channel
out_channels (int): output channel
"""
super(ResNetRFL, self).__init__()
assert use_cnt or use_seq
self.use_cnt, self.use_seq = use_cnt, use_seq
self.backbone = RFLBase(in_channels)
self.out_channels = out_channels
self.out_channels_block = [
int(self.out_channels / 4), int(self.out_channels / 2),
self.out_channels, self.out_channels
]
block = BasicBlock
layers = [1, 2, 5, 3]
self.inplanes = int(self.out_channels // 2)
self.relu = nn.ReLU()
if self.use_seq:
self.maxpool3 = nn.MaxPool2D(
kernel_size=2, stride=(2, 1), padding=(0, 1))
self.layer3 = self._make_layer(
block, self.out_channels_block[2], layers[2], stride=1)
self.conv3 = nn.Conv2D(
self.out_channels_block[2],
self.out_channels_block[2],
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn3 = nn.BatchNorm(self.out_channels_block[2])
self.layer4 = self._make_layer(
block, self.out_channels_block[3], layers[3], stride=1)
self.conv4_1 = nn.Conv2D(
self.out_channels_block[3],
self.out_channels_block[3],
kernel_size=2,
stride=(2, 1),
padding=(0, 1),
bias_attr=False)
self.bn4_1 = nn.BatchNorm(self.out_channels_block[3])
self.conv4_2 = nn.Conv2D(
self.out_channels_block[3],
self.out_channels_block[3],
kernel_size=2,
stride=1,
padding=0,
bias_attr=False)
self.bn4_2 = nn.BatchNorm(self.out_channels_block[3])
if self.use_cnt:
self.inplanes = int(self.out_channels // 2)
self.v_maxpool3 = nn.MaxPool2D(
kernel_size=2, stride=(2, 1), padding=(0, 1))
self.v_layer3 = self._make_layer(
block, self.out_channels_block[2], layers[2], stride=1)
self.v_conv3 = nn.Conv2D(
self.out_channels_block[2],
self.out_channels_block[2],
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.v_bn3 = nn.BatchNorm(self.out_channels_block[2])
self.v_layer4 = self._make_layer(
block, self.out_channels_block[3], layers[3], stride=1)
self.v_conv4_1 = nn.Conv2D(
self.out_channels_block[3],
self.out_channels_block[3],
kernel_size=2,
stride=(2, 1),
padding=(0, 1),
bias_attr=False)
self.v_bn4_1 = nn.BatchNorm(self.out_channels_block[3])
self.v_conv4_2 = nn.Conv2D(
self.out_channels_block[3],
self.out_channels_block[3],
kernel_size=2,
stride=1,
padding=0,
bias_attr=False)
self.v_bn4_2 = nn.BatchNorm(self.out_channels_block[3])
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm(planes * block.expansion), )
layers = list()
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, inputs):
x_1 = self.backbone(inputs)
if self.use_cnt:
v_x = self.v_maxpool3(x_1)
v_x = self.v_layer3(v_x)
v_x = self.v_conv3(v_x)
v_x = self.v_bn3(v_x)
visual_feature_2 = self.relu(v_x)
v_x = self.v_layer4(visual_feature_2)
v_x = self.v_conv4_1(v_x)
v_x = self.v_bn4_1(v_x)
v_x = self.relu(v_x)
v_x = self.v_conv4_2(v_x)
v_x = self.v_bn4_2(v_x)
visual_feature_3 = self.relu(v_x)
else:
visual_feature_3 = None
if self.use_seq:
x = self.maxpool3(x_1)
x = self.layer3(x)
x = self.conv3(x)
x = self.bn3(x)
x_2 = self.relu(x)
x = self.layer4(x_2)
x = self.conv4_1(x)
x = self.bn4_1(x)
x = self.relu(x)
x = self.conv4_2(x)
x = self.bn4_2(x)
x_3 = self.relu(x)
else:
x_3 = None
return [visual_feature_3, x_3]
class ResNetBase(nn.Layer):
def __init__(self, in_channels, out_channels, block, layers):
super(ResNetBase, self).__init__()
self.out_channels_block = [
int(out_channels / 4), int(out_channels / 2), out_channels,
out_channels
]
self.inplanes = int(out_channels / 8)
self.conv0_1 = nn.Conv2D(
in_channels,
int(out_channels / 16),
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn0_1 = nn.BatchNorm(int(out_channels / 16))
self.conv0_2 = nn.Conv2D(
int(out_channels / 16),
self.inplanes,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn0_2 = nn.BatchNorm(self.inplanes)
self.relu = nn.ReLU()
self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.layer1 = self._make_layer(block, self.out_channels_block[0],
layers[0])
self.conv1 = nn.Conv2D(
self.out_channels_block[0],
self.out_channels_block[0],
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm(self.out_channels_block[0])
self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.layer2 = self._make_layer(
block, self.out_channels_block[1], layers[1], stride=1)
self.conv2 = nn.Conv2D(
self.out_channels_block[1],
self.out_channels_block[1],
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn2 = nn.BatchNorm(self.out_channels_block[1])
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm(planes * block.expansion), )
layers = list()
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv0_1(x)
x = self.bn0_1(x)
x = self.relu(x)
x = self.conv0_2(x)
x = self.bn0_2(x)
x = self.relu(x)
x = self.maxpool1(x)
x = self.layer1(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool2(x)
x = self.layer2(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x
class RFLBase(nn.Layer):
""" Reciprocal feature learning share backbone network"""
def __init__(self, in_channels, out_channels=512):
super(RFLBase, self).__init__()
self.ConvNet = ResNetBase(in_channels, out_channels, BasicBlock,
[1, 2, 5, 3])
def forward(self, inputs):
return self.ConvNet(inputs)
......@@ -38,6 +38,7 @@ def build_head(config):
from .rec_abinet_head import ABINetHead
from .rec_robustscanner_head import RobustScannerHead
from .rec_visionlan_head import VLHead
from .rec_rfl_head import RFLHead
# cls head
from .cls_head import ClsHead
......@@ -53,7 +54,7 @@ def build_head(config):
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head'
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead'
]
#table head
......
......@@ -149,6 +149,8 @@ class AttentionLSTM(nn.Layer):
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
char_onehots = None
alpha = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/sequence_heads/counting_head.py
"""
import paddle
import paddle.nn as nn
from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
from .rec_att_head import AttentionLSTM
kaiming_init_ = KaimingNormal()
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
class CNTHead(nn.Layer):
def __init__(self,
embed_size=512,
encode_length=26,
out_channels=38,
**kwargs):
super(CNTHead, self).__init__()
self.out_channels = out_channels
self.Wv_fusion = nn.Linear(embed_size, embed_size, bias_attr=False)
self.Prediction_visual = nn.Linear(encode_length * embed_size,
self.out_channels)
def forward(self, visual_feature):
b, c, h, w = visual_feature.shape
visual_feature = visual_feature.reshape([b, c, h * w]).transpose(
[0, 2, 1])
visual_feature_num = self.Wv_fusion(visual_feature) # batch * 26 * 512
b, n, c = visual_feature_num.shape
# using visual feature directly calculate the text length
visual_feature_num = visual_feature_num.reshape([b, n * c])
prediction_visual = self.Prediction_visual(visual_feature_num)
return prediction_visual
class RFLHead(nn.Layer):
def __init__(self,
in_channels=512,
hidden_size=256,
batch_max_legnth=25,
out_channels=38,
use_cnt=True,
use_seq=True,
**kwargs):
super(RFLHead, self).__init__()
assert use_cnt or use_seq
self.use_cnt = use_cnt
self.use_seq = use_seq
if self.use_cnt:
self.cnt_head = CNTHead(
embed_size=in_channels,
encode_length=batch_max_legnth + 1,
out_channels=out_channels,
**kwargs)
if self.use_seq:
self.seq_head = AttentionLSTM(
in_channels=in_channels,
out_channels=out_channels,
hidden_size=hidden_size,
**kwargs)
self.batch_max_legnth = batch_max_legnth
self.num_class = out_channels
self.apply(self.init_weights)
def init_weights(self, m):
if isinstance(m, nn.Linear):
kaiming_init_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
def forward(self, x, targets=None):
cnt_inputs, seq_inputs = x
if self.use_cnt:
cnt_outputs = self.cnt_head(cnt_inputs)
else:
cnt_outputs = None
if self.use_seq:
if self.training:
seq_outputs = self.seq_head(seq_inputs, targets[0],
self.batch_max_legnth)
else:
seq_outputs = self.seq_head(seq_inputs, None,
self.batch_max_legnth)
else:
seq_outputs = None
return cnt_outputs, seq_outputs
......@@ -27,9 +27,11 @@ def build_neck(config):
from .pren_fpn import PRENFPN
from .csp_pan import CSPPAN
from .ct_fpn import CTFPN
from .rf_adaptor import RFAdaptor
support_dict = [
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN'
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN',
'RFAdaptor'
]
module_name = config.pop('name')
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/connects/single_block/RFAdaptor.py
"""
import paddle
import paddle.nn as nn
from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
kaiming_init_ = KaimingNormal()
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
class S2VAdaptor(nn.Layer):
""" Semantic to Visual adaptation module"""
def __init__(self, in_channels=512):
super(S2VAdaptor, self).__init__()
self.in_channels = in_channels # 512
# feature strengthen module, channel attention
self.channel_inter = nn.Linear(
self.in_channels, self.in_channels, bias_attr=False)
self.channel_bn = nn.BatchNorm1D(self.in_channels)
self.channel_act = nn.ReLU()
self.apply(self.init_weights)
def init_weights(self, m):
if isinstance(m, nn.Conv2D):
kaiming_init_(m.weight)
if isinstance(m, nn.Conv2D) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm, nn.BatchNorm2D, nn.BatchNorm1D)):
zeros_(m.bias)
ones_(m.weight)
def forward(self, semantic):
semantic_source = semantic # batch, channel, height, width
# feature transformation
semantic = semantic.squeeze(2).transpose(
[0, 2, 1]) # batch, width, channel
channel_att = self.channel_inter(semantic) # batch, width, channel
channel_att = channel_att.transpose([0, 2, 1]) # batch, channel, width
channel_bn = self.channel_bn(channel_att) # batch, channel, width
channel_att = self.channel_act(channel_bn) # batch, channel, width
# Feature enhancement
channel_output = semantic_source * channel_att.unsqueeze(
-2) # batch, channel, 1, width
return channel_output
class V2SAdaptor(nn.Layer):
""" Visual to Semantic adaptation module"""
def __init__(self, in_channels=512, return_mask=False):
super(V2SAdaptor, self).__init__()
# parameter initialization
self.in_channels = in_channels
self.return_mask = return_mask
# output transformation
self.channel_inter = nn.Linear(
self.in_channels, self.in_channels, bias_attr=False)
self.channel_bn = nn.BatchNorm1D(self.in_channels)
self.channel_act = nn.ReLU()
def forward(self, visual):
# Feature enhancement
visual = visual.squeeze(2).transpose([0, 2, 1]) # batch, width, channel
channel_att = self.channel_inter(visual) # batch, width, channel
channel_att = channel_att.transpose([0, 2, 1]) # batch, channel, width
channel_bn = self.channel_bn(channel_att) # batch, channel, width
channel_att = self.channel_act(channel_bn) # batch, channel, width
# size alignment
channel_output = channel_att.unsqueeze(-2) # batch, width, channel
if self.return_mask:
return channel_output, channel_att
return channel_output
class RFAdaptor(nn.Layer):
def __init__(self, in_channels=512, use_v2s=True, use_s2v=True, **kwargs):
super(RFAdaptor, self).__init__()
if use_v2s is True:
self.neck_v2s = V2SAdaptor(in_channels=in_channels, **kwargs)
else:
self.neck_v2s = None
if use_s2v is True:
self.neck_s2v = S2VAdaptor(in_channels=in_channels, **kwargs)
else:
self.neck_s2v = None
self.out_channels = in_channels
def forward(self, x):
visual_feature, rcg_feature = x
if visual_feature is not None:
batch, source_channels, v_source_height, v_source_width = visual_feature.shape
visual_feature = visual_feature.reshape(
[batch, source_channels, 1, v_source_height * v_source_width])
if self.neck_v2s is not None:
v_rcg_feature = rcg_feature * self.neck_v2s(visual_feature)
else:
v_rcg_feature = rcg_feature
if self.neck_s2v is not None:
v_visual_feature = visual_feature + self.neck_s2v(rcg_feature)
else:
v_visual_feature = visual_feature
if v_rcg_feature is not None:
batch, source_channels, source_height, source_width = v_rcg_feature.shape
v_rcg_feature = v_rcg_feature.reshape(
[batch, source_channels, 1, source_height * source_width])
v_rcg_feature = v_rcg_feature.squeeze(2).transpose([0, 2, 1])
return v_visual_feature, v_rcg_feature
......@@ -53,6 +53,9 @@ def build_optimizer(config, epochs, step_each_epoch, model):
if 'clip_norm' in config:
clip_norm = config.pop('clip_norm')
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
elif 'clip_norm_global' in config:
clip_norm = config.pop('clip_norm_global')
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm)
else:
grad_clip = None
optim = getattr(optimizer, optim_name)(learning_rate=lr,
......
......@@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
SPINLabelDecode, VLLabelDecode
SPINLabelDecode, VLLabelDecode, RFLLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
......@@ -49,7 +49,7 @@ def build_post_process(config, global_config=None):
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
'TableMasterLabelDecode', 'SPINLabelDecode',
'DistillationSerPostProcess', 'DistillationRePostProcess',
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess'
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess', 'RFLLabelDecode'
]
if config['name'] == 'PSEPostProcess':
......
......@@ -242,6 +242,92 @@ class AttnLabelDecode(BaseRecLabelDecode):
return idx
class RFLLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(RFLLabelDecode, self).__init__(character_dict_path,
use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character
dict_character = [self.beg_str] + dict_character + [self.end_str]
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
[beg_idx, end_idx] = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if int(text_index[batch_idx][idx]) == int(end_idx):
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
cnt_pred, preds = preds
if preds is not None:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
else:
cnt_length = []
for lens in cnt_pred:
length = round(paddle.sum(lens).item())
cnt_length.append(length)
if label is None:
return cnt_length
label = self.decode(label, is_remove_duplicate=False)
length = [len(res[0]) for res in label]
return cnt_length, length
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class SEEDLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
......
Global:
use_gpu: True
epoch_num: 6
log_smooth_window: 20
print_batch_step: 50
save_model_dir: ./output/rec/rec_resnet_rfl/
save_epoch_step: 1
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 5000]
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/rec_resnet_rfl.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
weight_decay: 0.0
clip_norm_global: 5.0
lr:
name: Piecewise
decay_epochs : [3, 4, 5]
values : [0.001, 0.0003, 0.00009, 0.000027]
Architecture:
model_type: rec
algorithm: RFL
in_channels: 1
Transform:
name: TPS
num_fiducial: 20
loc_lr: 1.0
model_name: large
Backbone:
name: ResNetRFL
use_cnt: True
use_seq: True
Neck:
name: RFAdaptor
use_v2s: True
use_s2v: True
Head:
name: RFLHead
in_channels: 512
hidden_size: 256
batch_max_legnth: 25
out_channels: 38
use_cnt: True
use_seq: True
Loss:
name: RFLLoss
PostProcess:
name: RFLLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- RFLLabelEncode: # Class handling label
- RFLRecResizeImg:
image_shape: [1, 32, 100]
interpolation: 2
- KeepKeys:
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 64
drop_last: True
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- RFLLabelEncode: # Class handling label
- RFLRecResizeImg:
image_shape: [1, 32, 100]
interpolation: 2
- KeepKeys:
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 8
===========================train_params===========================
model_name:rec_resnet_rfl
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:./inference/rec_resnet_rfl_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_image_shape="1,32,100" --rec_algorithm="RFL" --min_subgraph_size=5
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[1,32,100]}]
......@@ -99,7 +99,7 @@ def export_single_model(model,
]
# print([None, 3, 32, 128])
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["NRTR", "SPIN"]:
elif arch_config["algorithm"] in ["NRTR", "SPIN", 'RFL']:
other_shape = [
paddle.static.InputSpec(
shape=[None, 1, 32, 100], dtype="float32"),
......
......@@ -100,6 +100,12 @@ class TextRecognizer(object):
"use_space_char": args.use_space_char,
"rm_symbol": True
}
elif self.rec_algorithm == 'RFL':
postprocess_params = {
'name': 'RFLLabelDecode',
"character_dict_path": None,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
......@@ -143,6 +149,16 @@ class TextRecognizer(object):
else:
norm_img = norm_img.astype(np.float32) / 128. - 1.
return norm_img
elif self.rec_algorithm == 'RFL':
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
resized_image = resized_image.astype('float32')
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
resized_image -= 0.5
resized_image /= 0.5
return resized_image
assert imgC == img.shape[2]
imgW = int((imgH * max_wh_ratio))
......
......@@ -217,7 +217,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
"RobustScanner"
"RobustScanner", "RFL"
]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
......@@ -625,7 +625,7 @@ def preprocess(is_train=False):
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
'Gestalt', 'SLANet', 'RobustScanner', 'CT'
'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL'
]
if use_xpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册