From d57862883e5e5d4dec9722abd8245a76ef49a479 Mon Sep 17 00:00:00 2001 From: JYChen Date: Wed, 4 Aug 2021 10:21:14 +0800 Subject: [PATCH] Add lite hr net (#3793) * add LiteHRNet backbone and config .YML * test lite18-network param acc is same with ori-model 1. fix default darkpose=ON, 2. += is not inplace add new keypoint model Lite-HRNet * add new keypoint model Lite-HRNet * 1. Add description of network type; 2. use channel_shuffle in ops.py * use normal to init conv2d * add network type description --- .../lite_hrnet/lite_hrnet_18_256x192_coco.yml | 140 +++ .../lite_hrnet/lite_hrnet_30_256x192_coco.yml | 140 +++ .../modeling/architectures/keypoint_hrnet.py | 9 +- ppdet/modeling/backbones/__init__.py | 2 + ppdet/modeling/backbones/lite_hrnet.py | 886 ++++++++++++++++++ ppdet/modeling/backbones/shufflenet_v2.py | 17 +- ppdet/modeling/losses/keypoint_loss.py | 9 +- ppdet/modeling/ops.py | 12 + 8 files changed, 1191 insertions(+), 24 deletions(-) create mode 100644 configs/keypoint/lite_hrnet/lite_hrnet_18_256x192_coco.yml create mode 100644 configs/keypoint/lite_hrnet/lite_hrnet_30_256x192_coco.yml create mode 100644 ppdet/modeling/backbones/lite_hrnet.py diff --git a/configs/keypoint/lite_hrnet/lite_hrnet_18_256x192_coco.yml b/configs/keypoint/lite_hrnet/lite_hrnet_18_256x192_coco.yml new file mode 100644 index 000000000..266408246 --- /dev/null +++ b/configs/keypoint/lite_hrnet/lite_hrnet_18_256x192_coco.yml @@ -0,0 +1,140 @@ +use_gpu: true +log_iter: 5 +save_dir: output +snapshot_epoch: 10 +weights: output/lite_hrnet_18_256x192_coco/model_final +epoch: 210 +num_joints: &num_joints 17 +pixel_std: &pixel_std 200 +metric: KeyPointTopDownCOCOEval +num_classes: 1 +train_height: &train_height 256 +train_width: &train_width 192 +trainsize: &trainsize [*train_width, *train_height] +hmsize: &hmsize [48, 64] +flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] + + +#####model +architecture: TopDownHRNet + +TopDownHRNet: + backbone: LiteHRNet + post_process: HRNetPostProcess + flip_perm: *flip_perm + num_joints: *num_joints + width: &width 40 + loss: KeyPointMSELoss + use_dark: false + +LiteHRNet: + network_type: lite_18 + freeze_at: -1 + freeze_norm: false + return_idx: [0] + +KeyPointMSELoss: + use_target_weight: true + loss_scale: 1.0 + +#####optimizer +LearningRate: + base_lr: 0.002 + schedulers: + - !PiecewiseDecay + milestones: [170, 200] + gamma: 0.1 + - !LinearWarmup + start_factor: 0.001 + steps: 500 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: + factor: 0.0 + type: L2 + + +#####data +TrainDataset: + !KeypointTopDownCocoDataset + image_dir: train2017 + anno_path: annotations/person_keypoints_train2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + + +EvalDataset: + !KeypointTopDownCocoDataset + image_dir: val2017 + anno_path: annotations/person_keypoints_val2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + image_thre: 0.0 + + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 2 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - RandomFlipHalfBodyTransform: + scale: 0.25 + rot: 30 + num_joints_half_body: 8 + prob_half_body: 0.3 + pixel_std: *pixel_std + trainsize: *trainsize + upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + flip_pairs: *flip_perm + - TopDownAffine: + trainsize: *trainsize + - ToHeatmapsTopDown: + hmsize: *hmsize + sigma: 2 + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 64 + shuffle: true + drop_last: false + +EvalReader: + sample_transforms: + - TopDownAffine: + trainsize: *trainsize + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 16 + +TestReader: + inputs_def: + image_shape: [3, *train_height, *train_width] + sample_transforms: + - Decode: {} + - TopDownEvalAffine: + trainsize: *trainsize + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 diff --git a/configs/keypoint/lite_hrnet/lite_hrnet_30_256x192_coco.yml b/configs/keypoint/lite_hrnet/lite_hrnet_30_256x192_coco.yml new file mode 100644 index 000000000..118ba3604 --- /dev/null +++ b/configs/keypoint/lite_hrnet/lite_hrnet_30_256x192_coco.yml @@ -0,0 +1,140 @@ +use_gpu: true +log_iter: 5 +save_dir: output +snapshot_epoch: 10 +weights: output/lite_hrnet_30_256x192_coco/model_final +epoch: 210 +num_joints: &num_joints 17 +pixel_std: &pixel_std 200 +metric: KeyPointTopDownCOCOEval +num_classes: 1 +train_height: &train_height 256 +train_width: &train_width 192 +trainsize: &trainsize [*train_width, *train_height] +hmsize: &hmsize [48, 64] +flip_perm: &flip_perm [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] + + +#####model +architecture: TopDownHRNet + +TopDownHRNet: + backbone: LiteHRNet + post_process: HRNetPostProcess + flip_perm: *flip_perm + num_joints: *num_joints + width: &width 40 + loss: KeyPointMSELoss + use_dark: false + +LiteHRNet: + network_type: lite_30 + freeze_at: -1 + freeze_norm: false + return_idx: [0] + +KeyPointMSELoss: + use_target_weight: true + loss_scale: 1.0 + +#####optimizer +LearningRate: + base_lr: 0.002 + schedulers: + - !PiecewiseDecay + milestones: [170, 200] + gamma: 0.1 + - !LinearWarmup + start_factor: 0.001 + steps: 500 + +OptimizerBuilder: + optimizer: + type: Adam + regularizer: + factor: 0.0 + type: L2 + + +#####data +TrainDataset: + !KeypointTopDownCocoDataset + image_dir: train2017 + anno_path: annotations/person_keypoints_train2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + + +EvalDataset: + !KeypointTopDownCocoDataset + image_dir: val2017 + anno_path: annotations/person_keypoints_val2017.json + dataset_dir: dataset/coco + num_joints: *num_joints + trainsize: *trainsize + pixel_std: *pixel_std + use_gt_bbox: True + image_thre: 0.0 + + +TestDataset: + !ImageFolder + anno_path: dataset/coco/keypoint_imagelist.txt + +worker_num: 4 +global_mean: &global_mean [0.485, 0.456, 0.406] +global_std: &global_std [0.229, 0.224, 0.225] +TrainReader: + sample_transforms: + - RandomFlipHalfBodyTransform: + scale: 0.25 + rot: 30 + num_joints_half_body: 8 + prob_half_body: 0.3 + pixel_std: *pixel_std + trainsize: *trainsize + upper_body_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + flip_pairs: *flip_perm + - TopDownAffine: + trainsize: *trainsize + - ToHeatmapsTopDown: + hmsize: *hmsize + sigma: 2 + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 64 + shuffle: true + drop_last: false + +EvalReader: + sample_transforms: + - TopDownAffine: + trainsize: *trainsize + batch_transforms: + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 16 + +TestReader: + inputs_def: + image_shape: [3, *train_height, *train_width] + sample_transforms: + - Decode: {} + - TopDownEvalAffine: + trainsize: *trainsize + - NormalizeImage: + mean: *global_mean + std: *global_std + is_scale: true + - Permute: {} + batch_size: 1 diff --git a/ppdet/modeling/architectures/keypoint_hrnet.py b/ppdet/modeling/architectures/keypoint_hrnet.py index 3cacf3a53..64c269a7a 100644 --- a/ppdet/modeling/architectures/keypoint_hrnet.py +++ b/ppdet/modeling/architectures/keypoint_hrnet.py @@ -41,18 +41,20 @@ class TopDownHRNet(BaseArch): post_process='HRNetPostProcess', flip_perm=None, flip=True, - shift_heatmap=True): + shift_heatmap=True, + use_dark=True): """ - HRNnet network, see https://arxiv.org/abs/1902.09212 + HRNet network, see https://arxiv.org/abs/1902.09212 Args: backbone (nn.Layer): backbone instance post_process (object): `HRNetPostProcess` instance flip_perm (list): The left-right joints exchange order list + use_dark(bool): Whether to use DARK in post processing """ super(TopDownHRNet, self).__init__() self.backbone = backbone - self.post_process = HRNetPostProcess() + self.post_process = HRNetPostProcess(use_dark) self.loss = loss self.flip_perm = flip_perm self.flip = flip @@ -218,7 +220,6 @@ class HRNetPostProcess(object): preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints """ - coords, maxvals = self.get_max_preds(heatmaps) heatmap_height = heatmaps.shape[2] diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index c6e1c0c2c..58fd37138 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -18,6 +18,7 @@ from . import darknet from . import mobilenet_v1 from . import mobilenet_v3 from . import hrnet +from . import lite_hrnet from . import blazenet from . import ghostnet from . import senet @@ -31,6 +32,7 @@ from .darknet import * from .mobilenet_v1 import * from .mobilenet_v3 import * from .hrnet import * +from .lite_hrnet import * from .blazenet import * from .ghostnet import * from .senet import * diff --git a/ppdet/modeling/backbones/lite_hrnet.py b/ppdet/modeling/backbones/lite_hrnet.py new file mode 100644 index 000000000..be32c132d --- /dev/null +++ b/ppdet/modeling/backbones/lite_hrnet.py @@ -0,0 +1,886 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from numbers import Integral +from paddle import ParamAttr +from paddle.regularizer import L2Decay +from paddle.nn.initializer import Normal, Constant +from ppdet.core.workspace import register +from ppdet.modeling.shape_spec import ShapeSpec +from ppdet.modeling.ops import channel_shuffle +from .. import layers as L + +__all__ = ['LiteHRNet'] + + +class ConvNormLayer(nn.Layer): + def __init__(self, + ch_in, + ch_out, + filter_size, + stride=1, + groups=1, + norm_type=None, + norm_groups=32, + norm_decay=0., + freeze_norm=False, + act=None): + super(ConvNormLayer, self).__init__() + self.act = act + norm_lr = 0. if freeze_norm else 1. + if norm_type is not None: + assert ( + norm_type in ['bn', 'sync_bn', 'gn'], + "norm_type should be one of ['bn', 'sync_bn', 'gn'], but got {}". + format(norm_type)) + param_attr = ParamAttr( + initializer=Constant(1.0), + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay), ) + bias_attr = ParamAttr( + learning_rate=norm_lr, regularizer=L2Decay(norm_decay)) + global_stats = True if freeze_norm else False + if norm_type in ['bn', 'sync_bn']: + self.norm = nn.BatchNorm( + ch_out, + param_attr=param_attr, + bias_attr=bias_attr, + use_global_stats=global_stats, ) + elif norm_type == 'gn': + self.norm = nn.GroupNorm( + num_groups=norm_groups, + num_channels=ch_out, + weight_attr=param_attr, + bias_attr=bias_attr) + norm_params = self.norm.parameters() + if freeze_norm: + for param in norm_params: + param.stop_gradient = True + conv_bias_attr = False + else: + conv_bias_attr = True + self.norm = None + + self.conv = nn.Conv2D( + in_channels=ch_in, + out_channels=ch_out, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.001)), + bias_attr=conv_bias_attr) + + def forward(self, inputs): + out = self.conv(inputs) + if self.norm is not None: + out = self.norm(out) + + if self.act == 'relu': + out = F.relu(out) + elif self.act == 'sigmoid': + out = F.sigmoid(out) + return out + + +class DepthWiseSeparableConvNormLayer(nn.Layer): + def __init__(self, + ch_in, + ch_out, + filter_size, + stride=1, + dw_norm_type=None, + pw_norm_type=None, + norm_decay=0., + freeze_norm=False, + dw_act=None, + pw_act=None): + super(DepthWiseSeparableConvNormLayer, self).__init__() + self.depthwise_conv = ConvNormLayer( + ch_in=ch_in, + ch_out=ch_in, + filter_size=filter_size, + stride=stride, + groups=ch_in, + norm_type=dw_norm_type, + act=dw_act, + norm_decay=norm_decay, + freeze_norm=freeze_norm, ) + self.pointwise_conv = ConvNormLayer( + ch_in=ch_in, + ch_out=ch_out, + filter_size=1, + stride=1, + norm_type=pw_norm_type, + act=pw_act, + norm_decay=norm_decay, + freeze_norm=freeze_norm, ) + + def forward(self, x): + x = self.depthwise_conv(x) + x = self.pointwise_conv(x) + return x + + +class CrossResolutionWeightingModule(nn.Layer): + def __init__(self, + channels, + ratio=16, + norm_type='bn', + freeze_norm=False, + norm_decay=0.): + super(CrossResolutionWeightingModule, self).__init__() + self.channels = channels + total_channel = sum(channels) + self.conv1 = ConvNormLayer( + ch_in=total_channel, + ch_out=total_channel // ratio, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + self.conv2 = ConvNormLayer( + ch_in=total_channel // ratio, + ch_out=total_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='sigmoid', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + + def forward(self, x): + mini_size = x[-1].shape[-2:] + out = [F.adaptive_avg_pool2d(s, mini_size) for s in x[:-1]] + [x[-1]] + out = paddle.concat(out, 1) + out = self.conv1(out) + out = self.conv2(out) + out = paddle.split(out, self.channels, 1) + out = [ + s * F.interpolate( + a, s.shape[-2:], mode='nearest') for s, a in zip(x, out) + ] + return out + + +class SpatialWeightingModule(nn.Layer): + def __init__(self, in_channel, ratio=16, freeze_norm=False, norm_decay=0.): + super(SpatialWeightingModule, self).__init__() + self.global_avgpooling = nn.AdaptiveAvgPool2D(1) + self.conv1 = ConvNormLayer( + ch_in=in_channel, + ch_out=in_channel // ratio, + filter_size=1, + stride=1, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + self.conv2 = ConvNormLayer( + ch_in=in_channel // ratio, + ch_out=in_channel, + filter_size=1, + stride=1, + act='sigmoid', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + + def forward(self, x): + out = self.global_avgpooling(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out + + +class ConditionalChannelWeightingBlock(nn.Layer): + def __init__(self, + in_channels, + stride, + reduce_ratio, + norm_type='bn', + freeze_norm=False, + norm_decay=0.): + super(ConditionalChannelWeightingBlock, self).__init__() + assert stride in [1, 2] + branch_channels = [channel // 2 for channel in in_channels] + + self.cross_resolution_weighting = CrossResolutionWeightingModule( + branch_channels, + ratio=reduce_ratio, + norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay) + self.depthwise_convs = nn.LayerList([ + ConvNormLayer( + channel, + channel, + filter_size=3, + stride=stride, + groups=channel, + norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay) for channel in branch_channels + ]) + + self.spatial_weighting = nn.LayerList([ + SpatialWeightingModule( + channel, + ratio=4, + freeze_norm=freeze_norm, + norm_decay=norm_decay) for channel in branch_channels + ]) + + def forward(self, x): + x = [s.chunk(2, axis=1) for s in x] + x1 = [s[0] for s in x] + x2 = [s[1] for s in x] + + x2 = self.cross_resolution_weighting(x2) + x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)] + x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)] + + out = [paddle.concat([s1, s2], axis=1) for s1, s2 in zip(x1, x2)] + out = [channel_shuffle(s, groups=2) for s in out] + return out + + +class ShuffleUnit(nn.Layer): + def __init__(self, + in_channel, + out_channel, + stride, + norm_type='bn', + freeze_norm=False, + norm_decay=0.): + super(ShuffleUnit, self).__init__() + branch_channel = out_channel // 2 + stride = self.stride + if self.stride == 1: + assert ( + in_channel == branch_channel * 2, + "when stride=1, in_channel {} should equal to branch_channel*2 {}" + .format(in_channel, branch_channel * 2)) + if stride > 1: + self.branch1 = nn.Sequential( + ConvNormLayer( + ch_in=in_channel, + ch_out=in_channel, + filter_size=3, + stride=self.stride, + groups=in_channel, + norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay), + ConvNormLayer( + ch_in=in_channel, + ch_out=branch_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay), ) + self.branch2 = nn.Sequential( + ConvNormLayer( + ch_in=branch_channel if stride == 1 else in_channel, + ch_out=branch_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay), + ConvNormLayer( + ch_in=branch_channel, + ch_out=branch_channel, + filter_size=3, + stride=self.stride, + groups=branch_channel, + norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay), + ConvNormLayer( + ch_in=branch_channel, + ch_out=branch_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay), ) + + def forward(self, x): + if self.stride > 1: + x1 = self.branch1(x) + x2 = self.branch2(x) + else: + x1, x2 = x.chunk(2, axis=1) + x2 = self.branch2(x2) + out = paddle.concat([x1, x2], axis=1) + out = channel_shuffle(out, groups=2) + return out + + +class IterativeHead(nn.Layer): + def __init__(self, + in_channels, + norm_type='bn', + freeze_norm=False, + norm_decay=0.): + super(IterativeHead, self).__init__() + num_branches = len(in_channels) + self.in_channels = in_channels[::-1] + + projects = [] + for i in range(num_branches): + if i != num_branches - 1: + projects.append( + DepthWiseSeparableConvNormLayer( + ch_in=self.in_channels[i], + ch_out=self.in_channels[i + 1], + filter_size=3, + stride=1, + dw_act=None, + pw_act='relu', + dw_norm_type=norm_type, + pw_norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay)) + else: + projects.append( + DepthWiseSeparableConvNormLayer( + ch_in=self.in_channels[i], + ch_out=self.in_channels[i], + filter_size=3, + stride=1, + dw_act=None, + pw_act='relu', + dw_norm_type=norm_type, + pw_norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay)) + self.projects = nn.LayerList(projects) + + def forward(self, x): + x = x[::-1] + y = [] + last_x = None + for i, s in enumerate(x): + if last_x is not None: + last_x = F.interpolate( + last_x, + size=s.shape[-2:], + mode='bilinear', + align_corners=True) + s = s + last_x + s = self.projects[i](s) + y.append(s) + last_x = s + + return y[::-1] + + +class Stem(nn.Layer): + def __init__(self, + in_channel, + stem_channel, + out_channel, + expand_ratio, + norm_type='bn', + freeze_norm=False, + norm_decay=0.): + super(Stem, self).__init__() + self.conv1 = ConvNormLayer( + in_channel, + stem_channel, + filter_size=3, + stride=2, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + mid_channel = int(round(stem_channel * expand_ratio)) + branch_channel = stem_channel // 2 + if stem_channel == out_channel: + inc_channel = out_channel - branch_channel + else: + inc_channel = out_channel - stem_channel + self.branch1 = nn.Sequential( + ConvNormLayer( + ch_in=branch_channel, + ch_out=branch_channel, + filter_size=3, + stride=2, + groups=branch_channel, + norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay), + ConvNormLayer( + ch_in=branch_channel, + ch_out=inc_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay), ) + self.expand_conv = ConvNormLayer( + ch_in=branch_channel, + ch_out=mid_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + self.depthwise_conv = ConvNormLayer( + ch_in=mid_channel, + ch_out=mid_channel, + filter_size=3, + stride=2, + groups=mid_channel, + norm_type=norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay) + self.linear_conv = ConvNormLayer( + ch_in=mid_channel, + ch_out=branch_channel + if stem_channel == out_channel else stem_channel, + filter_size=1, + stride=1, + norm_type=norm_type, + act='relu', + freeze_norm=freeze_norm, + norm_decay=norm_decay) + + def forward(self, x): + x = self.conv1(x) + x1, x2 = x.chunk(2, axis=1) + x1 = self.branch1(x1) + x2 = self.expand_conv(x2) + x2 = self.depthwise_conv(x2) + x2 = self.linear_conv(x2) + out = paddle.concat([x1, x2], axis=1) + out = channel_shuffle(out, groups=2) + + return out + + +class LiteHRNetModule(nn.Layer): + def __init__(self, + num_branches, + num_blocks, + in_channels, + reduce_ratio, + module_type, + multiscale_output=False, + with_fuse=True, + norm_type='bn', + freeze_norm=False, + norm_decay=0.): + super(LiteHRNetModule, self).__init__() + assert (num_branches == len(in_channels), + "num_branches {} should equal to num_in_channels {}" + .format(num_branches, len(in_channels))) + assert (module_type in ['LITE', 'NAIVE'], + "module_type should be one of ['LITE', 'NAIVE']") + self.num_branches = num_branches + self.in_channels = in_channels + self.multiscale_output = multiscale_output + self.with_fuse = with_fuse + self.norm_type = 'bn' + self.module_type = module_type + + if self.module_type == 'LITE': + self.layers = self._make_weighting_blocks( + num_blocks, + reduce_ratio, + freeze_norm=freeze_norm, + norm_decay=norm_decay) + elif self.module_type == 'NAIVE': + self.layers = self._make_naive_branches( + num_branches, + num_blocks, + freeze_norm=freeze_norm, + norm_decay=norm_decay) + + if self.with_fuse: + self.fuse_layers = self._make_fuse_layers( + freeze_norm=freeze_norm, norm_decay=norm_decay) + self.relu = nn.ReLU() + + def _make_weighting_blocks(self, + num_blocks, + reduce_ratio, + stride=1, + freeze_norm=False, + norm_decay=0.): + layers = [] + for i in range(num_blocks): + layers.append( + ConditionalChannelWeightingBlock( + self.in_channels, + stride=stride, + reduce_ratio=reduce_ratio, + norm_type=self.norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay)) + return nn.Sequential(*layers) + + def _make_naive_branchs(self, + num_branches, + num_blocks, + freeze_norm=False, + norm_decay=0.): + branches = [] + for branch_idx in range(num_branches): + layers = [] + for i in range(num_blocks): + layers.append( + ShuffleUnit( + self.in_channels[branch_idx], + self.in_channels[branch_idx], + stride=1, + norm_type=self.norm_type, + freeze_norm=freeze_norm, + norm_decay=norm_decay)) + branches.append(nn.Sequential(*layers)) + return nn.LayerList(branches) + + def _make_fuse_layers(self, freeze_norm=False, norm_decay=0.): + if self.num_branches == 1: + return None + fuse_layers = [] + num_out_branches = self.num_branches if self.multiscale_output else 1 + for i in range(num_out_branches): + fuse_layer = [] + for j in range(self.num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + L.Conv2d( + self.in_channels[j], + self.in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False, ), + nn.BatchNorm(self.in_channels[i]), + nn.Upsample( + scale_factor=2**(j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv_downsamples = [] + for k in range(i - j): + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + L.Conv2d( + self.in_channels[j], + self.in_channels[j], + kernel_size=3, + stride=2, + padding=1, + groups=self.in_channels[j], + bias=False, ), + nn.BatchNorm(self.in_channels[j]), + L.Conv2d( + self.in_channels[j], + self.in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False, ), + nn.BatchNorm(self.in_channels[i]))) + else: + conv_downsamples.append( + nn.Sequential( + L.Conv2d( + self.in_channels[j], + self.in_channels[j], + kernel_size=3, + stride=2, + padding=1, + groups=self.in_channels[j], + bias=False, ), + nn.BatchNorm(self.in_channels[j]), + L.Conv2d( + self.in_channels[j], + self.in_channels[j], + kernel_size=1, + stride=1, + padding=0, + bias=False, ), + nn.BatchNorm(self.in_channels[j]), + nn.ReLU())) + + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.LayerList(fuse_layer)) + + return nn.LayerList(fuse_layers) + + def forward(self, x): + if self.num_branches == 1: + return [self.layers[0](x[0])] + if self.module_type == 'LITE': + out = self.layers(x) + elif self.module_type == 'NAIVE': + for i in range(self.num_branches): + x[i] = self.layers(x[i]) + out = x + if self.with_fuse: + out_fuse = [] + for i in range(len(self.fuse_layers)): + y = out[0] if i == 0 else self.fuse_layers[i][0](out[0]) + for j in range(self.num_branches): + if i == j: + y += out[j] + else: + y += self.fuse_layers[i][j](out[j]) + if i == 0: + out[i] = y + out_fuse.append(self.relu(y)) + out = out_fuse + elif not self.multiscale_output: + out = [out[0]] + return out + + +@register +class LiteHRNet(nn.Layer): + """ + @inproceedings{Yulitehrnet21, + title={Lite-HRNet: A Lightweight High-Resolution Network}, + author={Yu, Changqian and Xiao, Bin and Gao, Changxin and Yuan, Lu and Zhang, Lei and Sang, Nong and Wang, Jingdong}, + booktitle={CVPR},year={2021} + } + Args: + network_type (str): the network_type should be one of ["lite_18", "lite_30", "naive", "wider_naive"], + "naive": Simply combining the shuffle block in ShuffleNet and the highresolution design pattern in HRNet. + "wider_naive": Naive network with wider channels in each block. + "lite_18": Lite-HRNet-18, which replaces the pointwise convolution in a shuffle block by conditional channel weighting. + "lite_30": Lite-HRNet-30, with more blocks compared with Lite-HRNet-18. + freeze_at (int): the stage to freeze + freeze_norm (bool): whether to freeze norm in HRNet + norm_decay (float): weight decay for normalization layer weights + return_idx (List): the stage to return + """ + + def __init__(self, + network_type, + freeze_at=0, + freeze_norm=True, + norm_decay=0., + return_idx=[0, 1, 2, 3]): + super(LiteHRNet, self).__init__() + if isinstance(return_idx, Integral): + return_idx = [return_idx] + assert ( + network_type in ["lite_18", "lite_30", "naive", "wider_naive"], + "the network_type should be one of [lite_18, lite_30, naive, wider_naive]" + ) + assert len(return_idx) > 0, "need one or more return index" + self.freeze_at = freeze_at + self.freeze_norm = freeze_norm + self.norm_decay = norm_decay + self.return_idx = return_idx + self.norm_type = 'bn' + + self.module_configs = { + "lite_18": { + "num_modules": [2, 4, 2], + "num_branches": [2, 3, 4], + "num_blocks": [2, 2, 2], + "module_type": ["LITE", "LITE", "LITE"], + "reduce_ratios": [8, 8, 8], + "num_channels": [[40, 80], [40, 80, 160], [40, 80, 160, 320]], + }, + "lite_30": { + "num_modules": [3, 8, 3], + "num_branches": [2, 3, 4], + "num_blocks": [2, 2, 2], + "module_type": ["LITE", "LITE", "LITE"], + "reduce_ratios": [8, 8, 8], + "num_channels": [[40, 80], [40, 80, 160], [40, 80, 160, 320]], + }, + "naive": { + "num_modules": [2, 4, 2], + "num_branches": [2, 3, 4], + "num_blocks": [2, 2, 2], + "module_type": ["NAIVE", "NAIVE", "NAIVE"], + "reduce_ratios": [1, 1, 1], + "num_channels": [[30, 60], [30, 60, 120], [30, 60, 120, 240]], + }, + "wider_naive": { + "num_modules": [2, 4, 2], + "num_branches": [2, 3, 4], + "num_blocks": [2, 2, 2], + "module_type": ["NAIVE", "NAIVE", "NAIVE"], + "reduce_ratios": [1, 1, 1], + "num_channels": [[40, 80], [40, 80, 160], [40, 80, 160, 320]], + }, + } + + self.stages_config = self.module_configs[network_type] + + self.stem = Stem(3, 32, 32, 1) + num_channels_pre_layer = [32] + for stage_idx in range(3): + num_channels = self.stages_config["num_channels"][stage_idx] + setattr(self, 'transition{}'.format(stage_idx), + self._make_transition_layer(num_channels_pre_layer, + num_channels, self.freeze_norm, + self.norm_decay)) + stage, num_channels_pre_layer = self._make_stage( + self.stages_config, stage_idx, num_channels, True, + self.freeze_norm, self.norm_decay) + setattr(self, 'stage{}'.format(stage_idx), stage) + self.head_layer = IterativeHead(num_channels_pre_layer, 'bn', + self.freeze_norm, self.norm_decay) + + def _make_transition_layer(self, + num_channels_pre_layer, + num_channels_cur_layer, + freeze_norm=False, + norm_decay=0.): + num_branches_pre = len(num_channels_pre_layer) + num_branches_cur = len(num_channels_cur_layer) + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + L.Conv2d( + num_channels_pre_layer[i], + num_channels_pre_layer[i], + kernel_size=3, + stride=1, + padding=1, + groups=num_channels_pre_layer[i], + bias=False), + nn.BatchNorm(num_channels_pre_layer[i]), + L.Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=1, + stride=1, + padding=0, + bias=False, ), + nn.BatchNorm(num_channels_cur_layer[i]), + nn.ReLU())) + else: + transition_layers.append(None) + else: + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + conv_downsamples.append( + nn.Sequential( + L.Conv2d( + num_channels_pre_layer[-1], + num_channels_pre_layer[-1], + groups=num_channels_pre_layer[-1], + kernel_size=3, + stride=2, + padding=1, + bias=False, ), + nn.BatchNorm(num_channels_pre_layer[-1]), + L.Conv2d( + num_channels_pre_layer[-1], + num_channels_cur_layer[i] + if j == i - num_branches_pre else + num_channels_pre_layer[-1], + kernel_size=1, + stride=1, + padding=0, + bias=False, ), + nn.BatchNorm(num_channels_cur_layer[i] + if j == i - num_branches_pre else + num_channels_pre_layer[-1]), + nn.ReLU())) + transition_layers.append(nn.Sequential(*conv_downsamples)) + return nn.LayerList(transition_layers) + + def _make_stage(self, + stages_config, + stage_idx, + in_channels, + multiscale_output, + freeze_norm=False, + norm_decay=0.): + num_modules = stages_config["num_modules"][stage_idx] + num_branches = stages_config["num_branches"][stage_idx] + num_blocks = stages_config["num_blocks"][stage_idx] + reduce_ratio = stages_config['reduce_ratios'][stage_idx] + module_type = stages_config['module_type'][stage_idx] + + modules = [] + for i in range(num_modules): + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + modules.append( + LiteHRNetModule( + num_branches, + num_blocks, + in_channels, + reduce_ratio, + module_type, + multiscale_output=reset_multiscale_output, + with_fuse=True, + freeze_norm=freeze_norm, + norm_decay=norm_decay)) + in_channels = modules[-1].in_channels + return nn.Sequential(*modules), in_channels + + def forward(self, inputs): + x = inputs['image'] + x = self.stem(x) + y_list = [x] + for stage_idx in range(3): + x_list = [] + transition = getattr(self, 'transition{}'.format(stage_idx)) + for j in range(self.stages_config["num_branches"][stage_idx]): + if transition[j] is not None: + if j >= len(y_list): + x_list.append(transition[j](y_list[-1])) + else: + x_list.append(transition[j](y_list[j])) + else: + x_list.append(y_list[j]) + y_list = getattr(self, 'stage{}'.format(stage_idx))(x_list) + x = self.head_layer(y_list) + res = [] + for i, layer in enumerate(x): + if i == self.freeze_at: + layer.stop_gradient = True + if i in self.return_idx: + res.append(layer) + return res + + @property + def out_shape(self): + return [ + ShapeSpec( + channels=self._out_channels[i], stride=self._out_strides[i]) + for i in self.return_idx + ] diff --git a/ppdet/modeling/backbones/shufflenet_v2.py b/ppdet/modeling/backbones/shufflenet_v2.py index 75cd6e38d..cd7ea2d9b 100644 --- a/ppdet/modeling/backbones/shufflenet_v2.py +++ b/ppdet/modeling/backbones/shufflenet_v2.py @@ -25,26 +25,11 @@ from paddle.nn.initializer import KaimingNormal from ppdet.core.workspace import register, serializable from numbers import Integral from ..shape_spec import ShapeSpec +from ppdet.modeling.ops import channel_shuffle __all__ = ['ShuffleNetV2'] -def channel_shuffle(x, groups): - batch_size, num_channels, height, width = x.shape[0:4] - channels_per_group = num_channels // groups - - # reshape - x = paddle.reshape( - x=x, shape=[batch_size, groups, channels_per_group, height, width]) - - # transpose - x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4]) - - # flatten - x = paddle.reshape(x=x, shape=[batch_size, num_channels, height, width]) - return x - - class ConvBNLayer(nn.Layer): def __init__(self, in_channels, diff --git a/ppdet/modeling/losses/keypoint_loss.py b/ppdet/modeling/losses/keypoint_loss.py index 62b5a54a2..9c3c113db 100644 --- a/ppdet/modeling/losses/keypoint_loss.py +++ b/ppdet/modeling/losses/keypoint_loss.py @@ -29,7 +29,7 @@ __all__ = ['HrHRNetLoss', 'KeyPointMSELoss'] @register @serializable class KeyPointMSELoss(nn.Layer): - def __init__(self, use_target_weight=True): + def __init__(self, use_target_weight=True, loss_scale=0.5): """ KeyPointMSELoss layer @@ -39,6 +39,7 @@ class KeyPointMSELoss(nn.Layer): super(KeyPointMSELoss, self).__init__() self.criterion = nn.MSELoss(reduction='mean') self.use_target_weight = use_target_weight + self.loss_scale = loss_scale def forward(self, output, records): target = records['target'] @@ -50,16 +51,16 @@ class KeyPointMSELoss(nn.Layer): heatmaps_gt = target.reshape( (batch_size, num_joints, -1)).split(num_joints, 1) loss = 0 - for idx in range(num_joints): heatmap_pred = heatmaps_pred[idx].squeeze() heatmap_gt = heatmaps_gt[idx].squeeze() if self.use_target_weight: - loss += 0.5 * self.criterion( + loss += self.loss_scale * self.criterion( heatmap_pred.multiply(target_weight[:, idx]), heatmap_gt.multiply(target_weight[:, idx])) else: - loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) + loss += self.loss_scale * self.criterion(heatmap_pred, + heatmap_gt) keypoint_losses = dict() keypoint_losses['loss'] = loss / num_joints return keypoint_losses diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 6fa4d8ad3..ba4321e6b 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -1588,3 +1588,15 @@ def smooth_l1(input, label, inside_weight=None, outside_weight=None, out = paddle.reshape(out, shape=[out.shape[0], -1]) out = paddle.sum(out, axis=1) return out + + +def channel_shuffle(x, groups): + batch_size, num_channels, height, width = x.shape[0:4] + assert (num_channels % groups == 0, + 'num_channels should be divisible by groups') + channels_per_group = num_channels // groups + x = paddle.reshape( + x=x, shape=[batch_size, groups, channels_per_group, height, width]) + x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4]) + x = paddle.reshape(x=x, shape=[batch_size, num_channels, height, width]) + return x -- GitLab