From 3fc967b6138fc0e9793a49ee8d9888d366a1fddd Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Fri, 11 Feb 2022 10:51:10 +0800 Subject: [PATCH] add a new version code of PicoDet (#5170) * add a new version code of PicoDet --- configs/picodet/_base_/picodet_esnetv2.yml | 64 ++++ .../picodet/_base_/picodet_v2_416_reader.yml | 41 ++ .../picodet_v2_lcnet_1_5x_416_coco.yml | 34 ++ configs/picodet/picodet_v2_s_416_coco.yml | 46 +++ ppdet/engine/export_utils.py | 9 +- ppdet/modeling/heads/pico_head.py | 354 +++++++++++++++++- ppdet/modeling/necks/__init__.py | 2 + ppdet/modeling/necks/csp_pan.py | 35 +- ppdet/modeling/necks/es_pan.py | 212 +++++++++++ 9 files changed, 761 insertions(+), 36 deletions(-) create mode 100644 configs/picodet/_base_/picodet_esnetv2.yml create mode 100644 configs/picodet/_base_/picodet_v2_416_reader.yml create mode 100644 configs/picodet/more_config/picodet_v2_lcnet_1_5x_416_coco.yml create mode 100644 configs/picodet/picodet_v2_s_416_coco.yml create mode 100644 ppdet/modeling/necks/es_pan.py diff --git a/configs/picodet/_base_/picodet_esnetv2.yml b/configs/picodet/_base_/picodet_esnetv2.yml new file mode 100644 index 000000000..8e16d5298 --- /dev/null +++ b/configs/picodet/_base_/picodet_esnetv2.yml @@ -0,0 +1,64 @@ +architecture: PicoDet +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams +export_post_process: False # Whether post-processing is included in the network when export model. + +PicoDet: + backbone: ESNet + neck: ESPAN + head: PicoHeadV2 + +ESNet: + scale: 1.0 + feature_maps: [4, 11, 14] + act: hard_swish + channel_ratio: [0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5, 0.625, 1.0, 0.625, 0.75] + +ESPAN: + out_channels: 128 + use_depthwise: True + num_features: 4 + +PicoHeadV2: + conv_feat: + name: PicoFeat + feat_in: 128 + feat_out: 128 + num_convs: 4 + num_fpn_stride: 4 + norm_type: bn + share_cls_reg: True + use_se: True + fpn_stride: [8, 16, 32, 64] + feat_in_chan: 128 + prior_prob: 0.01 + reg_max: 7 + cell_offset: 0.5 + grid_cell_scale: 5.0 + static_assigner_epoch: 100 + use_align_head: True + static_assigner: + name: ATSSAssigner + topk: 9 + force_gt_matching: False + assigner: + name: TaskAlignedAssigner + topk: 13 + alpha: 1.0 + beta: 6.0 + loss_class: + name: VarifocalLoss + use_sigmoid: False + iou_weighted: True + loss_weight: 1.0 + loss_dfl: + name: DistributionFocalLoss + loss_weight: 0.5 + loss_bbox: + name: GIoULoss + loss_weight: 2.5 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.025 + nms_threshold: 0.6 diff --git a/configs/picodet/_base_/picodet_v2_416_reader.yml b/configs/picodet/_base_/picodet_v2_416_reader.yml new file mode 100644 index 000000000..0479c58ef --- /dev/null +++ b/configs/picodet/_base_/picodet_v2_416_reader.yml @@ -0,0 +1,41 @@ +worker_num: 6 +TrainReader: + sample_transforms: + - Decode: {} + - RandomCrop: {} + - RandomFlip: {prob: 0.5} + - RandomDistort: {} + batch_transforms: + - BatchRandomResize: {target_size: [352, 384, 416, 448, 480], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + - PadGT: {} + batch_size: 80 + shuffle: true + drop_last: true + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 8 + shuffle: false + + +TestReader: + inputs_def: + image_shape: [1, 3, 416, 416] + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [416, 416], keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false diff --git a/configs/picodet/more_config/picodet_v2_lcnet_1_5x_416_coco.yml b/configs/picodet/more_config/picodet_v2_lcnet_1_5x_416_coco.yml new file mode 100644 index 000000000..5c825fa5f --- /dev/null +++ b/configs/picodet/more_config/picodet_v2_lcnet_1_5x_416_coco.yml @@ -0,0 +1,34 @@ +_BASE_: [ + '../../datasets/coco_detection.yml', + '../../runtime.yml', + '../_base_/picodet_esnetv2.yml', + '../_base_/optimizer_300e.yml', + '../_base_/picodet_v2_416_reader.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/LCNet_x1_5_pretrained.pdparams +weights: output/picodet_lcnet_1_5x_416_coco/model_final +find_unused_parameters: True +use_ema: true +snapshot_epoch: 10 + +PicoDet: + backbone: LCNet + neck: ESPAN + head: PicoHeadV2 + +LCNet: + scale: 1.5 + feature_maps: [3, 4, 5] + +TrainReader: + batch_size: 32 + +LearningRate: + base_lr: 0.2 + schedulers: + - !CosineDecay + max_epochs: 300 + - !LinearWarmup + start_factor: 0.1 + steps: 300 diff --git a/configs/picodet/picodet_v2_s_416_coco.yml b/configs/picodet/picodet_v2_s_416_coco.yml new file mode 100644 index 000000000..ba4623220 --- /dev/null +++ b/configs/picodet/picodet_v2_s_416_coco.yml @@ -0,0 +1,46 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/picodet_esnetv2.yml', + '_base_/optimizer_300e.yml', + '_base_/picodet_v2_416_reader.yml', +] + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x0_75_pretrained.pdparams +weights: output/picodet_s_416_coco/model_final +find_unused_parameters: True +use_ema: true +snapshot_epoch: 10 + +ESNet: + scale: 0.75 + feature_maps: [4, 11, 14] + act: hard_swish + channel_ratio: [0.875, 0.5, 0.5, 0.5, 0.625, 0.5, 0.625, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5] + +ESPAN: + out_channels: 96 + +PicoHeadV2: + conv_feat: + name: PicoFeat + feat_in: 96 + feat_out: 96 + num_convs: 2 + num_fpn_stride: 4 + norm_type: bn + share_cls_reg: True + use_se: True + feat_in_chan: 96 + +TrainReader: + batch_size: 56 + +LearningRate: + base_lr: 0.3 + schedulers: + - !CosineDecay + max_epochs: 300 + - !LinearWarmup + start_factor: 0.1 + steps: 300 diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index 192b89658..496fd0773 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -167,12 +167,13 @@ def _dump_infer_config(config, path, image_shape, model): if infer_arch == 'PicoDet': if config.get('export_post_process', False): infer_cfg['arch'] = 'GFL' - infer_cfg['NMS'] = config['PicoHead']['nms'] + head_name = 'PicoHeadV2' if config['PicoHeadV2'] else 'PicoHead' + infer_cfg['NMS'] = config[head_name]['nms'] # In order to speed up the prediction, the threshold of nms # is adjusted here, which can be changed in infer_cfg.yml - config['PicoHead']['nms']["score_threshold"] = 0.3 - config['PicoHead']['nms']["nms_threshold"] = 0.5 - infer_cfg['fpn_stride'] = config['PicoHead']['fpn_stride'] + config[head_name]['nms']["score_threshold"] = 0.3 + config[head_name]['nms']["nms_threshold"] = 0.5 + infer_cfg['fpn_stride'] = config[head_name]['fpn_stride'] yaml.dump(infer_cfg, open(path, 'w')) logger.info("Export inference config file to {}".format(os.path.join(path))) diff --git a/ppdet/modeling/heads/pico_head.py b/ppdet/modeling/heads/pico_head.py index 022acb045..b8ee83a34 100644 --- a/ppdet/modeling/heads/pico_head.py +++ b/ppdet/modeling/heads/pico_head.py @@ -1,15 +1,15 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import @@ -23,11 +23,38 @@ import paddle.nn as nn import paddle.nn.functional as F from paddle import ParamAttr from paddle.nn.initializer import Normal, Constant +from paddle.fluid.dygraph import parallel_helper +from ppdet.modeling.ops import get_static_shape +from ..initializer import normal_ +from ..assigners.utils import generate_anchors_for_grid_cell +from ..bbox_utils import bbox_center, batch_distance2bbox, bbox2distance from ppdet.core.workspace import register from ppdet.modeling.layers import ConvNormLayer -from ppdet.modeling.bbox_utils import batch_distance2bbox from .simota_head import OTAVFLHead +from .gfl_head import Integral, GFLHead +from ppdet.modeling.necks.csp_pan import DPModule + +eps = 1e-9 + +__all__ = ['PicoHead', 'PicoHeadV2', 'PicoFeat'] + + +class PicoSE(nn.Layer): + def __init__(self, feat_channels): + super(PicoSE, self).__init__() + self.fc = nn.Conv2D(feat_channels, feat_channels, 1) + self.conv = ConvNormLayer(feat_channels, feat_channels, 1, 1) + + self._init_weights() + + def _init_weights(self): + normal_(self.fc.weight, std=0.001) + + def forward(self, feat, avg_feat): + weight = F.sigmoid(self.fc(avg_feat)) + out = self.conv(feat * weight) + return out @register @@ -40,6 +67,9 @@ class PicoFeat(nn.Layer): feat_out (int): The channel number of output Tensor. num_convs (int): The convolution number of the LiteGFLFeat. norm_type (str): Normalization type, 'bn'/'sync_bn'/'gn'. + share_cls_reg (bool): Whether to share the cls and reg output. + act (str): The act of per layers. + use_se (bool): Whether to use se module. """ def __init__(self, @@ -49,14 +79,20 @@ class PicoFeat(nn.Layer): num_convs=2, norm_type='bn', share_cls_reg=False, - act='hard_swish'): + act='hard_swish', + use_se=False): super(PicoFeat, self).__init__() self.num_convs = num_convs self.norm_type = norm_type self.share_cls_reg = share_cls_reg self.act = act + self.use_se = use_se self.cls_convs = [] self.reg_convs = [] + if use_se: + assert share_cls_reg == True, \ + 'In the case of using se, share_cls_reg is not supported' + self.se = nn.LayerList() for stage_idx in range(num_fpn_stride): cls_subnet_convs = [] reg_subnet_convs = [] @@ -112,6 +148,8 @@ class PicoFeat(nn.Layer): reg_subnet_convs.append(reg_conv_pw) self.cls_convs.append(cls_subnet_convs) self.reg_convs.append(reg_subnet_convs) + if use_se: + self.se.append(PicoSE(feat_out)) def act_func(self, x): if self.act == "leaky_relu": @@ -126,8 +164,13 @@ class PicoFeat(nn.Layer): reg_feat = fpn_feat for i in range(len(self.cls_convs[stage_idx])): cls_feat = self.act_func(self.cls_convs[stage_idx][i](cls_feat)) + reg_feat = cls_feat if not self.share_cls_reg: reg_feat = self.act_func(self.reg_convs[stage_idx][i](reg_feat)) + if self.use_se: + avg_feat = F.adaptive_avg_pool2d(cls_feat, (1, 1)) + se_feat = self.act_func(self.se[stage_idx](cls_feat, avg_feat)) + return cls_feat, se_feat return cls_feat, reg_feat @@ -291,3 +334,286 @@ class PicoHead(OTAVFLHead): bboxes_reg_list.append(bbox_pred) return (cls_logits_list, bboxes_reg_list) + + +@register +class PicoHeadV2(GFLHead): + """ + PicoHeadV2 + Args: + conv_feat (object): Instance of 'PicoFeat' + num_classes (int): Number of classes + fpn_stride (list): The stride of each FPN Layer + prior_prob (float): Used to set the bias init for the class prediction layer + loss_class (object): Instance of VariFocalLoss. + loss_dfl (object): Instance of DistributionFocalLoss. + loss_bbox (object): Instance of bbox loss. + assigner (object): Instance of label assigner. + reg_max: Max value of integral set :math: `{0, ..., reg_max}` + n QFL setting. Default: 7. + """ + __inject__ = [ + 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox', + 'static_assigner', 'assigner', 'nms' + ] + __shared__ = ['num_classes'] + + def __init__( + self, + conv_feat='PicoFeatV2', + dgqp_module=None, + num_classes=80, + fpn_stride=[8, 16, 32], + prior_prob=0.01, + use_align_head=True, + loss_class='VariFocalLoss', + loss_dfl='DistributionFocalLoss', + loss_bbox='GIoULoss', + static_assigner_epoch=60, + static_assigner='ATSSAssigner', + assigner='TaskAlignedAssigner', + reg_max=16, + feat_in_chan=96, + nms=None, + nms_pre=1000, + cell_offset=0, + act='hard_swish', + grid_cell_scale=5.0, ): + super(PicoHeadV2, self).__init__( + conv_feat=conv_feat, + dgqp_module=dgqp_module, + num_classes=num_classes, + fpn_stride=fpn_stride, + prior_prob=prior_prob, + loss_class=loss_class, + loss_dfl=loss_dfl, + loss_bbox=loss_bbox, + reg_max=reg_max, + feat_in_chan=feat_in_chan, + nms=nms, + nms_pre=nms_pre, + cell_offset=cell_offset, ) + self.conv_feat = conv_feat + self.num_classes = num_classes + self.fpn_stride = fpn_stride + self.prior_prob = prior_prob + self.loss_vfl = loss_class + self.loss_dfl = loss_dfl + self.loss_bbox = loss_bbox + + self.static_assigner_epoch = static_assigner_epoch + self.static_assigner = static_assigner + self.assigner = assigner + + self.reg_max = reg_max + self.feat_in_chan = feat_in_chan + self.nms = nms + self.nms_pre = nms_pre + self.cell_offset = cell_offset + self.act = act + self.grid_cell_scale = grid_cell_scale + self.use_align_head = use_align_head + self.cls_out_channels = self.num_classes + + bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob) + # Clear the super class initialization + self.gfl_head_cls = None + self.gfl_head_reg = None + self.scales_regs = None + + self.head_cls_list = [] + self.head_reg_list = [] + self.cls_align = nn.LayerList() + + for i in range(len(fpn_stride)): + head_cls = self.add_sublayer( + "head_cls" + str(i), + nn.Conv2D( + in_channels=self.feat_in_chan, + out_channels=self.cls_out_channels, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), + bias_attr=ParamAttr( + initializer=Constant(value=bias_init_value)))) + self.head_cls_list.append(head_cls) + head_reg = self.add_sublayer( + "head_reg" + str(i), + nn.Conv2D( + in_channels=self.feat_in_chan, + out_channels=4 * (self.reg_max + 1), + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), + bias_attr=ParamAttr(initializer=Constant(value=0)))) + self.head_reg_list.append(head_reg) + if self.use_align_head: + self.cls_align.append( + DPModule( + self.feat_in_chan, + 1, + 5, + act=self.act, + use_act_in_out=False)) + + def forward(self, fpn_feats, deploy=False): + assert len(fpn_feats) == len( + self.fpn_stride + ), "The size of fpn_feats is not equal to size of fpn_stride" + anchors, num_anchors_list, stride_tensor_list = generate_anchors_for_grid_cell( + fpn_feats, self.fpn_stride, self.grid_cell_scale, self.cell_offset) + + cls_score_list, reg_list, box_list = [], [], [] + for i, fpn_feat, anchor, stride, align_cls in zip( + range(len(self.fpn_stride)), fpn_feats, anchors, + self.fpn_stride, self.cls_align): + b, _, h, w = get_static_shape(fpn_feat) + # task decomposition + conv_cls_feat, se_feat = self.conv_feat(fpn_feat, i) + cls_logit = self.head_cls_list[i](se_feat) + reg_pred = self.head_reg_list[i](se_feat) + + # cls prediction and alignment + if self.use_align_head: + cls_prob = F.sigmoid(align_cls(conv_cls_feat)) + cls_score = (F.sigmoid(cls_logit) * cls_prob + eps).sqrt() + else: + cls_score = F.sigmoid(cls_logit) + + anchor_centers = bbox_center(anchor).unsqueeze(0) / stride + anchor_centers = anchor_centers.reshape([1, h, w, 2]) + + pred_distances = self.distribution_project( + reg_pred.transpose([0, 2, 3, 1])).reshape([b, h, w, 4]) + reg_bbox = batch_distance2bbox( + anchor_centers, pred_distances, max_shapes=None) + if not self.training: + cls_score_list.append( + cls_score.transpose([0, 2, 3, 1]).reshape( + [b, -1, self.cls_out_channels])) + box_list.append(reg_bbox.reshape([b, -1, 4]) * stride) + else: + cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1])) + reg_list.append(reg_pred.flatten(2).transpose([0, 2, 1])) + box_list.append(reg_bbox.reshape([b, -1, 4])) + + if not self.training: + return cls_score_list, box_list + else: + cls_score_list = paddle.concat(cls_score_list, axis=1) + box_list = paddle.concat(box_list, axis=1) + reg_list = paddle.concat(reg_list, axis=1) + anchors = paddle.concat(anchors) + anchors.stop_gradient = True + stride_tensor_list = paddle.concat(stride_tensor_list) + stride_tensor_list.stop_gradient = True + return cls_score_list, reg_list, box_list, anchors, num_anchors_list, stride_tensor_list + + def get_loss(self, head_outs, gt_meta): + pred_scores, pred_regs, pred_bboxes, anchors, num_anchors_list, stride_tensor_list = head_outs + gt_labels = gt_meta['gt_class'] + gt_bboxes = gt_meta['gt_bbox'] + gt_scores = gt_meta['gt_score'] if 'gt_score' in gt_meta else None + num_imgs = gt_meta['im_id'].shape[0] + pad_gt_mask = gt_meta['pad_gt_mask'] + + centers = bbox_center(anchors) + + # label assignment + if gt_meta['epoch_id'] < self.static_assigner_epoch: + assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner( + anchors, + num_anchors_list, + gt_labels, + gt_bboxes, + pad_gt_mask, + bg_index=self.num_classes, + gt_scores=gt_scores, + pred_bboxes=pred_bboxes.detach() * stride_tensor_list) + + else: + assigned_labels, assigned_bboxes, assigned_scores = self.assigner( + pred_scores.detach(), + pred_bboxes.detach() * stride_tensor_list, + centers, + num_anchors_list, + gt_labels, + gt_bboxes, + pad_gt_mask, + bg_index=self.num_classes, + gt_scores=gt_scores) + + assigned_bboxes /= stride_tensor_list + + centers_shape = centers.shape + flatten_centers = centers.expand( + [num_imgs, centers_shape[0], centers_shape[1]]).reshape([-1, 2]) + flatten_strides = stride_tensor_list.expand( + [num_imgs, centers_shape[0], 1]).reshape([-1, 1]) + flatten_cls_preds = pred_scores.reshape([-1, self.num_classes]) + flatten_regs = pred_regs.reshape([-1, 4 * (self.reg_max + 1)]) + flatten_bboxes = pred_bboxes.reshape([-1, 4]) + flatten_bbox_targets = assigned_bboxes.reshape([-1, 4]) + flatten_labels = assigned_labels.reshape([-1]) + flatten_assigned_scores = assigned_scores.reshape( + [-1, self.num_classes]) + + pos_inds = paddle.nonzero( + paddle.logical_and((flatten_labels >= 0), + (flatten_labels < self.num_classes)), + as_tuple=False).squeeze(1) + + num_total_pos = len(pos_inds) + + if num_total_pos > 0: + pos_bbox_targets = paddle.gather( + flatten_bbox_targets, pos_inds, axis=0) + pos_decode_bbox_pred = paddle.gather( + flatten_bboxes, pos_inds, axis=0) + pos_reg = paddle.gather(flatten_regs, pos_inds, axis=0) + pos_strides = paddle.gather(flatten_strides, pos_inds, axis=0) + pos_centers = paddle.gather( + flatten_centers, pos_inds, axis=0) / pos_strides + + weight_targets = flatten_assigned_scores.detach() + weight_targets = paddle.gather( + weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0) + + pred_corners = pos_reg.reshape([-1, self.reg_max + 1]) + target_corners = bbox2distance(pos_centers, pos_bbox_targets, + self.reg_max).reshape([-1]) + # regression loss + loss_bbox = paddle.sum( + self.loss_bbox(pos_decode_bbox_pred, + pos_bbox_targets) * weight_targets) + + # dfl loss + loss_dfl = self.loss_dfl( + pred_corners, + target_corners, + weight=weight_targets.expand([-1, 4]).reshape([-1]), + avg_factor=4.0) + else: + loss_bbox = paddle.zeros([1]) + loss_dfl = paddle.zeros([1]) + + avg_factor = flatten_assigned_scores.sum() + if paddle.fluid.core.is_compiled_with_dist( + ) and parallel_helper._is_parallel_ctx_initialized(): + paddle.distributed.all_reduce(avg_factor) + avg_factor = paddle.clip( + avg_factor / paddle.distributed.get_world_size(), min=1) + loss_vfl = self.loss_vfl( + flatten_cls_preds, flatten_assigned_scores, avg_factor=avg_factor) + + loss_bbox = loss_bbox / avg_factor + loss_dfl = loss_dfl / avg_factor + + loss_states = dict( + loss_vfl=loss_vfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl) + + return loss_states diff --git a/ppdet/modeling/necks/__init__.py b/ppdet/modeling/necks/__init__.py index d66697caf..3908012cf 100644 --- a/ppdet/modeling/necks/__init__.py +++ b/ppdet/modeling/necks/__init__.py @@ -19,6 +19,7 @@ from . import ttf_fpn from . import centernet_fpn from . import bifpn from . import csp_pan +from . import es_pan from .fpn import * from .yolo_fpn import * @@ -28,3 +29,4 @@ from .centernet_fpn import * from .blazeface_fpn import * from .bifpn import * from .csp_pan import * +from .es_pan import * diff --git a/ppdet/modeling/necks/csp_pan.py b/ppdet/modeling/necks/csp_pan.py index 7417c46ab..5c3539a95 100644 --- a/ppdet/modeling/necks/csp_pan.py +++ b/ppdet/modeling/necks/csp_pan.py @@ -19,7 +19,6 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle import ParamAttr -from paddle.regularizer import L2Decay from ppdet.core.workspace import register, serializable from ..shape_spec import ShapeSpec @@ -36,8 +35,6 @@ class ConvBNLayer(nn.Layer): act='leaky_relu'): super(ConvBNLayer, self).__init__() initializer = nn.initializer.KaimingUniform() - self.act = act - assert self.act in ['leaky_relu', "hard_swish"] self.conv = nn.Conv2D( in_channels=in_channel, out_channels=out_channel, @@ -48,13 +45,14 @@ class ConvBNLayer(nn.Layer): weight_attr=ParamAttr(initializer=initializer), bias_attr=False) self.bn = nn.BatchNorm2D(out_channel) + if act == "hard_swish": + act = 'hardswish' + self.act = act def forward(self, x): x = self.bn(self.conv(x)) - if self.act == "leaky_relu": - x = F.leaky_relu(x) - elif self.act == "hard_swish": - x = F.hardswish(x) + if self.act: + x = getattr(F, self.act)(x) return x @@ -75,10 +73,11 @@ class DPModule(nn.Layer): out_channel=96, kernel_size=3, stride=1, - act='leaky_relu'): + act='leaky_relu', + use_act_in_out=True): super(DPModule, self).__init__() initializer = nn.initializer.KaimingUniform() - self.act = act + self.use_act_in_out = use_act_in_out self.dwconv = nn.Conv2D( in_channels=in_channel, out_channels=out_channel, @@ -98,17 +97,17 @@ class DPModule(nn.Layer): weight_attr=ParamAttr(initializer=initializer), bias_attr=False) self.bn2 = nn.BatchNorm2D(out_channel) - - def act_func(self, x): - if self.act == "leaky_relu": - x = F.leaky_relu(x) - elif self.act == "hard_swish": - x = F.hardswish(x) - return x + if act == "hard_swish": + act = 'hardswish' + self.act = act def forward(self, x): - x = self.act_func(self.bn1(self.dwconv(x))) - x = self.act_func(self.bn2(self.pwconv(x))) + x = self.bn1(self.dwconv(x)) + if self.act: + x = getattr(F, self.act)(x) + x = self.bn2(self.pwconv(x)) + if self.use_act_in_out and self.act: + x = getattr(F, self.act)(x) return x diff --git a/ppdet/modeling/necks/es_pan.py b/ppdet/modeling/necks/es_pan.py new file mode 100644 index 000000000..bc2487733 --- /dev/null +++ b/ppdet/modeling/necks/es_pan.py @@ -0,0 +1,212 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.regularizer import L2Decay +from ppdet.core.workspace import register, serializable + +from ..shape_spec import ShapeSpec +from ..backbones.esnet import SEModule +from .csp_pan import ConvBNLayer, Channel_T, DPModule + +__all__ = ['ESPAN'] + + +class ES_Block(nn.Layer): + def __init__(self, + in_channels, + mid_channels, + out_channels, + kernel_size=5, + stride=1, + act='leaky_relu'): + super(ES_Block, self).__init__() + self._residual = ConvBNLayer( + in_channel=in_channels, + out_channel=out_channels, + kernel_size=1, + stride=1, + groups=1, + act=act) + self._conv_pw = ConvBNLayer( + in_channel=in_channels, + out_channel=mid_channels // 2, + kernel_size=1, + stride=1, + groups=1, + act=act) + self._conv_dw = ConvBNLayer( + in_channel=mid_channels // 2, + out_channel=mid_channels // 2, + kernel_size=kernel_size, + stride=stride, + groups=mid_channels // 2, + act=None) + self._se = SEModule(mid_channels) + + self._conv_linear = ConvBNLayer( + in_channel=mid_channels, + out_channel=out_channels, + kernel_size=1, + stride=1, + groups=1, + act=act) + + self._out_conv = ConvBNLayer( + in_channel=out_channels * 2, + out_channel=out_channels, + kernel_size=1, + stride=1, + groups=1, + act=act) + + def forward(self, inputs): + x1 = self._residual(inputs) + x2 = self._conv_pw(inputs) + x3 = self._conv_dw(x2) + x3 = paddle.concat([x2, x3], axis=1) + x3 = self._se(x3) + x3 = self._conv_linear(x3) + out = paddle.concat([x1, x3], axis=1) + out = self._out_conv(out) + return out + + +@register +@serializable +class ESPAN(nn.Layer): + """Path Aggregation Network with ES module. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + kernel_size (int): The conv2d kernel size of this Module. + num_features (int): Number of output features of CSPPAN module. + num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 1 + use_depthwise (bool): Whether to depthwise separable convolution in + blocks. Default: True + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=5, + num_features=3, + use_depthwise=True, + act='hard_swish', + spatial_scales=[0.125, 0.0625, 0.03125]): + super(ESPAN, self).__init__() + self.conv_t = Channel_T(in_channels, out_channels, act=act) + in_channels = [out_channels] * len(spatial_scales) + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_scales = spatial_scales + self.num_features = num_features + conv_func = DPModule if use_depthwise else ConvBNLayer + + if self.num_features == 4: + self.first_top_conv = conv_func( + in_channels[0], in_channels[0], kernel_size, stride=2, act=act) + self.second_top_conv = conv_func( + in_channels[0], in_channels[0], kernel_size, stride=2, act=act) + self.spatial_scales.append(self.spatial_scales[-1] / 2) + + # build top-down blocks + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.top_down_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1, 0, -1): + self.top_down_blocks.append( + ES_Block( + in_channels[idx - 1] * 2, + in_channels[idx - 1], + in_channels[idx - 1], + kernel_size=kernel_size, + stride=1, + act=act)) + + # build bottom-up blocks + self.downsamples = nn.LayerList() + self.bottom_up_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1): + self.downsamples.append( + conv_func( + in_channels[idx], + in_channels[idx], + kernel_size=kernel_size, + stride=2, + act=act)) + self.bottom_up_blocks.append( + ES_Block( + in_channels[idx] * 2, + in_channels[idx + 1], + in_channels[idx + 1], + kernel_size=kernel_size, + stride=1, + act=act)) + + def forward(self, inputs): + """ + Args: + inputs (tuple[Tensor]): input features. + + Returns: + tuple[Tensor]: CSPPAN features. + """ + assert len(inputs) == len(self.in_channels) + inputs = self.conv_t(inputs) + + # top-down path + inner_outs = [inputs[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = inputs[idx - 1] + + upsample_feat = self.upsample(feat_heigh) + + inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx]( + paddle.concat([upsample_feat, feat_low], 1)) + inner_outs.insert(0, inner_out) + + # bottom-up path + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsamples[idx](feat_low) + out = self.bottom_up_blocks[idx](paddle.concat( + [downsample_feat, feat_height], 1)) + outs.append(out) + + top_features = None + if self.num_features == 4: + top_features = self.first_top_conv(inputs[-1]) + top_features = top_features + self.second_top_conv(outs[-1]) + outs.append(top_features) + + return tuple(outs) + + @property + def out_shape(self): + return [ + ShapeSpec( + channels=self.out_channels, stride=1. / s) + for s in self.spatial_scales + ] + + @classmethod + def from_config(cls, cfg, input_shape): + return {'in_channels': [i.channels for i in input_shape], } -- GitLab