未验证 提交 2a98d40b 编写于 作者: T topduke 提交者: GitHub

Add v4rec hgnet (#9768)

* v4rec code

* v4rec add nrtrloss

* Add V4rec backbone file

* Add V4Rec config file.

* Fix V4rec reparameters when export_model

* convert lvnetv3

* fix codestyle

* fix infer_rec v4rec

* add v4rec hgnet

* add v4rec hgnet config

* add svtr_hgnet

* fix bugs in infer_rec and hgnet
上级 ce657457
Global:
debug: false
use_gpu: true
epoch_num: 200
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_ppocr_v4_hgnet
save_epoch_step: 10
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: &max_text_length 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_ppocrv3.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 5
regularizer:
name: L2
factor: 3.0e-05
Architecture:
model_type: rec
algorithm: SVTR_HGNet
Transform:
Backbone:
name: PPHGNet_small
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: *max_text_length
Loss:
name: MultiLoss
loss_config_list:
- CTCLoss:
- NRTRLoss:
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecConAug:
prob: 0.5
ext_data_num: 2
image_shape: [48, 320, 3]
max_text_length: *max_text_length
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4
......@@ -46,11 +46,12 @@ def build_backbone(config, model_type):
from .rec_densenet import DenseNet
from .rec_shallow_cnn import ShallowCNN
from .rec_lcnetv3 import LCNetv3
from .rec_hgnet import PPHGNet_small
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
'DenseNet', 'ShallowCNN', 'LCNetv3'
'DenseNet', 'ShallowCNN', 'LCNetv3', 'PPHGNet_small'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import KaimingNormal, Constant
from paddle.nn import Conv2D, BatchNorm2D, ReLU, AdaptiveAvgPool2D, MaxPool2D
from paddle.regularizer import L2Decay
from paddle import ParamAttr
kaiming_normal_ = KaimingNormal()
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
class ConvBNAct(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1,
use_act=True):
super().__init__()
self.use_act = use_act
self.conv = Conv2D(
in_channels,
out_channels,
kernel_size,
stride,
padding=(kernel_size - 1) // 2,
groups=groups,
bias_attr=False)
self.bn = BatchNorm2D(
out_channels,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
if self.use_act:
self.act = ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.use_act:
x = self.act(x)
return x
class ESEModule(nn.Layer):
def __init__(self, channels):
super().__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv = Conv2D(
in_channels=channels,
out_channels=channels,
kernel_size=1,
stride=1,
padding=0)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv(x)
x = self.sigmoid(x)
return paddle.multiply(x=identity, y=x)
class HG_Block(nn.Layer):
def __init__(
self,
in_channels,
mid_channels,
out_channels,
layer_num,
identity=False, ):
super().__init__()
self.identity = identity
self.layers = nn.LayerList()
self.layers.append(
ConvBNAct(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=3,
stride=1))
for _ in range(layer_num - 1):
self.layers.append(
ConvBNAct(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
stride=1))
# feature aggregation
total_channels = in_channels + layer_num * mid_channels
self.aggregation_conv = ConvBNAct(
in_channels=total_channels,
out_channels=out_channels,
kernel_size=1,
stride=1)
self.att = ESEModule(out_channels)
def forward(self, x):
identity = x
output = []
output.append(x)
for layer in self.layers:
x = layer(x)
output.append(x)
x = paddle.concat(output, axis=1)
x = self.aggregation_conv(x)
x = self.att(x)
if self.identity:
x += identity
return x
class HG_Stage(nn.Layer):
def __init__(self,
in_channels,
mid_channels,
out_channels,
block_num,
layer_num,
downsample=True,
stride=[2, 1]):
super().__init__()
self.downsample = downsample
if downsample:
self.downsample = ConvBNAct(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=stride,
groups=in_channels,
use_act=False)
blocks_list = []
blocks_list.append(
HG_Block(
in_channels,
mid_channels,
out_channels,
layer_num,
identity=False))
for _ in range(block_num - 1):
blocks_list.append(
HG_Block(
out_channels,
mid_channels,
out_channels,
layer_num,
identity=True))
self.blocks = nn.Sequential(*blocks_list)
def forward(self, x):
if self.downsample:
x = self.downsample(x)
x = self.blocks(x)
return x
class PPHGNet(nn.Layer):
"""
PPHGNet
Args:
stem_channels: list. Stem channel list of PPHGNet.
stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc.
layer_num: int. Number of layers of HG_Block.
use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer.
class_expand: int=2048. Number of channels for the last 1x1 convolutional layer.
dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used.
class_num: int=1000. The number of classes.
Returns:
model: nn.Layer. Specific PPHGNet model depends on args.
"""
def __init__(self, stem_channels, stage_config, layer_num, in_channels=3):
super().__init__()
# stem
stem_channels.insert(0, in_channels)
self.stem = nn.Sequential(* [
ConvBNAct(
in_channels=stem_channels[i],
out_channels=stem_channels[i + 1],
kernel_size=3,
stride=2 if i == 0 else 1) for i in range(
len(stem_channels) - 1)
])
# stages
self.stages = nn.LayerList()
for k in stage_config:
in_channels, mid_channels, out_channels, block_num, downsample, stride = stage_config[
k]
self.stages.append(
HG_Stage(in_channels, mid_channels, out_channels, block_num,
layer_num, downsample, stride))
self.out_channels = stage_config["stage4"][2]
self._init_weights()
def _init_weights(self):
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2D)):
ones_(m.weight)
zeros_(m.bias)
elif isinstance(m, nn.Linear):
zeros_(m.bias)
def forward(self, x):
x = self.stem(x)
for stage in self.stages:
x = stage(x)
if self.training:
x = F.adaptive_avg_pool2d(x, [1, 40])
else:
x = F.avg_pool2d(x, [3, 2])
return x
def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNet_tiny
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_tiny` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [96, 96, 224, 1, False, [2, 1]],
"stage2": [224, 128, 448, 1, True, [1, 2]],
"stage3": [448, 160, 512, 2, True, [2, 1]],
"stage4": [512, 192, 768, 1, True, [2, 1]],
}
model = PPHGNet(
stem_channels=[48, 48, 96],
stage_config=stage_config,
layer_num=5,
**kwargs)
return model
def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNet_small
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_small` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [128, 128, 256, 1, True, [2, 1]],
"stage2": [256, 160, 512, 1, True, [1, 2]],
"stage3": [512, 192, 768, 2, True, [2, 1]],
"stage4": [768, 224, 1024, 1, True, [2, 1]],
}
model = PPHGNet(
stem_channels=[64, 64, 128],
stage_config=stage_config,
layer_num=6,
**kwargs)
return model
def PPHGNet_base(pretrained=False, use_ssld=True, **kwargs):
"""
PPHGNet_base
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_base` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [160, 192, 320, 1, False, [2, 1]],
"stage2": [320, 224, 640, 2, True, [1, 2]],
"stage3": [640, 256, 960, 3, True, [2, 1]],
"stage4": [960, 288, 1280, 2, True, [2, 1]],
}
model = PPHGNet(
stem_channels=[96, 96, 160],
stage_config=stage_config,
layer_num=7,
dropout_prob=0.2,
**kwargs)
return model
......@@ -83,7 +83,7 @@ def main():
model = build_model(config['Architecture'])
extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "VisionLAN",
"RobustScanner"
"RobustScanner", "SVTR_HGNet"
]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
......
......@@ -62,7 +62,7 @@ def export_single_model(model,
shape=[None], dtype="float32")]
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "SVTR_LCNet":
elif arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 48, -1], dtype="float32"),
......
......@@ -291,7 +291,9 @@ def create_predictor(args, mode, logger):
def get_output_tensors(args, mode, predictor):
output_names = predictor.get_output_names()
output_tensors = []
if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]:
if mode == "rec" and args.rec_algorithm in [
"CRNN", "SVTR_LCNet", "SVTR_HGNet"
]:
output_name = 'softmax_0.tmp_0'
if output_name in output_names:
return [predictor.get_output_handle(output_name)]
......
......@@ -68,9 +68,6 @@ def main():
else:
config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num
# just one final tensor needs to exported for inference
config["Architecture"]["Models"][key][
"return_all_feats"] = False
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # multi head
out_channels_list = {}
......@@ -86,7 +83,6 @@ def main():
'out_channels_list'] = out_channels_list
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config['Architecture'])
load_model(config, model)
......
......@@ -230,7 +230,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "SPIN", "VisionLAN",
"RobustScanner", "RFL", 'DRRG', 'SATRN'
"RobustScanner", "RFL", 'DRRG', 'SATRN', 'SVTR_HGNet'
]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
......@@ -654,7 +654,7 @@ def preprocess(is_train=False):
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'SVTR_LCNet', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN',
'VisionLAN', 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG',
'CAN', 'Telescope', 'SATRN'
'CAN', 'Telescope', 'SATRN', 'SVTR_HGNet'
]
if use_xpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册