From e333d629ed204c738aa07dbf9e305464a5122270 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Fri, 15 Jan 2021 16:44:30 +0800 Subject: [PATCH] Add faster rcnn fpn hrnet 1x & 2x (#2055) * add hrnet, test=develop * add comment, test=dygraph * move config file to hrnet, test=dygraph * add 2x model * add hrnet README * update latest config structure, test=dygraph --- dygraph/configs/hrnet/README.md | 34 + .../hrnet/_base_/faster_rcnn_hrnetv2p_w18.yml | 90 +++ .../faster_rcnn_hrnetv2p_w18_1x_coco.yml | 20 + .../faster_rcnn_hrnetv2p_w18_2x_coco.yml | 23 + dygraph/ppdet/engine/trainer.py | 2 +- dygraph/ppdet/modeling/backbones/__init__.py | 2 + dygraph/ppdet/modeling/backbones/darknet.py | 14 + dygraph/ppdet/modeling/backbones/hrnet.py | 668 ++++++++++++++++++ dygraph/ppdet/modeling/necks/__init__.py | 2 + dygraph/ppdet/modeling/necks/hrfpn.py | 110 +++ dygraph/ppdet/utils/checkpoint.py | 5 +- 11 files changed, 968 insertions(+), 2 deletions(-) create mode 100644 dygraph/configs/hrnet/README.md create mode 100644 dygraph/configs/hrnet/_base_/faster_rcnn_hrnetv2p_w18.yml create mode 100644 dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml create mode 100644 dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_2x_coco.yml create mode 100644 dygraph/ppdet/modeling/backbones/hrnet.py create mode 100644 dygraph/ppdet/modeling/necks/hrfpn.py diff --git a/dygraph/configs/hrnet/README.md b/dygraph/configs/hrnet/README.md new file mode 100644 index 000000000..5b8cb5cab --- /dev/null +++ b/dygraph/configs/hrnet/README.md @@ -0,0 +1,34 @@ +# High-resolution networks (HRNets) for object detection + +## Introduction + +- Deep High-Resolution Representation Learning for Human Pose Estimation: [https://arxiv.org/abs/1902.09212](https://arxiv.org/abs/1902.09212) + +``` +@inproceedings{SunXLW19, + title={Deep High-Resolution Representation Learning for Human Pose Estimation}, + author={Ke Sun and Bin Xiao and Dong Liu and Jingdong Wang}, + booktitle={CVPR}, + year={2019} +} +``` + +- High-Resolution Representations for Labeling Pixels and Regions: [https://arxiv.org/abs/1904.04514](https://arxiv.org/abs/1904.04514) + +``` +@article{SunZJCXLMWLW19, + title={High-Resolution Representations for Labeling Pixels and Regions}, + author={Ke Sun and Yang Zhao and Borui Jiang and Tianheng Cheng and Bin Xiao + and Dong Liu and Yadong Mu and Xinggang Wang and Wenyu Liu and Jingdong Wang}, + journal = {CoRR}, + volume = {abs/1904.04514}, + year={2019} +} +``` + +## Model Zoo + +| Backbone | Type | deformable Conv | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download | Configs | +| :---------------------- | :------------- | :---: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: | :-----: | +| HRNetV2p_W18 | Faster | False | 2 | 1x | - | 35.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml) | +| HRNetV2p_W18 | Faster | False | 2 | 2x | - | 37.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_2x_coco.yml) | diff --git a/dygraph/configs/hrnet/_base_/faster_rcnn_hrnetv2p_w18.yml b/dygraph/configs/hrnet/_base_/faster_rcnn_hrnetv2p_w18.yml new file mode 100644 index 000000000..0f6fb8f11 --- /dev/null +++ b/dygraph/configs/hrnet/_base_/faster_rcnn_hrnetv2p_w18.yml @@ -0,0 +1,90 @@ +architecture: FasterRCNN +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar +weights: output/faster_rcnn_hrnetv2p_w18_1x_coco/model_final +load_static_weights: True + +# Model Achitecture +FasterRCNN: + # model anchor info flow + anchor: Anchor + proposal: Proposal + # model feat info flow + backbone: HRNet + neck: HRFPN + rpn_head: RPNHead + bbox_head: BBoxHead + # post process + bbox_post_process: BBoxPostProcess + +HRNet: + width: 18 + freeze_at: 0 + return_idx: [0, 1, 2, 3] + +HRFPN: + out_channel: 256 + share_conv: false + +RPNHead: + rpn_feat: + name: RPNFeat + feat_in: 256 + feat_out: 256 + anchor_per_position: 3 + rpn_channel: 256 + +Anchor: + anchor_generator: + name: AnchorGeneratorRPN + aspect_ratios: [0.5, 1.0, 2.0] + anchor_start_size: 32 + stride: [4., 4.] + anchor_target_generator: + name: AnchorTargetGeneratorRPN + batch_size_per_im: 256 + fg_fraction: 0.5 + negative_overlap: 0.3 + positive_overlap: 0.7 + straddle_thresh: 0.0 + +Proposal: + proposal_generator: + name: ProposalGenerator + min_size: 0.0 + nms_thresh: 0.7 + train_pre_nms_top_n: 2000 + train_post_nms_top_n: 2000 + infer_pre_nms_top_n: 1000 + infer_post_nms_top_n: 1000 + proposal_target_generator: + name: ProposalTargetGenerator + batch_size_per_im: 512 + bbox_reg_weights: [0.1, 0.1, 0.2, 0.2] + bg_thresh_hi: [0.5,] + bg_thresh_lo: [0.0,] + fg_thresh: [0.5,] + fg_fraction: 0.25 + +BBoxHead: + bbox_feat: + name: BBoxFeat + roi_extractor: + name: RoIAlign + resolution: 7 + sampling_ratio: 2 + head_feat: + name: TwoFCHead + in_dim: 256 + mlp_dim: 1024 + in_feat: 1024 + +BBoxPostProcess: + decode: + name: RCNNBox + num_classes: 81 + batch_size: 1 + nms: + name: MultiClassNMS + keep_top_k: 100 + score_threshold: 0.05 + nms_threshold: 0.5 diff --git a/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml b/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml new file mode 100644 index 000000000..f68bac5fe --- /dev/null +++ b/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml @@ -0,0 +1,20 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + './_base_/faster_rcnn_hrnetv2p_w18.yml', + '../faster_rcnn/_base_/optimizer_1x.yml', + '../faster_rcnn/_base_/faster_fpn_reader.yml', + '../runtime.yml', +] + +LearningRate: + base_lr: 0.02 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [8, 11] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 + +TrainReader: + batch_size: 2 diff --git a/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_2x_coco.yml b/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_2x_coco.yml new file mode 100644 index 000000000..73d9dc885 --- /dev/null +++ b/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_2x_coco.yml @@ -0,0 +1,23 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + './_base_/faster_rcnn_hrnetv2p_w18.yml', + '../faster_rcnn/_base_/optimizer_1x.yml', + '../faster_rcnn/_base_/faster_fpn_reader.yml', + '../runtime.yml', +] + +weights: output/faster_rcnn_hrnetv2p_w18_2x_coco/model_final +epoch: 24 + +LearningRate: + base_lr: 0.02 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [16, 22] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 + +TrainReader: + batch_size: 2 diff --git a/dygraph/ppdet/engine/trainer.py b/dygraph/ppdet/engine/trainer.py index e6f2bab5c..691c92c6a 100644 --- a/dygraph/ppdet/engine/trainer.py +++ b/dygraph/ppdet/engine/trainer.py @@ -156,6 +156,7 @@ class Trainer(object): def train(self): assert self.mode == 'train', "Model not in 'train' mode" + self.model.train() # if no given weights loaded, load backbone pretrain weights as default if not self._weights_loaded: @@ -184,7 +185,6 @@ class Trainer(object): self._compose_callback.on_step_begin(self.status) # model forward - self.model.train() outputs = self.model(data) loss = outputs['loss'] diff --git a/dygraph/ppdet/modeling/backbones/__init__.py b/dygraph/ppdet/modeling/backbones/__init__.py index b156de4ae..07964090b 100644 --- a/dygraph/ppdet/modeling/backbones/__init__.py +++ b/dygraph/ppdet/modeling/backbones/__init__.py @@ -3,9 +3,11 @@ from . import resnet from . import darknet from . import mobilenet_v1 from . import mobilenet_v3 +from . import hrnet from .vgg import * from .resnet import * from .darknet import * from .mobilenet_v1 import * from .mobilenet_v3 import * +from .hrnet import * diff --git a/dygraph/ppdet/modeling/backbones/darknet.py b/dygraph/ppdet/modeling/backbones/darknet.py index fc4debb40..d08811ad0 100755 --- a/dygraph/ppdet/modeling/backbones/darknet.py +++ b/dygraph/ppdet/modeling/backbones/darknet.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle import paddle.nn as nn import paddle.nn.functional as F diff --git a/dygraph/ppdet/modeling/backbones/hrnet.py b/dygraph/ppdet/modeling/backbones/hrnet.py new file mode 100644 index 000000000..71715a4b5 --- /dev/null +++ b/dygraph/ppdet/modeling/backbones/hrnet.py @@ -0,0 +1,668 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.regularizer import L2Decay +from paddle import ParamAttr +from paddle.nn.initializer import Normal +from numbers import Integral +import math + +from ppdet.core.workspace import register, serializable + +__all__ = ['HRNet'] + + +class ConvNormLayer(nn.Layer): + def __init__(self, + ch_in, + ch_out, + filter_size, + stride=1, + norm_type='bn', + norm_groups=32, + use_dcn=False, + norm_decay=0., + freeze_norm=False, + act=None, + name=None): + super(ConvNormLayer, self).__init__() + assert norm_type in ['bn', 'sync_bn', 'gn'] + + self.act = act + self.conv = nn.Conv2D( + in_channels=ch_in, + out_channels=ch_out, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=1, + weight_attr=ParamAttr( + name=name + "_weights", initializer=Normal( + mean=0., std=0.01)), + bias_attr=False) + + norm_lr = 0. if freeze_norm else 1. + + norm_name = name + '_bn' + param_attr = ParamAttr( + name=norm_name + "_scale", + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + bias_attr = ParamAttr( + name=norm_name + "_offset", + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + global_stats = True if freeze_norm else False + if norm_type in ['bn', 'sync_bn']: + self.norm = nn.BatchNorm( + ch_out, + param_attr=param_attr, + bias_attr=bias_attr, + use_global_stats=global_stats, + moving_mean_name=norm_name + '_mean', + moving_variance_name=norm_name + '_variance') + elif norm_type == 'gn': + self.norm = nn.GroupNorm( + num_groups=norm_groups, + num_channels=ch_out, + weight_attr=param_attr, + bias_attr=bias_attr) + norm_params = self.norm.parameters() + if freeze_norm: + for param in norm_params: + param.stop_gradient = True + + def forward(self, inputs): + out = self.conv(inputs) + out = self.norm(out) + + if self.act == 'relu': + out = F.relu(out) + return out + + +class Layer1(nn.Layer): + def __init__(self, num_channels, has_se=False, freeze_norm=True, name=None): + super(Layer1, self).__init__() + + self.bottleneck_block_list = [] + + for i in range(4): + bottleneck_block = self.add_sublayer( + "block_{}_{}".format(name, i + 1), + BottleneckBlock( + num_channels=num_channels if i == 0 else 256, + num_filters=64, + has_se=has_se, + stride=1, + downsample=True if i == 0 else False, + freeze_norm=freeze_norm, + name=name + '_' + str(i + 1))) + self.bottleneck_block_list.append(bottleneck_block) + + def forward(self, input): + conv = input + for block_func in self.bottleneck_block_list: + conv = block_func(conv) + return conv + + +class TransitionLayer(nn.Layer): + def __init__(self, in_channels, out_channels, freeze_norm=True, name=None): + super(TransitionLayer, self).__init__() + + num_in = len(in_channels) + num_out = len(out_channels) + out = [] + self.conv_bn_func_list = [] + for i in range(num_out): + residual = None + if i < num_in: + if in_channels[i] != out_channels[i]: + residual = self.add_sublayer( + "transition_{}_layer_{}".format(name, i + 1), + ConvNormLayer( + ch_in=in_channels[i], + ch_out=out_channels[i], + filter_size=3, + freeze_norm=freeze_norm, + act='relu', + name=name + '_layer_' + str(i + 1))) + else: + residual = self.add_sublayer( + "transition_{}_layer_{}".format(name, i + 1), + ConvNormLayer( + ch_in=in_channels[-1], + ch_out=out_channels[i], + filter_size=3, + stride=2, + freeze_norm=freeze_norm, + act='relu', + name=name + '_layer_' + str(i + 1))) + self.conv_bn_func_list.append(residual) + + def forward(self, input): + outs = [] + for idx, conv_bn_func in enumerate(self.conv_bn_func_list): + if conv_bn_func is None: + outs.append(input[idx]) + else: + if idx < len(input): + outs.append(conv_bn_func(input[idx])) + else: + outs.append(conv_bn_func(input[-1])) + return outs + + +class Branches(nn.Layer): + def __init__(self, + block_num, + in_channels, + out_channels, + has_se=False, + freeze_norm=True, + name=None): + super(Branches, self).__init__() + + self.basic_block_list = [] + for i in range(len(out_channels)): + self.basic_block_list.append([]) + for j in range(block_num): + in_ch = in_channels[i] if j == 0 else out_channels[i] + basic_block_func = self.add_sublayer( + "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1), + BasicBlock( + num_channels=in_ch, + num_filters=out_channels[i], + has_se=has_se, + freeze_norm=freeze_norm, + name=name + '_branch_layer_' + str(i + 1) + '_' + + str(j + 1))) + self.basic_block_list[i].append(basic_block_func) + + def forward(self, inputs): + outs = [] + for idx, input in enumerate(inputs): + conv = input + basic_block_list = self.basic_block_list[idx] + for basic_block_func in basic_block_list: + conv = basic_block_func(conv) + outs.append(conv) + return outs + + +class BottleneckBlock(nn.Layer): + def __init__(self, + num_channels, + num_filters, + has_se, + stride=1, + downsample=False, + freeze_norm=True, + name=None): + super(BottleneckBlock, self).__init__() + + self.has_se = has_se + self.downsample = downsample + + self.conv1 = ConvNormLayer( + ch_in=num_channels, + ch_out=num_filters, + filter_size=1, + freeze_norm=freeze_norm, + act="relu", + name=name + "_conv1") + self.conv2 = ConvNormLayer( + ch_in=num_filters, + ch_out=num_filters, + filter_size=3, + stride=stride, + freeze_norm=freeze_norm, + act="relu", + name=name + "_conv2") + self.conv3 = ConvNormLayer( + ch_in=num_filters, + ch_out=num_filters * 4, + filter_size=1, + freeze_norm=freeze_norm, + act=None, + name=name + "_conv3") + + if self.downsample: + self.conv_down = ConvNormLayer( + ch_in=num_channels, + ch_out=num_filters * 4, + filter_size=1, + freeze_norm=freeze_norm, + act=None, + name=name + "_downsample") + + if self.has_se: + self.se = SELayer( + num_channels=num_filters * 4, + num_filters=num_filters * 4, + reduction_ratio=16, + name='fc' + name) + + def forward(self, input): + residual = input + conv1 = self.conv1(input) + conv2 = self.conv2(conv1) + conv3 = self.conv3(conv2) + + if self.downsample: + residual = self.conv_down(input) + + if self.has_se: + conv3 = self.se(conv3) + + y = paddle.add(x=residual, y=conv3) + y = F.relu(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, + num_channels, + num_filters, + stride=1, + has_se=False, + downsample=False, + freeze_norm=True, + name=None): + super(BasicBlock, self).__init__() + + self.has_se = has_se + self.downsample = downsample + self.conv1 = ConvNormLayer( + ch_in=num_channels, + ch_out=num_filters, + filter_size=3, + freeze_norm=freeze_norm, + stride=stride, + act="relu", + name=name + "_conv1") + self.conv2 = ConvNormLayer( + ch_in=num_filters, + ch_out=num_filters, + filter_size=3, + freeze_norm=freeze_norm, + stride=1, + act=None, + name=name + "_conv2") + + if self.downsample: + self.conv_down = ConvNormLayer( + ch_in=num_channels, + ch_out=num_filters * 4, + filter_size=1, + freeze_norm=freeze_norm, + act=None, + name=name + "_downsample") + + if self.has_se: + self.se = SELayer( + num_channels=num_filters, + num_filters=num_filters, + reduction_ratio=16, + name='fc' + name) + + def forward(self, input): + residual = input + conv1 = self.conv1(input) + conv2 = self.conv2(conv1) + + if self.downsample: + residual = self.conv_down(input) + + if self.has_se: + conv2 = self.se(conv2) + + y = paddle.add(x=residual, y=conv2) + y = F.relu(y) + return y + + +class SELayer(nn.Layer): + def __init__(self, num_channels, num_filters, reduction_ratio, name=None): + super(SELayer, self).__init__() + + self.pool2d_gap = AdaptiveAvgPool2D(1) + + self._num_channels = num_channels + + med_ch = int(num_channels / reduction_ratio) + stdv = 1.0 / math.sqrt(num_channels * 1.0) + self.squeeze = Linear( + num_channels, + med_ch, + weight_attr=ParamAttr( + initializer=Uniform(-stdv, stdv), name=name + "_sqz_weights"), + bias_attr=ParamAttr(name=name + '_sqz_offset')) + + stdv = 1.0 / math.sqrt(med_ch * 1.0) + self.excitation = Linear( + med_ch, + num_filters, + weight_attr=ParamAttr( + initializer=Uniform(-stdv, stdv), name=name + "_exc_weights"), + bias_attr=ParamAttr(name=name + '_exc_offset')) + + def forward(self, input): + pool = self.pool2d_gap(input) + pool = paddle.squeeze(pool, axis=[2, 3]) + squeeze = self.squeeze(pool) + squeeze = F.relu(squeeze) + excitation = self.excitation(squeeze) + excitation = F.sigmoid(excitation) + excitation = paddle.unsqueeze(excitation, axis=[2, 3]) + out = input * excitation + return out + + +class Stage(nn.Layer): + def __init__(self, + num_channels, + num_modules, + num_filters, + has_se=False, + freeze_norm=True, + multi_scale_output=True, + name=None): + super(Stage, self).__init__() + + self._num_modules = num_modules + self.stage_func_list = [] + for i in range(num_modules): + if i == num_modules - 1 and not multi_scale_output: + stage_func = self.add_sublayer( + "stage_{}_{}".format(name, i + 1), + HighResolutionModule( + num_channels=num_channels, + num_filters=num_filters, + has_se=has_se, + freeze_norm=freeze_norm, + multi_scale_output=False, + name=name + '_' + str(i + 1))) + else: + stage_func = self.add_sublayer( + "stage_{}_{}".format(name, i + 1), + HighResolutionModule( + num_channels=num_channels, + num_filters=num_filters, + has_se=has_se, + freeze_norm=freeze_norm, + name=name + '_' + str(i + 1))) + + self.stage_func_list.append(stage_func) + + def forward(self, input): + out = input + for idx in range(self._num_modules): + out = self.stage_func_list[idx](out) + return out + + +class HighResolutionModule(nn.Layer): + def __init__(self, + num_channels, + num_filters, + has_se=False, + multi_scale_output=True, + freeze_norm=True, + name=None): + super(HighResolutionModule, self).__init__() + self.branches_func = Branches( + block_num=4, + in_channels=num_channels, + out_channels=num_filters, + has_se=has_se, + freeze_norm=freeze_norm, + name=name) + + self.fuse_func = FuseLayers( + in_channels=num_filters, + out_channels=num_filters, + multi_scale_output=multi_scale_output, + freeze_norm=freeze_norm, + name=name) + + def forward(self, input): + out = self.branches_func(input) + out = self.fuse_func(out) + return out + + +class FuseLayers(nn.Layer): + def __init__(self, + in_channels, + out_channels, + multi_scale_output=True, + freeze_norm=True, + name=None): + super(FuseLayers, self).__init__() + + self._actual_ch = len(in_channels) if multi_scale_output else 1 + self._in_channels = in_channels + + self.residual_func_list = [] + for i in range(self._actual_ch): + for j in range(len(in_channels)): + residual_func = None + if j > i: + residual_func = self.add_sublayer( + "residual_{}_layer_{}_{}".format(name, i + 1, j + 1), + ConvNormLayer( + ch_in=in_channels[j], + ch_out=out_channels[i], + filter_size=1, + stride=1, + act=None, + freeze_norm=freeze_norm, + name=name + '_layer_' + str(i + 1) + '_' + + str(j + 1))) + self.residual_func_list.append(residual_func) + elif j < i: + pre_num_filters = in_channels[j] + for k in range(i - j): + if k == i - j - 1: + residual_func = self.add_sublayer( + "residual_{}_layer_{}_{}_{}".format( + name, i + 1, j + 1, k + 1), + ConvNormLayer( + ch_in=pre_num_filters, + ch_out=out_channels[i], + filter_size=3, + stride=2, + freeze_norm=freeze_norm, + act=None, + name=name + '_layer_' + str(i + 1) + '_' + + str(j + 1) + '_' + str(k + 1))) + pre_num_filters = out_channels[i] + else: + residual_func = self.add_sublayer( + "residual_{}_layer_{}_{}_{}".format( + name, i + 1, j + 1, k + 1), + ConvNormLayer( + ch_in=pre_num_filters, + ch_out=out_channels[j], + filter_size=3, + stride=2, + freeze_norm=freeze_norm, + act="relu", + name=name + '_layer_' + str(i + 1) + '_' + + str(j + 1) + '_' + str(k + 1))) + pre_num_filters = out_channels[j] + self.residual_func_list.append(residual_func) + + def forward(self, input): + outs = [] + residual_func_idx = 0 + for i in range(self._actual_ch): + residual = input[i] + for j in range(len(self._in_channels)): + if j > i: + y = self.residual_func_list[residual_func_idx](input[j]) + residual_func_idx += 1 + y = F.interpolate(y, scale_factor=2**(j - i)) + residual = paddle.add(x=residual, y=y) + elif j < i: + y = input[j] + for k in range(i - j): + y = self.residual_func_list[residual_func_idx](y) + residual_func_idx += 1 + + residual = paddle.add(x=residual, y=y) + residual = F.relu(residual) + outs.append(residual) + + return outs + + +@register +class HRNet(nn.Layer): + """ + HRNet, see https://arxiv.org/abs/1908.07919 + + Args: + width (int): the width of HRNet + has_se (bool): whether to add SE block for each stage + freeze_at (int): the stage to freeze + freeze_norm (bool): whether to freeze norm in HRNet + return_idx (List): the stage to return + """ + + def __init__(self, + width=18, + has_se=False, + freeze_at=0, + freeze_norm=True, + norm_decay=0., + return_idx=[0, 1, 2, 3]): + super(HRNet, self).__init__() + + self.width = width + self.has_se = has_se + if isinstance(return_idx, Integral): + return_idx = [return_idx] + + assert len(return_idx) > 0, "need one or more return index" + self.freeze_at = freeze_at + self.return_idx = return_idx + + self.channels = { + 18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]], + 30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]], + 32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]], + 40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]], + 44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]], + 48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]], + 60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]], + 64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]] + } + + channels_2, channels_3, channels_4 = self.channels[width] + num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3 + + self.conv_layer1_1 = ConvNormLayer( + ch_in=3, + ch_out=64, + filter_size=3, + stride=2, + freeze_norm=freeze_norm, + act='relu', + name="layer1_1") + + self.conv_layer1_2 = ConvNormLayer( + ch_in=64, + ch_out=64, + filter_size=3, + stride=2, + freeze_norm=freeze_norm, + act='relu', + name="layer1_2") + + self.la1 = Layer1( + num_channels=64, + has_se=has_se, + freeze_norm=freeze_norm, + name="layer2") + + self.tr1 = TransitionLayer( + in_channels=[256], + out_channels=channels_2, + freeze_norm=freeze_norm, + name="tr1") + + self.st2 = Stage( + num_channels=channels_2, + num_modules=num_modules_2, + num_filters=channels_2, + has_se=self.has_se, + freeze_norm=freeze_norm, + name="st2") + + self.tr2 = TransitionLayer( + in_channels=channels_2, + out_channels=channels_3, + freeze_norm=freeze_norm, + name="tr2") + + self.st3 = Stage( + num_channels=channels_3, + num_modules=num_modules_3, + num_filters=channels_3, + has_se=self.has_se, + freeze_norm=freeze_norm, + name="st3") + + self.tr3 = TransitionLayer( + in_channels=channels_3, + out_channels=channels_4, + freeze_norm=freeze_norm, + name="tr3") + self.st4 = Stage( + num_channels=channels_4, + num_modules=num_modules_4, + num_filters=channels_4, + has_se=self.has_se, + freeze_norm=freeze_norm, + name="st4") + + def forward(self, inputs): + x = inputs['image'] + conv1 = self.conv_layer1_1(x) + conv2 = self.conv_layer1_2(conv1) + + la1 = self.la1(conv2) + tr1 = self.tr1([la1]) + st2 = self.st2(tr1) + tr2 = self.tr2(st2) + + st3 = self.st3(tr2) + tr3 = self.tr3(st3) + + st4 = self.st4(tr3) + + res = [] + for i, layer in enumerate(st4): + if i == self.freeze_at: + layer.stop_gradient = True + if i in self.return_idx: + res.append(layer) + + return res diff --git a/dygraph/ppdet/modeling/necks/__init__.py b/dygraph/ppdet/modeling/necks/__init__.py index 0b61c3292..01288ca6b 100644 --- a/dygraph/ppdet/modeling/necks/__init__.py +++ b/dygraph/ppdet/modeling/necks/__init__.py @@ -14,6 +14,8 @@ from . import fpn from . import yolo_fpn +from . import hrfpn from .fpn import * from .yolo_fpn import * +from .hrfpn import * diff --git a/dygraph/ppdet/modeling/necks/hrfpn.py b/dygraph/ppdet/modeling/necks/hrfpn.py new file mode 100644 index 000000000..f06b3cace --- /dev/null +++ b/dygraph/ppdet/modeling/necks/hrfpn.py @@ -0,0 +1,110 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn.functional as F +from paddle import ParamAttr +import paddle.nn as nn +from paddle.regularizer import L2Decay +from ppdet.core.workspace import register, serializable + +__all__ = ['HRFPN'] + + +@register +class HRFPN(nn.Layer): + """ + Args: + in_channel (int): number of input feature channels from backbone + out_channel (int): number of output feature channels + share_conv (bool): whether to share conv for different layers' reduction + spatial_scale (list): feature map scaling factor + """ + + def __init__( + self, + in_channel=270, + out_channel=256, + share_conv=False, + spatial_scale=[1. / 4, 1. / 8, 1. / 16, 1. / 32, 1. / 64], ): + super(HRFPN, self).__init__() + self.in_channel = in_channel + self.out_channel = out_channel + self.share_conv = share_conv + self.spatial_scale = spatial_scale + + self.reduction = nn.Conv2D( + in_channels=in_channel, + out_channels=out_channel, + kernel_size=1, + weight_attr=ParamAttr(name='hrfpn_reduction_weights'), + bias_attr=False) + self.num_out = len(self.spatial_scale) + if share_conv: + self.fpn_conv = nn.Conv2D( + in_channels=out_channel, + out_channels=out_channel, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(name='fpn_conv_weights'), + bias_attr=False) + else: + self.fpn_conv = [] + for i in range(self.num_out): + conv_name = "fpn_conv_" + str(i) + conv = self.add_sublayer( + conv_name, + nn.Conv2D( + in_channels=out_channel, + out_channels=out_channel, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(name=conv_name + "_weights"), + bias_attr=False)) + self.fpn_conv.append(conv) + + def forward(self, body_feats): + num_backbone_stages = len(body_feats) + + outs = [] + outs.append(body_feats[0]) + + # resize + for i in range(1, num_backbone_stages): + resized = F.interpolate( + body_feats[i], scale_factor=2**i, mode='bilinear') + outs.append(resized) + + # concat + out = paddle.concat(outs, axis=1) + assert out.shape[ + 1] == self.in_channel, 'in_channel should be {}, be received {}'.format( + out.shape[1], self.in_channel) + + # reduction + out = self.reduction(out) + + # conv + outs = [out] + for i in range(1, self.num_out): + outs.append(F.avg_pool2d(out, kernel_size=2**i, stride=2**i)) + outputs = [] + + for i in range(self.num_out): + conv_func = self.fpn_conv if self.share_conv else self.fpn_conv[i] + conv = conv_func(outs[i]) + outputs.append(conv) + + fpn_feat = [outputs[k] for k in range(self.num_out)] + return fpn_feat, self.spatial_scale diff --git a/dygraph/ppdet/utils/checkpoint.py b/dygraph/ppdet/utils/checkpoint.py index 05ce8171b..f96804f6b 100644 --- a/dygraph/ppdet/utils/checkpoint.py +++ b/dygraph/ppdet/utils/checkpoint.py @@ -95,7 +95,7 @@ def load_weight(model, weight, optimizer=None): last_epoch = 0 if optimizer is not None and os.path.exists(path + '.pdopt'): optim_state_dict = paddle.load(path + '.pdopt') - # to slove resume bug, will it be fixed in paddle 2.0 + # to solve resume bug, will it be fixed in paddle 2.0 for key in optimizer.state_dict().keys(): if not key in optim_state_dict.keys(): optim_state_dict[key] = optimizer.state_dict()[key] @@ -132,6 +132,9 @@ def load_pretrain_weight(model, weight_name, pre_state_dict[weight_name].shape)) param_state_dict[key] = pre_state_dict[weight_name] else: + if 'backbone' in key: + logger.info('Lack weight: {}, structure name: {}'.format( + weight_name, key)) param_state_dict[key] = model_dict[key] model.set_dict(param_state_dict) return -- GitLab