From ae621055724ca27122f37357b75cf607049b5b58 Mon Sep 17 00:00:00 2001 From: Blake Date: Wed, 26 Jan 2022 16:51:41 +0800 Subject: [PATCH] add implementation of RetinaNet (#5140) * add implementation of RetinaNet * * add README.md and model zoo * rename FOCSFeat -> RetianFeat * add mstrain model to model zoo * refactor DeltaBBoxCoder * update link for model and log --- configs/retinanet/README.md | 28 ++ configs/retinanet/_base_/optimizer_1x.yml | 19 ++ .../retinanet/_base_/retinanet_r50_fpn.yml | 65 +++++ configs/retinanet/_base_/retinanet_reader.yml | 39 +++ .../retinanet/retinanet_r50_fpn_1x_coco.yml | 10 + .../retinanet_r50_fpn_mstrain_1x_coco.yml | 20 ++ ppdet/modeling/__init__.py | 2 + ppdet/modeling/architectures/__init__.py | 2 + ppdet/modeling/architectures/retinanet.py | 72 +++++ ppdet/modeling/assigners/__init__.py | 2 + ppdet/modeling/assigners/max_iou_assigner.py | 52 ++++ ppdet/modeling/bbox_utils.py | 90 ++++++ ppdet/modeling/coders/__init__.py | 1 + ppdet/modeling/coders/delta_bbox_coder.py | 40 +++ ppdet/modeling/heads/__init__.py | 2 + ppdet/modeling/heads/fcos_head.py | 2 + ppdet/modeling/heads/retina_head.py | 257 ++++++++++++++++++ ppdet/modeling/layers.py | 7 +- ppdet/modeling/losses/__init__.py | 4 + ppdet/modeling/losses/focal_loss.py | 61 +++++ ppdet/modeling/losses/smooth_l1_loss.py | 60 ++++ 21 files changed, 833 insertions(+), 2 deletions(-) create mode 100644 configs/retinanet/README.md create mode 100644 configs/retinanet/_base_/optimizer_1x.yml create mode 100644 configs/retinanet/_base_/retinanet_r50_fpn.yml create mode 100644 configs/retinanet/_base_/retinanet_reader.yml create mode 100644 configs/retinanet/retinanet_r50_fpn_1x_coco.yml create mode 100644 configs/retinanet/retinanet_r50_fpn_mstrain_1x_coco.yml create mode 100644 ppdet/modeling/architectures/retinanet.py create mode 100644 ppdet/modeling/assigners/max_iou_assigner.py create mode 100644 ppdet/modeling/coders/__init__.py create mode 100644 ppdet/modeling/coders/delta_bbox_coder.py create mode 100644 ppdet/modeling/heads/retina_head.py create mode 100644 ppdet/modeling/losses/focal_loss.py create mode 100644 ppdet/modeling/losses/smooth_l1_loss.py diff --git a/configs/retinanet/README.md b/configs/retinanet/README.md new file mode 100644 index 000000000..bfa281321 --- /dev/null +++ b/configs/retinanet/README.md @@ -0,0 +1,28 @@ +# Focal Loss for Dense Object Detection + +## Introduction + +We reproduce RetinaNet proposed in paper Focal Loss for Dense Object Detection. + +## Model Zoo + +| Backbone | Model | mstrain | imgs/GPU | lr schedule | FPS | Box AP | download | config | +| ------------ | --------- | ------- | -------- | ----------- | --- | ------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------- | +| ResNet50-FPN | RetinaNet | Yes | 4 | 1x | --- | 37.5 | [model](https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_mstrain_1x_coco.pdparams)\|[log](https://bj.bcebos.com/v1/paddledet/logs/retinanet_r50_fpn_mstrain_1x_coco.log) | retinanet_r50_fpn_mstrain_1x_coco.yml | + +**Notes:** + +- All above models are trained on COCO train2017 with 4 GPUs and evaludated on val2017. Box AP=`mAP(IoU=0.5:0.95)`. + +- Config `configs/retinanet/retinanet_r50_fpn_1x_coco.yml` is for 8 GPUs and `configs/retinanet/retinanet_r50_fpn_mstrain_1x_coco.yml` is for 4 GPUs (mind the difference of train batch size). + +## Citation + +```latex +@inproceedings{lin2017focal, + title={Focal loss for dense object detection}, + author={Lin, Tsung-Yi and Goyal, Priya and Girshick, Ross and He, Kaiming and Doll{\'a}r, Piotr}, + booktitle={Proceedings of the IEEE international conference on computer vision}, + year={2017} +} +``` diff --git a/configs/retinanet/_base_/optimizer_1x.yml b/configs/retinanet/_base_/optimizer_1x.yml new file mode 100644 index 000000000..39c54ac80 --- /dev/null +++ b/configs/retinanet/_base_/optimizer_1x.yml @@ -0,0 +1,19 @@ +epoch: 12 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [8, 11] + - !LinearWarmup + start_factor: 0.001 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 diff --git a/configs/retinanet/_base_/retinanet_r50_fpn.yml b/configs/retinanet/_base_/retinanet_r50_fpn.yml new file mode 100644 index 000000000..156a17fea --- /dev/null +++ b/configs/retinanet/_base_/retinanet_r50_fpn.yml @@ -0,0 +1,65 @@ +architecture: RetinaNet +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams + +RetinaNet: + backbone: ResNet + neck: FPN + head: RetinaHead + +ResNet: + depth: 50 + variant: b + norm_type: bn + freeze_at: 0 + return_idx: [1,2,3] + num_stages: 4 + +FPN: + out_channel: 256 + spatial_scales: [0.125, 0.0625, 0.03125] + extra_stage: 2 + has_extra_convs: true + use_c5: false + +RetinaHead: + num_classes: 80 + prior_prob: 0.01 + nms_pre: 1000 + decode_reg_out: false + conv_feat: + name: RetinaFeat + feat_in: 256 + feat_out: 256 + num_convs: 4 + norm_type: null + use_dcn: false + anchor_generator: + name: RetinaAnchorGenerator + octave_base_scale: 4 + scales_per_octave: 3 + aspect_ratios: [0.5, 1.0, 2.0] + strides: [8.0, 16.0, 32.0, 64.0, 128.0] + bbox_assigner: + name: MaxIoUAssigner + positive_overlap: 0.5 + negative_overlap: 0.4 + allow_low_quality: true + bbox_coder: + name: DeltaBBoxCoder + norm_mean: [0.0, 0.0, 0.0, 0.0] + norm_std: [1.0, 1.0, 1.0, 1.0] + loss_class: + name: FocalLoss + gamma: 2.0 + alpha: 0.25 + loss_weight: 1.0 + loss_bbox: + name: SmoothL1Loss + beta: 0.0 + loss_weight: 1.0 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.05 + nms_threshold: 0.5 diff --git a/configs/retinanet/_base_/retinanet_reader.yml b/configs/retinanet/_base_/retinanet_reader.yml new file mode 100644 index 000000000..8cf31aa5e --- /dev/null +++ b/configs/retinanet/_base_/retinanet_reader.yml @@ -0,0 +1,39 @@ +worker_num: 2 +TrainReader: + sample_transforms: + - Decode: {} + - RandomFlip: {prob: 0.5} + - Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1} + - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 2 + shuffle: true + drop_last: true + use_process: true + collate_batch: false + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1} + - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 2 + shuffle: false + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1} + - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false diff --git a/configs/retinanet/retinanet_r50_fpn_1x_coco.yml b/configs/retinanet/retinanet_r50_fpn_1x_coco.yml new file mode 100644 index 000000000..bb2c5a404 --- /dev/null +++ b/configs/retinanet/retinanet_r50_fpn_1x_coco.yml @@ -0,0 +1,10 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/retinanet_r50_fpn.yml', + '_base_/optimizer_1x.yml', + '_base_/retinanet_reader.yml' +] + +weights: output/retinanet_r50_fpn_1x_coco/model_final +find_unused_parameters: true \ No newline at end of file diff --git a/configs/retinanet/retinanet_r50_fpn_mstrain_1x_coco.yml b/configs/retinanet/retinanet_r50_fpn_mstrain_1x_coco.yml new file mode 100644 index 000000000..ef4023d22 --- /dev/null +++ b/configs/retinanet/retinanet_r50_fpn_mstrain_1x_coco.yml @@ -0,0 +1,20 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/retinanet_r50_fpn.yml', + '_base_/optimizer_1x.yml', + '_base_/retinanet_reader.yml' +] + +worker_num: 4 +TrainReader: + batch_size: 4 + sample_transforms: + - Decode: {} + - RandomFlip: {prob: 0.5} + - RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: true, interp: 1} + - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} + +weights: output/retinanet_r50_fpn_mstrain_1x_coco/model_final +find_unused_parameters: true \ No newline at end of file diff --git a/ppdet/modeling/__init__.py b/ppdet/modeling/__init__.py index cdcb5d1bf..88a9a3570 100644 --- a/ppdet/modeling/__init__.py +++ b/ppdet/modeling/__init__.py @@ -29,6 +29,7 @@ from . import reid from . import mot from . import transformers from . import assigners +from . import coders from .ops import * from .backbones import * @@ -43,3 +44,4 @@ from .reid import * from .mot import * from .transformers import * from .assigners import * +from .coders import * diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index b5feb06d8..30aecac61 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -26,6 +26,7 @@ from . import picodet from . import detr from . import sparse_rcnn from . import tood +from . import retinanet from .meta_arch import * from .faster_rcnn import * @@ -49,3 +50,4 @@ from .picodet import * from .detr import * from .sparse_rcnn import * from .tood import * +from .retinanet import * diff --git a/ppdet/modeling/architectures/retinanet.py b/ppdet/modeling/architectures/retinanet.py new file mode 100644 index 000000000..5e9ce2de4 --- /dev/null +++ b/ppdet/modeling/architectures/retinanet.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ppdet.core.workspace import register, create +from .meta_arch import BaseArch +import paddle + +__all__ = ['RetinaNet'] + +@register +class RetinaNet(BaseArch): + __category__ = 'architecture' + + def __init__(self, + backbone, + neck, + head): + super(RetinaNet, self).__init__() + self.backbone = backbone + self.neck = neck + self.head = head + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + backbone = create(cfg['backbone']) + kwargs = {'input_shape': backbone.out_shape} + neck = create(cfg['neck'], **kwargs) + head = create(cfg['head']) + return { + 'backbone': backbone, + 'neck': neck, + 'head': head} + + def _forward(self): + body_feats = self.backbone(self.inputs) + neck_feats = self.neck(body_feats) + head_outs = self.head(neck_feats) + if not self.training: + im_shape = self.inputs['im_shape'] + scale_factor = self.inputs['scale_factor'] + bboxes, bbox_num = self.head.post_process(head_outs, im_shape, scale_factor) + return bboxes, bbox_num + return head_outs + + def get_loss(self): + loss = dict() + head_outs = self._forward() + loss_retina = self.head.get_loss(head_outs, self.inputs) + loss.update(loss_retina) + total_loss = paddle.add_n(list(loss.values())) + loss.update(loss=total_loss) + return loss + + def get_pred(self): + bbox_pred, bbox_num = self._forward() + output = dict(bbox=bbox_pred, bbox_num=bbox_num) + return output diff --git a/ppdet/modeling/assigners/__init__.py b/ppdet/modeling/assigners/__init__.py index be5bb04d3..f82266b92 100644 --- a/ppdet/modeling/assigners/__init__.py +++ b/ppdet/modeling/assigners/__init__.py @@ -16,8 +16,10 @@ from . import utils from . import task_aligned_assigner from . import atss_assigner from . import simota_assigner +from . import max_iou_assigner from .utils import * from .task_aligned_assigner import * from .atss_assigner import * from .simota_assigner import * +from .max_iou_assigner import * diff --git a/ppdet/modeling/assigners/max_iou_assigner.py b/ppdet/modeling/assigners/max_iou_assigner.py new file mode 100644 index 000000000..98a4fdf8c --- /dev/null +++ b/ppdet/modeling/assigners/max_iou_assigner.py @@ -0,0 +1,52 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ppdet.core.workspace import register +from ppdet.modeling.proposal_generator.target import label_box + +__all__ = ['MaxIoUAssigner'] + +@register +class MaxIoUAssigner(object): + """a standard bbox assigner based on max IoU, use ppdet's label_box + as backend. + Args: + positive_overlap (float): threshold for defining positive samples + negative_overlap (float): threshold for denining negative samples + allow_low_quality (bool): whether to lower IoU thr if a GT poorly + overlaps with candidate bboxes + """ + def __init__(self, + positive_overlap, + negative_overlap, + allow_low_quality=True): + self.positive_overlap = positive_overlap + self.negative_overlap = negative_overlap + self.allow_low_quality = allow_low_quality + + def __call__(self, bboxes, gt_bboxes): + matches, match_labels = label_box( + bboxes, + gt_bboxes, + positive_overlap=self.positive_overlap, + negative_overlap=self.negative_overlap, + allow_low_quality=self.allow_low_quality, + ignore_thresh=-1, + is_crowd=None, + assign_on_cpu=False) + return matches, match_labels diff --git a/ppdet/modeling/bbox_utils.py b/ppdet/modeling/bbox_utils.py index cc3d7bb05..49a5d46fc 100644 --- a/ppdet/modeling/bbox_utils.py +++ b/ppdet/modeling/bbox_utils.py @@ -775,3 +775,93 @@ def batch_distance2bbox(points, distance, max_shapes=None): out_bbox = paddle.where(out_bbox > 0, out_bbox, paddle.zeros_like(out_bbox)) return out_bbox + + +def delta2bbox_v2(rois, + deltas, + means=(0.0, 0.0, 0.0, 0.0), + stds=(1.0, 1.0, 1.0, 1.0), + max_shape=None, + wh_ratio_clip=16.0/1000.0, + ctr_clip=None): + """Transform network output(delta) to bboxes. + Based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/ + bbox/coder/delta_xywh_bbox_coder.py + Args: + rois (Tensor): shape [..., 4], base bboxes, typical examples include + anchor and rois + deltas (Tensor): shape [..., 4], offset relative to base bboxes + means (list[float]): the mean that was used to normalize deltas, + must be of size 4 + stds (list[float]): the std that was used to normalize deltas, + must be of size 4 + max_shape (list[float] or None): height and width of image, will be + used to clip bboxes if not None + wh_ratio_clip (float): to clip delta wh of decoded bboxes + ctr_clip (float or None): whether to clip delta xy of decoded bboxes + """ + if rois.size == 0: + return paddle.empty_like(rois) + means = paddle.to_tensor(means) + stds = paddle.to_tensor(stds) + deltas = deltas * stds + means + + dxy = deltas[..., :2] + dwh = deltas[..., 2:] + + pxy = (rois[..., :2] + rois[..., 2:]) * 0.5 + pwh = rois[..., 2:] - rois[..., :2] + dxy_wh = pwh * dxy + + max_ratio = np.abs(np.log(wh_ratio_clip)) + if ctr_clip is not None: + dxy_wh = paddle.clip(dxy_wh, max=ctr_clip, min=-ctr_clip) + dwh = paddle.clip(dwh, max=max_ratio) + else: + dwh = dwh.clip(min=-max_ratio, max=max_ratio) + + gxy = pxy + dxy_wh + gwh = pwh * dwh.exp() + x1y1 = gxy - (gwh * 0.5) + x2y2 = gxy + (gwh * 0.5) + bboxes = paddle.concat([x1y1, x2y2], axis=-1) + if max_shape is not None: + bboxes[..., 0::2] = bboxes[..., 0::2].clip(min=0, max=max_shape[1]) + bboxes[..., 1::2] = bboxes[..., 1::2].clip(min=0, max=max_shape[0]) + return bboxes + + +def bbox2delta_v2(src_boxes, + tgt_boxes, + means=(0.0, 0.0, 0.0, 0.0), + stds=(1.0, 1.0, 1.0, 1.0)): + """Encode bboxes to deltas. + Modified from ppdet.modeling.bbox_utils.bbox2delta. + Args: + src_boxes (Tensor[..., 4]): base bboxes + tgt_boxes (Tensor[..., 4]): target bboxes + means (list[float]): the mean that will be used to normalize delta + stds (list[float]): the std that will be used to normalize delta + """ + if src_boxes.size == 0: + return paddle.empty_like(src_boxes) + src_w = src_boxes[..., 2] - src_boxes[..., 0] + src_h = src_boxes[..., 3] - src_boxes[..., 1] + src_ctr_x = src_boxes[..., 0] + 0.5 * src_w + src_ctr_y = src_boxes[..., 1] + 0.5 * src_h + + tgt_w = tgt_boxes[..., 2] - tgt_boxes[..., 0] + tgt_h = tgt_boxes[..., 3] - tgt_boxes[..., 1] + tgt_ctr_x = tgt_boxes[..., 0] + 0.5 * tgt_w + tgt_ctr_y = tgt_boxes[..., 1] + 0.5 * tgt_h + + dx = (tgt_ctr_x - src_ctr_x) / src_w + dy = (tgt_ctr_y - src_ctr_y) / src_h + dw = paddle.log(tgt_w / src_w) + dh = paddle.log(tgt_h / src_h) + + deltas = paddle.stack((dx, dy, dw, dh), axis=1) # [n, 4] + means = paddle.to_tensor(means, place=src_boxes.place) + stds = paddle.to_tensor(stds, place=src_boxes.place) + deltas = (deltas - means) / stds + return deltas diff --git a/ppdet/modeling/coders/__init__.py b/ppdet/modeling/coders/__init__.py new file mode 100644 index 000000000..7726bb36c --- /dev/null +++ b/ppdet/modeling/coders/__init__.py @@ -0,0 +1 @@ +from .delta_bbox_coder import DeltaBBoxCoder diff --git a/ppdet/modeling/coders/delta_bbox_coder.py b/ppdet/modeling/coders/delta_bbox_coder.py new file mode 100644 index 000000000..0c53ea349 --- /dev/null +++ b/ppdet/modeling/coders/delta_bbox_coder.py @@ -0,0 +1,40 @@ +import paddle +import numpy as np +from ppdet.core.workspace import register +from ppdet.modeling.bbox_utils import delta2bbox_v2, bbox2delta_v2 + +__all__ = ['DeltaBBoxCoder'] + + +@register +class DeltaBBoxCoder: + """Encode bboxes in terms of delta/offset of a reference bbox. + Args: + norm_mean (list[float]): the mean to normalize delta + norm_std (list[float]): the std to normalize delta + wh_ratio_clip (float): to clip delta wh of decoded bboxes + ctr_clip (float or None): whether to clip delta xy of decoded bboxes + """ + def __init__(self, + norm_mean=[0.0, 0.0, 0.0, 0.0], + norm_std=[1., 1., 1., 1.], + wh_ratio_clip=16/1000.0, + ctr_clip=None): + self.norm_mean = norm_mean + self.norm_std = norm_std + self.wh_ratio_clip = wh_ratio_clip + self.ctr_clip = ctr_clip + + def encode(self, bboxes, tar_bboxes): + return bbox2delta_v2( + bboxes, tar_bboxes, means=self.norm_mean, stds=self.norm_std) + + def decode(self, bboxes, deltas, max_shape=None): + return delta2bbox_v2( + bboxes, + deltas, + max_shape=max_shape, + wh_ratio_clip=self.wh_ratio_clip, + ctr_clip=self.ctr_clip, + means=self.norm_mean, + stds=self.norm_std) diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index b6b928608..6272f2977 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -31,6 +31,7 @@ from . import pico_head from . import detr_head from . import sparsercnn_head from . import tood_head +from . import retina_head from .bbox_head import * from .mask_head import * @@ -51,3 +52,4 @@ from .pico_head import * from .detr_head import * from .sparsercnn_head import * from .tood_head import * +from .retina_head import * diff --git a/ppdet/modeling/heads/fcos_head.py b/ppdet/modeling/heads/fcos_head.py index 1d61feed6..758b1bf8b 100644 --- a/ppdet/modeling/heads/fcos_head.py +++ b/ppdet/modeling/heads/fcos_head.py @@ -64,6 +64,8 @@ class FCOSFeat(nn.Layer): norm_type='bn', use_dcn=False): super(FCOSFeat, self).__init__() + self.feat_in = feat_in + self.feat_out = feat_out self.num_convs = num_convs self.norm_type = norm_type self.cls_subnet_convs = [] diff --git a/ppdet/modeling/heads/retina_head.py b/ppdet/modeling/heads/retina_head.py new file mode 100644 index 000000000..e8f5cbd0a --- /dev/null +++ b/ppdet/modeling/heads/retina_head.py @@ -0,0 +1,257 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math, paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn.initializer import Normal, Constant +from ppdet.modeling.proposal_generator import AnchorGenerator +from ppdet.core.workspace import register +from ppdet.modeling.heads.fcos_head import FCOSFeat + +__all__ = ['RetinaHead'] + +@register +class RetinaFeat(FCOSFeat): + """We use FCOSFeat to construct conv layers in RetinaNet. + We rename FCOSFeat to RetinaFeat to avoid confusion. + """ + pass + +@register +class RetinaAnchorGenerator(AnchorGenerator): + def __init__(self, + octave_base_scale=4, + scales_per_octave=3, + aspect_ratios=[0.5, 1.0, 2.0], + strides=[8.0, 16.0, 32.0, 64.0, 128.0], + variance=[1.0, 1.0, 1.0, 1.0], + offset=0.0): + anchor_sizes = [] + for s in strides: + anchor_sizes.append([ + s * octave_base_scale * 2**(i/scales_per_octave) \ + for i in range(scales_per_octave)]) + super(RetinaAnchorGenerator, self).__init__( + anchor_sizes=anchor_sizes, + aspect_ratios=aspect_ratios, + strides=strides, + variance=variance, + offset=offset) + +@register +class RetinaHead(nn.Layer): + """Used in RetinaNet proposed in paper https://arxiv.org/pdf/1708.02002.pdf + """ + __inject__ = [ + 'conv_feat', 'anchor_generator', 'bbox_assigner', + 'bbox_coder', 'loss_class', 'loss_bbox', 'nms'] + def __init__(self, + num_classes=80, + prior_prob=0.01, + decode_reg_out=False, + conv_feat=None, + anchor_generator=None, + bbox_assigner=None, + bbox_coder=None, + loss_class=None, + loss_bbox=None, + nms_pre=1000, + nms=None): + super(RetinaHead, self).__init__() + self.num_classes = num_classes + self.prior_prob = prior_prob + # allow RetinaNet to use IoU based losses. + self.decode_reg_out = decode_reg_out + self.conv_feat = conv_feat + self.anchor_generator = anchor_generator + self.bbox_assigner = bbox_assigner + self.bbox_coder = bbox_coder + self.loss_class = loss_class + self.loss_bbox = loss_bbox + self.nms_pre = nms_pre + self.nms = nms + self.cls_out_channels = num_classes + self.init_layers() + + def init_layers(self): + bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob) + num_anchors = self.anchor_generator.num_anchors + self.retina_cls = nn.Conv2D( + in_channels=self.conv_feat.feat_out, + out_channels=self.cls_out_channels * num_anchors, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(initializer=Normal(mean=0.0, std=0.01)), + bias_attr=ParamAttr(initializer=Constant(value=bias_init_value))) + self.retina_reg = nn.Conv2D( + in_channels=self.conv_feat.feat_out, + out_channels=4 * num_anchors, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(initializer=Normal(mean=0.0, std=0.01)), + bias_attr=ParamAttr(initializer=Constant(value=0))) + + def forward(self, neck_feats): + cls_logits_list = [] + bboxes_reg_list = [] + for neck_feat in neck_feats: + conv_cls_feat, conv_reg_feat = self.conv_feat(neck_feat) + cls_logits = self.retina_cls(conv_cls_feat) + bbox_reg = self.retina_reg(conv_reg_feat) + cls_logits_list.append(cls_logits) + bboxes_reg_list.append(bbox_reg) + return (cls_logits_list, bboxes_reg_list) + + def get_loss(self, head_outputs, meta): + """Here we calculate loss for a batch of images. + We assign anchors to gts in each image and gather all the assigned + postive and negative samples. Then loss is calculated on the gathered + samples. + """ + cls_logits, bboxes_reg = head_outputs + # we use the same anchor for all images + anchors = self.anchor_generator(cls_logits) + anchors = paddle.concat(anchors) + + # matches: contain gt_inds + # match_labels: -1(ignore), 0(neg) or 1(pos) + matches_list, match_labels_list = [], [] + # assign anchors to gts, no sampling is involved + for gt_bbox in meta['gt_bbox']: + matches, match_labels = self.bbox_assigner(anchors, gt_bbox) + matches_list.append(matches) + match_labels_list.append(match_labels) + # reshape network outputs + cls_logits = [_.transpose([0, 2, 3, 1]) for _ in cls_logits] + cls_logits = [_.reshape([0, -1, self.cls_out_channels]) \ + for _ in cls_logits] + bboxes_reg = [_.transpose([0, 2, 3, 1]) for _ in bboxes_reg] + bboxes_reg = [_.reshape([0, -1, 4]) for _ in bboxes_reg] + cls_logits = paddle.concat(cls_logits, axis=1) + bboxes_reg = paddle.concat(bboxes_reg, axis=1) + + cls_pred_list, cls_tar_list = [], [] + reg_pred_list, reg_tar_list = [], [] + # find and gather preds and targets in each image + for matches, match_labels, cls_logit, bbox_reg, gt_bbox, gt_class in \ + zip(matches_list, match_labels_list, cls_logits, bboxes_reg, + meta['gt_bbox'], meta['gt_class']): + pos_mask = (match_labels == 1) + neg_mask = (match_labels == 0) + chosen_mask = paddle.logical_or(pos_mask, neg_mask) + + gt_class = gt_class.reshape([-1]) + bg_class = paddle.to_tensor( + [self.num_classes], dtype=gt_class.dtype) + # a trick to assign num_classes to negative targets + gt_class = paddle.concat([gt_class, bg_class]) + matches = paddle.where( + neg_mask, paddle.full_like(matches, gt_class.size-1), matches) + + cls_pred = cls_logit[chosen_mask] + cls_tar = gt_class[matches[chosen_mask]] + reg_pred = bbox_reg[pos_mask].reshape([-1, 4]) + reg_tar = gt_bbox[matches[pos_mask]].reshape([-1, 4]) + if self.decode_reg_out: + reg_pred = self.bbox_coder.decode( + anchors[pos_mask], reg_pred) + else: + reg_tar = self.bbox_coder.encode(anchors[pos_mask], reg_tar) + cls_pred_list.append(cls_pred) + cls_tar_list.append(cls_tar) + reg_pred_list.append(reg_pred) + reg_tar_list.append(reg_tar) + cls_pred = paddle.concat(cls_pred_list) + cls_tar = paddle.concat(cls_tar_list) + reg_pred = paddle.concat(reg_pred_list) + reg_tar = paddle.concat(reg_tar_list) + avg_factor = max(1.0, reg_pred.shape[0]) + cls_loss = self.loss_class( + cls_pred, cls_tar, reduction='sum')/avg_factor + if reg_pred.size == 0: + reg_loss = bboxes_reg[0][0].sum() * 0 + else: + reg_loss = self.loss_bbox( + reg_pred, reg_tar, reduction='sum')/avg_factor + return dict(loss_cls=cls_loss, loss_reg=reg_loss) + + def get_bboxes_single(self, + anchors, + cls_scores, + bbox_preds, + im_shape, + scale_factor, + rescale=True): + assert len(cls_scores) == len(bbox_preds) + mlvl_bboxes = [] + mlvl_scores = [] + for anchor, cls_score, bbox_pred in zip(anchors, cls_scores, bbox_preds): + cls_score = cls_score.reshape([-1, self.num_classes]) + bbox_pred = bbox_pred.reshape([-1, 4]) + if self.nms_pre is not None and cls_score.shape[0] > self.nms_pre: + max_score = cls_score.max(axis=1) + _, topk_inds = max_score.topk(self.nms_pre) + bbox_pred = bbox_pred.gather(topk_inds) + anchor = anchor.gather(topk_inds) + cls_score = cls_score.gather(topk_inds) + bbox_pred = self.bbox_coder.decode( + anchor, bbox_pred, max_shape=im_shape) + bbox_pred = bbox_pred.squeeze() + mlvl_bboxes.append(bbox_pred) + mlvl_scores.append(F.sigmoid(cls_score)) + mlvl_bboxes = paddle.concat(mlvl_bboxes) + mlvl_bboxes = paddle.squeeze(mlvl_bboxes) + if rescale: + mlvl_bboxes = mlvl_bboxes / paddle.concat( + [scale_factor[::-1], scale_factor[::-1]]) + mlvl_scores = paddle.concat(mlvl_scores) + mlvl_scores = mlvl_scores.transpose([1, 0]) + return mlvl_bboxes, mlvl_scores + + def decode(self, anchors, cls_scores, bbox_preds, im_shape, scale_factor): + batch_bboxes = [] + batch_scores = [] + for img_id in range(cls_scores[0].shape[0]): + num_lvls = len(cls_scores) + cls_score_list = [cls_scores[i][img_id] for i in range(num_lvls)] + bbox_pred_list = [bbox_preds[i][img_id] for i in range(num_lvls)] + bboxes, scores = self.get_bboxes_single( + anchors, + cls_score_list, + bbox_pred_list, + im_shape[img_id], + scale_factor[img_id]) + batch_bboxes.append(bboxes) + batch_scores.append(scores) + batch_bboxes = paddle.stack(batch_bboxes, axis=0) + batch_scores = paddle.stack(batch_scores, axis=0) + return batch_bboxes, batch_scores + + def post_process(self, head_outputs, im_shape, scale_factor): + cls_scores, bbox_preds = head_outputs + anchors = self.anchor_generator(cls_scores) + cls_scores = [_.transpose([0, 2, 3, 1]) for _ in cls_scores] + bbox_preds = [_.transpose([0, 2, 3, 1]) for _ in bbox_preds] + bboxes, scores = self.decode( + anchors, cls_scores, bbox_preds, im_shape, scale_factor) + bbox_pred, bbox_num, _ = self.nms(bboxes, scores) + return bbox_pred, bbox_num diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 894fa3c8f..5af677eec 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -128,7 +128,7 @@ class ConvNormLayer(nn.Layer): dcn_lr_scale=2., dcn_regularizer=L2Decay(0.)): super(ConvNormLayer, self).__init__() - assert norm_type in ['bn', 'sync_bn', 'gn'] + assert norm_type in ['bn', 'sync_bn', 'gn', None] if bias_on: bias_attr = ParamAttr( @@ -183,10 +183,13 @@ class ConvNormLayer(nn.Layer): num_channels=ch_out, weight_attr=param_attr, bias_attr=bias_attr) + else: + self.norm = None def forward(self, inputs): out = self.conv(inputs) - out = self.norm(out) + if self.norm is not None: + out = self.norm(out) return out diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py index 83389c08e..94eff1f17 100644 --- a/ppdet/modeling/losses/__init__.py +++ b/ppdet/modeling/losses/__init__.py @@ -25,6 +25,8 @@ from . import fairmot_loss from . import gfocal_loss from . import detr_loss from . import sparsercnn_loss +from . import focal_loss +from . import smooth_l1_loss from .yolo_loss import * from .iou_aware_loss import * @@ -39,3 +41,5 @@ from .fairmot_loss import * from .gfocal_loss import * from .detr_loss import * from .sparsercnn_loss import * +from .focal_loss import * +from .smooth_l1_loss import * diff --git a/ppdet/modeling/losses/focal_loss.py b/ppdet/modeling/losses/focal_loss.py new file mode 100644 index 000000000..083e1dd3d --- /dev/null +++ b/ppdet/modeling/losses/focal_loss.py @@ -0,0 +1,61 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn.functional as F +import paddle.nn as nn +from ppdet.core.workspace import register + +__all__ = ['FocalLoss'] + +@register +class FocalLoss(nn.Layer): + """A wrapper around paddle.nn.functional.sigmoid_focal_loss. + Args: + use_sigmoid (bool): currently only support use_sigmoid=True + alpha (float): parameter alpha in Focal Loss + gamma (float): parameter gamma in Focal Loss + loss_weight (float): final loss will be multiplied by this + """ + def __init__(self, + use_sigmoid=True, + alpha=0.25, + gamma=2.0, + loss_weight=1.0): + super(FocalLoss, self).__init__() + assert use_sigmoid == True, \ + 'Focal Loss only supports sigmoid at the moment' + self.use_sigmoid = use_sigmoid + self.alpha = alpha + self.gamma = gamma + self.loss_weight = loss_weight + + def forward(self, pred, target, reduction='none'): + """forward function. + Args: + pred (Tensor): logits of class prediction, of shape (N, num_classes) + target (Tensor): target class label, of shape (N, ) + reduction (str): the way to reduce loss, one of (none, sum, mean) + """ + num_classes = pred.shape[1] + target = F.one_hot(target, num_classes+1).cast(pred.dtype) + target = target[:, :-1].detach() + loss = F.sigmoid_focal_loss( + pred, target, alpha=self.alpha, gamma=self.gamma, + reduction=reduction) + return loss * self.loss_weight diff --git a/ppdet/modeling/losses/smooth_l1_loss.py b/ppdet/modeling/losses/smooth_l1_loss.py new file mode 100644 index 000000000..f89c28f66 --- /dev/null +++ b/ppdet/modeling/losses/smooth_l1_loss.py @@ -0,0 +1,60 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register + +__all__ = ['SmoothL1Loss'] + +@register +class SmoothL1Loss(nn.Layer): + """Smooth L1 Loss. + Args: + beta (float): controls smooth region, it becomes L1 Loss when beta=0.0 + loss_weight (float): the final loss will be multiplied by this + """ + def __init__(self, + beta=1.0, + loss_weight=1.0): + super(SmoothL1Loss, self).__init__() + assert beta >= 0 + self.beta = beta + self.loss_weight = loss_weight + + def forward(self, pred, target, reduction='none'): + """forward function, based on fvcore. + Args: + pred (Tensor): prediction tensor + target (Tensor): target tensor, pred.shape must be the same as target.shape + reduction (str): the way to reduce loss, one of (none, sum, mean) + """ + assert reduction in ('none', 'sum', 'mean') + target = target.detach() + if self.beta < 1e-5: + loss = paddle.abs(pred - target) + else: + n = paddle.abs(pred - target) + cond = n < self.beta + loss = paddle.where(cond, 0.5 * n ** 2 / self.beta, n - 0.5 * self.beta) + if reduction == 'mean': + loss = loss.mean() if loss.size > 0 else 0.0 * loss.sum() + elif reduction == 'sum': + loss = loss.sum() + return loss * self.loss_weight -- GitLab