diff --git a/configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_hgnet.yml b/configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_hgnet.yml new file mode 100644 index 0000000000000000000000000000000000000000..e5ed916a2fc89c2bcf2a37815c197aaecd78c3de --- /dev/null +++ b/configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_hgnet.yml @@ -0,0 +1,131 @@ +Global: + debug: false + use_gpu: true + epoch_num: 200 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec_ppocr_v4_hgnet + save_epoch_step: 10 + eval_batch_step: [0, 2000] + cal_metric_during_train: true + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + max_text_length: &max_text_length 25 + infer_mode: false + use_space_char: true + distributed: true + save_res_path: ./output/rec/predicts_ppocrv3.txt + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 5 + regularizer: + name: L2 + factor: 3.0e-05 + + +Architecture: + model_type: rec + algorithm: SVTR_HGNet + Transform: + Backbone: + name: PPHGNet_small + Head: + name: MultiHead + head_list: + - CTCHead: + Neck: + name: svtr + dims: 120 + depth: 2 + hidden_dims: 120 + kernel_size: [1, 3] + use_guide: True + Head: + fc_decay: 0.00001 + - NRTRHead: + nrtr_dim: 384 + max_text_length: *max_text_length + +Loss: + name: MultiLoss + loss_config_list: + - CTCLoss: + - NRTRLoss: + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ + ext_op_transform_idx: 1 + label_file_list: + - ./train_data/train_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecConAug: + prob: 0.5 + ext_data_num: 2 + image_shape: [48, 320, 3] + max_text_length: *max_text_length + - RecAug: + - MultiLabelEncode: + gtc_encode: NRTRLabelEncode + - RecResizeImg: + image_shape: [3, 48, 320] + - KeepKeys: + keep_keys: + - image + - label_ctc + - label_gtc + - length + - valid_ratio + loader: + shuffle: true + batch_size_per_card: 128 + drop_last: true + num_workers: 4 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - MultiLabelEncode: + gtc_encode: NRTRLabelEncode + - RecResizeImg: + image_shape: [3, 48, 320] + - KeepKeys: + keep_keys: + - image + - label_ctc + - label_gtc + - length + - valid_ratio + loader: + shuffle: false + drop_last: false + batch_size_per_card: 128 + num_workers: 4 diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 63aef33019ec04a5f2a0f6027e189be2864fd5b5..2e8f4c05c8dd08fe21eaed22f55c4c81c6131b5f 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -46,11 +46,12 @@ def build_backbone(config, model_type): from .rec_densenet import DenseNet from .rec_shallow_cnn import ShallowCNN from .rec_lcnetv3 import LCNetv3 + from .rec_hgnet import PPHGNet_small support_dict = [ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet', 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL', - 'DenseNet', 'ShallowCNN', 'LCNetv3' + 'DenseNet', 'ShallowCNN', 'LCNetv3', 'PPHGNet_small' ] elif model_type == 'e2e': from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_hgnet.py b/ppocr/modeling/backbones/rec_hgnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f2d7bdee1cd5b9fc37e98f20f8b7637901ef87 --- /dev/null +++ b/ppocr/modeling/backbones/rec_hgnet.py @@ -0,0 +1,314 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.initializer import KaimingNormal, Constant +from paddle.nn import Conv2D, BatchNorm2D, ReLU, AdaptiveAvgPool2D, MaxPool2D +from paddle.regularizer import L2Decay +from paddle import ParamAttr + +kaiming_normal_ = KaimingNormal() +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +class ConvBNAct(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + groups=1, + use_act=True): + super().__init__() + self.use_act = use_act + self.conv = Conv2D( + in_channels, + out_channels, + kernel_size, + stride, + padding=(kernel_size - 1) // 2, + groups=groups, + bias_attr=False) + self.bn = BatchNorm2D( + out_channels, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + if self.use_act: + self.act = ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.use_act: + x = self.act(x) + return x + + +class ESEModule(nn.Layer): + def __init__(self, channels): + super().__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv = Conv2D( + in_channels=channels, + out_channels=channels, + kernel_size=1, + stride=1, + padding=0) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv(x) + x = self.sigmoid(x) + return paddle.multiply(x=identity, y=x) + + +class HG_Block(nn.Layer): + def __init__( + self, + in_channels, + mid_channels, + out_channels, + layer_num, + identity=False, ): + super().__init__() + self.identity = identity + + self.layers = nn.LayerList() + self.layers.append( + ConvBNAct( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=3, + stride=1)) + for _ in range(layer_num - 1): + self.layers.append( + ConvBNAct( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=3, + stride=1)) + + # feature aggregation + total_channels = in_channels + layer_num * mid_channels + self.aggregation_conv = ConvBNAct( + in_channels=total_channels, + out_channels=out_channels, + kernel_size=1, + stride=1) + self.att = ESEModule(out_channels) + + def forward(self, x): + identity = x + output = [] + output.append(x) + for layer in self.layers: + x = layer(x) + output.append(x) + x = paddle.concat(output, axis=1) + x = self.aggregation_conv(x) + x = self.att(x) + if self.identity: + x += identity + return x + + +class HG_Stage(nn.Layer): + def __init__(self, + in_channels, + mid_channels, + out_channels, + block_num, + layer_num, + downsample=True, + stride=[2, 1]): + super().__init__() + self.downsample = downsample + if downsample: + self.downsample = ConvBNAct( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=stride, + groups=in_channels, + use_act=False) + + blocks_list = [] + blocks_list.append( + HG_Block( + in_channels, + mid_channels, + out_channels, + layer_num, + identity=False)) + for _ in range(block_num - 1): + blocks_list.append( + HG_Block( + out_channels, + mid_channels, + out_channels, + layer_num, + identity=True)) + self.blocks = nn.Sequential(*blocks_list) + + def forward(self, x): + if self.downsample: + x = self.downsample(x) + x = self.blocks(x) + return x + + +class PPHGNet(nn.Layer): + """ + PPHGNet + Args: + stem_channels: list. Stem channel list of PPHGNet. + stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc. + layer_num: int. Number of layers of HG_Block. + use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer. + class_expand: int=2048. Number of channels for the last 1x1 convolutional layer. + dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used. + class_num: int=1000. The number of classes. + Returns: + model: nn.Layer. Specific PPHGNet model depends on args. + """ + + def __init__(self, stem_channels, stage_config, layer_num, in_channels=3): + super().__init__() + + # stem + stem_channels.insert(0, in_channels) + self.stem = nn.Sequential(* [ + ConvBNAct( + in_channels=stem_channels[i], + out_channels=stem_channels[i + 1], + kernel_size=3, + stride=2 if i == 0 else 1) for i in range( + len(stem_channels) - 1) + ]) + + # stages + self.stages = nn.LayerList() + for k in stage_config: + in_channels, mid_channels, out_channels, block_num, downsample, stride = stage_config[ + k] + self.stages.append( + HG_Stage(in_channels, mid_channels, out_channels, block_num, + layer_num, downsample, stride)) + + self.out_channels = stage_config["stage4"][2] + self._init_weights() + + def _init_weights(self): + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + kaiming_normal_(m.weight) + elif isinstance(m, (nn.BatchNorm2D)): + ones_(m.weight) + zeros_(m.bias) + elif isinstance(m, nn.Linear): + zeros_(m.bias) + + def forward(self, x): + x = self.stem(x) + for stage in self.stages: + x = stage(x) + if self.training: + x = F.adaptive_avg_pool2d(x, [1, 40]) + else: + x = F.avg_pool2d(x, [3, 2]) + return x + + +def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs): + """ + PPHGNet_tiny + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `PPHGNet_tiny` model depends on args. + """ + stage_config = { + # in_channels, mid_channels, out_channels, blocks, downsample + "stage1": [96, 96, 224, 1, False, [2, 1]], + "stage2": [224, 128, 448, 1, True, [1, 2]], + "stage3": [448, 160, 512, 2, True, [2, 1]], + "stage4": [512, 192, 768, 1, True, [2, 1]], + } + + model = PPHGNet( + stem_channels=[48, 48, 96], + stage_config=stage_config, + layer_num=5, + **kwargs) + return model + + +def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs): + """ + PPHGNet_small + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `PPHGNet_small` model depends on args. + """ + stage_config = { + # in_channels, mid_channels, out_channels, blocks, downsample + "stage1": [128, 128, 256, 1, True, [2, 1]], + "stage2": [256, 160, 512, 1, True, [1, 2]], + "stage3": [512, 192, 768, 2, True, [2, 1]], + "stage4": [768, 224, 1024, 1, True, [2, 1]], + } + + model = PPHGNet( + stem_channels=[64, 64, 128], + stage_config=stage_config, + layer_num=6, + **kwargs) + return model + + +def PPHGNet_base(pretrained=False, use_ssld=True, **kwargs): + """ + PPHGNet_base + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `PPHGNet_base` model depends on args. + """ + stage_config = { + # in_channels, mid_channels, out_channels, blocks, downsample + "stage1": [160, 192, 320, 1, False, [2, 1]], + "stage2": [320, 224, 640, 2, True, [1, 2]], + "stage3": [640, 256, 960, 3, True, [2, 1]], + "stage4": [960, 288, 1280, 2, True, [2, 1]], + } + + model = PPHGNet( + stem_channels=[96, 96, 160], + stage_config=stage_config, + layer_num=7, + dropout_prob=0.2, + **kwargs) + return model diff --git a/tools/eval.py b/tools/eval.py index d30ca57d8242080944b2563be60175a427893329..75b6941d27874b2cfb50f9f97a80457ae1d6b310 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -83,7 +83,7 @@ def main(): model = build_model(config['Architecture']) extra_input_models = [ "SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "VisionLAN", - "RobustScanner" + "RobustScanner", "SVTR_HGNet" ] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': diff --git a/tools/export_model.py b/tools/export_model.py index 76f5db49a1395e9fce48b225461afd4cda798ee3..45c9bba650eef3559e8421b1a294db0e197ff96f 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -62,7 +62,7 @@ def export_single_model(model, shape=[None], dtype="float32")] ] model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "SVTR_LCNet": + elif arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]: other_shape = [ paddle.static.InputSpec( shape=[None, 3, 48, -1], dtype="float32"), diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 959373cd79315bc16521beb70c66b6c3556e6963..53a7d5ee54897e0503e473ca8ec04453cfb70e56 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -291,7 +291,9 @@ def create_predictor(args, mode, logger): def get_output_tensors(args, mode, predictor): output_names = predictor.get_output_names() output_tensors = [] - if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]: + if mode == "rec" and args.rec_algorithm in [ + "CRNN", "SVTR_LCNet", "SVTR_HGNet" + ]: output_name = 'softmax_0.tmp_0' if output_name in output_names: return [predictor.get_output_handle(output_name)] diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 14eeb46bce04aa9ea64995f0d29a748c1483edf7..80986ccdebb5b0e91cb843933c1a0ee6914ca671 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -68,9 +68,6 @@ def main(): else: config["Architecture"]["Models"][key]["Head"][ "out_channels"] = char_num - # just one final tensor needs to exported for inference - config["Architecture"]["Models"][key][ - "return_all_feats"] = False elif config['Architecture']['Head'][ 'name'] == 'MultiHead': # multi head out_channels_list = {} @@ -86,7 +83,6 @@ def main(): 'out_channels_list'] = out_channels_list else: # base rec model config["Architecture"]["Head"]["out_channels"] = char_num - model = build_model(config['Architecture']) load_model(config, model) diff --git a/tools/program.py b/tools/program.py index 2761e1273392c953fdbe6b1e0cc8dca92533b6df..53983b0b27eff7cc79af1742e6d704f9d54267c8 100755 --- a/tools/program.py +++ b/tools/program.py @@ -230,7 +230,7 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" extra_input_models = [ "SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "SPIN", "VisionLAN", - "RobustScanner", "RFL", 'DRRG', 'SATRN' + "RobustScanner", "RFL", 'DRRG', 'SATRN', 'SVTR_HGNet' ] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': @@ -654,7 +654,7 @@ def preprocess(is_train=False): 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SVTR', 'SVTR_LCNet', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', - 'CAN', 'Telescope', 'SATRN' + 'CAN', 'Telescope', 'SATRN', 'SVTR_HGNet' ] if use_xpu: