diff --git a/configs/retinanet/README.md b/configs/retinanet/README.md index bfa281321ed0b2509a423390b26db15fabc5caf4..8a9ec59b9ca377576c0f0b5f8ec0d01ff7684a17 100644 --- a/configs/retinanet/README.md +++ b/configs/retinanet/README.md @@ -1,20 +1,14 @@ -# Focal Loss for Dense Object Detection - -## Introduction - -We reproduce RetinaNet proposed in paper Focal Loss for Dense Object Detection. +# RetinaNet (Focal Loss for Dense Object Detection) ## Model Zoo -| Backbone | Model | mstrain | imgs/GPU | lr schedule | FPS | Box AP | download | config | -| ------------ | --------- | ------- | -------- | ----------- | --- | ------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------- | -| ResNet50-FPN | RetinaNet | Yes | 4 | 1x | --- | 37.5 | [model](https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_mstrain_1x_coco.pdparams)\|[log](https://bj.bcebos.com/v1/paddledet/logs/retinanet_r50_fpn_mstrain_1x_coco.log) | retinanet_r50_fpn_mstrain_1x_coco.yml | - +| Backbone | Model | imgs/GPU | lr schedule | FPS | Box AP | download | config | +| ------------ | --------- | -------- | ----------- | --- | ------ | ---------- | ----------- | +| ResNet50-FPN | RetinaNet | 2 | 1x | --- | 37.5 | [model](https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_1x_coco.pdparams) | [config](./retinanet_r50_fpn_1x_coco.yml) | **Notes:** -- All above models are trained on COCO train2017 with 4 GPUs and evaludated on val2017. Box AP=`mAP(IoU=0.5:0.95)`. +- All above models are trained on COCO train2017 with 8 GPUs and evaludated on val2017. Box AP=`mAP(IoU=0.5:0.95)`. -- Config `configs/retinanet/retinanet_r50_fpn_1x_coco.yml` is for 8 GPUs and `configs/retinanet/retinanet_r50_fpn_mstrain_1x_coco.yml` is for 4 GPUs (mind the difference of train batch size). ## Citation diff --git a/configs/retinanet/_base_/retinanet_r50_fpn.yml b/configs/retinanet/_base_/retinanet_r50_fpn.yml index 156a17fea84119322c4e34b5e58b37e47cadcb63..fb2d767aed5bd383f312ce79e4e39e3710c3cb9c 100644 --- a/configs/retinanet/_base_/retinanet_r50_fpn.yml +++ b/configs/retinanet/_base_/retinanet_r50_fpn.yml @@ -22,10 +22,6 @@ FPN: use_c5: false RetinaHead: - num_classes: 80 - prior_prob: 0.01 - nms_pre: 1000 - decode_reg_out: false conv_feat: name: RetinaFeat feat_in: 256 @@ -44,10 +40,6 @@ RetinaHead: positive_overlap: 0.5 negative_overlap: 0.4 allow_low_quality: true - bbox_coder: - name: DeltaBBoxCoder - norm_mean: [0.0, 0.0, 0.0, 0.0] - norm_std: [1.0, 1.0, 1.0, 1.0] loss_class: name: FocalLoss gamma: 2.0 diff --git a/configs/retinanet/_base_/retinanet_reader.yml b/configs/retinanet/_base_/retinanet_reader.yml index 8cf31aa5ecdb903ce50e6c48ca7fb8429f3d776b..1f686b4d7f06f143106491e9b8fe3957a40927c2 100644 --- a/configs/retinanet/_base_/retinanet_reader.yml +++ b/configs/retinanet/_base_/retinanet_reader.yml @@ -1,39 +1,36 @@ worker_num: 2 TrainReader: sample_transforms: - - Decode: {} - - RandomFlip: {prob: 0.5} - - Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1} - - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} - - Permute: {} + - Decode: {} + - RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: True, interp: 1} + - RandomFlip: {} + - NormalizeImage: {is_scale: True, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32} + - PadBatch: {pad_to_stride: 32} batch_size: 2 - shuffle: true - drop_last: true - use_process: true - collate_batch: false + shuffle: True + drop_last: True + collate_batch: False EvalReader: sample_transforms: - - Decode: {} - - Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1} - - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} - - Permute: {} + - Decode: {} + - Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1} + - NormalizeImage: {is_scale: True, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32} - batch_size: 2 - shuffle: false + - PadBatch: {pad_to_stride: 32} + batch_size: 8 TestReader: sample_transforms: - - Decode: {} - - Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1} - - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} - - Permute: {} + - Decode: {} + - Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1} + - NormalizeImage: {is_scale: True, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32} + - PadBatch: {pad_to_stride: 32} batch_size: 1 - shuffle: false diff --git a/configs/retinanet/retinanet_r50_fpn_1x_coco.yml b/configs/retinanet/retinanet_r50_fpn_1x_coco.yml index bb2c5a404033691650b99430649cd512a81a91be..cb6d342baeb428547d42f417acda02e8c90e39da 100644 --- a/configs/retinanet/retinanet_r50_fpn_1x_coco.yml +++ b/configs/retinanet/retinanet_r50_fpn_1x_coco.yml @@ -7,4 +7,3 @@ _BASE_: [ ] weights: output/retinanet_r50_fpn_1x_coco/model_final -find_unused_parameters: true \ No newline at end of file diff --git a/configs/retinanet/retinanet_r50_fpn_mstrain_1x_coco.yml b/configs/retinanet/retinanet_r50_fpn_mstrain_1x_coco.yml deleted file mode 100644 index ef4023d2284941e6df255dd4e403f88e0d2d1513..0000000000000000000000000000000000000000 --- a/configs/retinanet/retinanet_r50_fpn_mstrain_1x_coco.yml +++ /dev/null @@ -1,20 +0,0 @@ -_BASE_: [ - '../datasets/coco_detection.yml', - '../runtime.yml', - '_base_/retinanet_r50_fpn.yml', - '_base_/optimizer_1x.yml', - '_base_/retinanet_reader.yml' -] - -worker_num: 4 -TrainReader: - batch_size: 4 - sample_transforms: - - Decode: {} - - RandomFlip: {prob: 0.5} - - RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: true, interp: 1} - - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} - - Permute: {} - -weights: output/retinanet_r50_fpn_mstrain_1x_coco/model_final -find_unused_parameters: true \ No newline at end of file diff --git a/ppdet/modeling/__init__.py b/ppdet/modeling/__init__.py index 88a9a3570477f55e8f7fbfeae4fd84271a76256d..cdcb5d1bf08d813257dc577366de2efa9da9add7 100644 --- a/ppdet/modeling/__init__.py +++ b/ppdet/modeling/__init__.py @@ -29,7 +29,6 @@ from . import reid from . import mot from . import transformers from . import assigners -from . import coders from .ops import * from .backbones import * @@ -44,4 +43,3 @@ from .reid import * from .mot import * from .transformers import * from .assigners import * -from .coders import * diff --git a/ppdet/modeling/architectures/retinanet.py b/ppdet/modeling/architectures/retinanet.py index 5e9ce2de4b045abae60434cedb30976ba3398e9d..e774430a03dfebf74c1e91138ed57f2ee52f1c9d 100644 --- a/ppdet/modeling/architectures/retinanet.py +++ b/ppdet/modeling/architectures/retinanet.py @@ -22,14 +22,12 @@ import paddle __all__ = ['RetinaNet'] + @register class RetinaNet(BaseArch): __category__ = 'architecture' - def __init__(self, - backbone, - neck, - head): + def __init__(self, backbone, neck, head): super(RetinaNet, self).__init__() self.backbone = backbone self.neck = neck @@ -38,35 +36,33 @@ class RetinaNet(BaseArch): @classmethod def from_config(cls, cfg, *args, **kwargs): backbone = create(cfg['backbone']) + kwargs = {'input_shape': backbone.out_shape} neck = create(cfg['neck'], **kwargs) - head = create(cfg['head']) + + kwargs = {'input_shape': neck.out_shape} + head = create(cfg['head'], **kwargs) + return { 'backbone': backbone, 'neck': neck, - 'head': head} + 'head': head, + } def _forward(self): body_feats = self.backbone(self.inputs) neck_feats = self.neck(body_feats) - head_outs = self.head(neck_feats) - if not self.training: - im_shape = self.inputs['im_shape'] - scale_factor = self.inputs['scale_factor'] - bboxes, bbox_num = self.head.post_process(head_outs, im_shape, scale_factor) - return bboxes, bbox_num - return head_outs + + if self.training: + return self.head(neck_feats, self.inputs) + else: + head_outs = self.head(neck_feats) + bbox, bbox_num = self.head.post_process( + head_outs, self.inputs['im_shape'], self.inputs['scale_factor']) + return {'bbox': bbox, 'bbox_num': bbox_num} def get_loss(self): - loss = dict() - head_outs = self._forward() - loss_retina = self.head.get_loss(head_outs, self.inputs) - loss.update(loss_retina) - total_loss = paddle.add_n(list(loss.values())) - loss.update(loss=total_loss) - return loss + return self._forward() def get_pred(self): - bbox_pred, bbox_num = self._forward() - output = dict(bbox=bbox_pred, bbox_num=bbox_num) - return output + return self._forward() diff --git a/ppdet/modeling/coders/__init__.py b/ppdet/modeling/coders/__init__.py deleted file mode 100644 index 7726bb36cb06430b7bccd64ab89c8ef626e47790..0000000000000000000000000000000000000000 --- a/ppdet/modeling/coders/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .delta_bbox_coder import DeltaBBoxCoder diff --git a/ppdet/modeling/coders/delta_bbox_coder.py b/ppdet/modeling/coders/delta_bbox_coder.py deleted file mode 100644 index 0c53ea349eed4799cba164c4544051cb45d60385..0000000000000000000000000000000000000000 --- a/ppdet/modeling/coders/delta_bbox_coder.py +++ /dev/null @@ -1,40 +0,0 @@ -import paddle -import numpy as np -from ppdet.core.workspace import register -from ppdet.modeling.bbox_utils import delta2bbox_v2, bbox2delta_v2 - -__all__ = ['DeltaBBoxCoder'] - - -@register -class DeltaBBoxCoder: - """Encode bboxes in terms of delta/offset of a reference bbox. - Args: - norm_mean (list[float]): the mean to normalize delta - norm_std (list[float]): the std to normalize delta - wh_ratio_clip (float): to clip delta wh of decoded bboxes - ctr_clip (float or None): whether to clip delta xy of decoded bboxes - """ - def __init__(self, - norm_mean=[0.0, 0.0, 0.0, 0.0], - norm_std=[1., 1., 1., 1.], - wh_ratio_clip=16/1000.0, - ctr_clip=None): - self.norm_mean = norm_mean - self.norm_std = norm_std - self.wh_ratio_clip = wh_ratio_clip - self.ctr_clip = ctr_clip - - def encode(self, bboxes, tar_bboxes): - return bbox2delta_v2( - bboxes, tar_bboxes, means=self.norm_mean, stds=self.norm_std) - - def decode(self, bboxes, deltas, max_shape=None): - return delta2bbox_v2( - bboxes, - deltas, - max_shape=max_shape, - wh_ratio_clip=self.wh_ratio_clip, - ctr_clip=self.ctr_clip, - means=self.norm_mean, - stds=self.norm_std) diff --git a/ppdet/modeling/heads/retina_head.py b/ppdet/modeling/heads/retina_head.py index e8f5cbd0ac194d5adcaa0893cf12f0ffaa0161e9..8705e86febb30d06fcbbd06187a76548450c9600 100644 --- a/ppdet/modeling/heads/retina_head.py +++ b/ppdet/modeling/heads/retina_head.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,17 +16,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math, paddle +import math +import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle import ParamAttr from paddle.nn.initializer import Normal, Constant -from ppdet.modeling.proposal_generator import AnchorGenerator -from ppdet.core.workspace import register +from ppdet.modeling.bbox_utils import bbox2delta, delta2bbox from ppdet.modeling.heads.fcos_head import FCOSFeat +from ppdet.core.workspace import register + __all__ = ['RetinaHead'] + @register class RetinaFeat(FCOSFeat): """We use FCOSFeat to construct conv layers in RetinaNet. @@ -34,72 +37,49 @@ class RetinaFeat(FCOSFeat): """ pass -@register -class RetinaAnchorGenerator(AnchorGenerator): - def __init__(self, - octave_base_scale=4, - scales_per_octave=3, - aspect_ratios=[0.5, 1.0, 2.0], - strides=[8.0, 16.0, 32.0, 64.0, 128.0], - variance=[1.0, 1.0, 1.0, 1.0], - offset=0.0): - anchor_sizes = [] - for s in strides: - anchor_sizes.append([ - s * octave_base_scale * 2**(i/scales_per_octave) \ - for i in range(scales_per_octave)]) - super(RetinaAnchorGenerator, self).__init__( - anchor_sizes=anchor_sizes, - aspect_ratios=aspect_ratios, - strides=strides, - variance=variance, - offset=offset) @register class RetinaHead(nn.Layer): """Used in RetinaNet proposed in paper https://arxiv.org/pdf/1708.02002.pdf """ + __shared__ = ['num_classes'] __inject__ = [ - 'conv_feat', 'anchor_generator', 'bbox_assigner', - 'bbox_coder', 'loss_class', 'loss_bbox', 'nms'] + 'conv_feat', 'anchor_generator', 'bbox_assigner', 'loss_class', + 'loss_bbox', 'nms' + ] + def __init__(self, num_classes=80, + conv_feat='RetinaFeat', + anchor_generator='RetinaAnchorGenerator', + bbox_assigner='MaxIoUAssigner', + loss_class='FocalLoss', + loss_bbox='SmoothL1Loss', + nms='MultiClassNMS', prior_prob=0.01, - decode_reg_out=False, - conv_feat=None, - anchor_generator=None, - bbox_assigner=None, - bbox_coder=None, - loss_class=None, - loss_bbox=None, nms_pre=1000, - nms=None): + weights=[1., 1., 1., 1.]): super(RetinaHead, self).__init__() self.num_classes = num_classes - self.prior_prob = prior_prob - # allow RetinaNet to use IoU based losses. - self.decode_reg_out = decode_reg_out self.conv_feat = conv_feat self.anchor_generator = anchor_generator self.bbox_assigner = bbox_assigner - self.bbox_coder = bbox_coder self.loss_class = loss_class self.loss_bbox = loss_bbox - self.nms_pre = nms_pre self.nms = nms - self.cls_out_channels = num_classes - self.init_layers() + self.nms_pre = nms_pre + self.weights = weights - def init_layers(self): - bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob) + bias_init_value = -math.log((1 - prior_prob) / prior_prob) num_anchors = self.anchor_generator.num_anchors self.retina_cls = nn.Conv2D( in_channels=self.conv_feat.feat_out, - out_channels=self.cls_out_channels * num_anchors, + out_channels=self.num_classes * num_anchors, kernel_size=3, stride=1, padding=1, - weight_attr=ParamAttr(initializer=Normal(mean=0.0, std=0.01)), + weight_attr=ParamAttr(initializer=Normal( + mean=0.0, std=0.01)), bias_attr=ParamAttr(initializer=Constant(value=bias_init_value))) self.retina_reg = nn.Conv2D( in_channels=self.conv_feat.feat_out, @@ -107,10 +87,11 @@ class RetinaHead(nn.Layer): kernel_size=3, stride=1, padding=1, - weight_attr=ParamAttr(initializer=Normal(mean=0.0, std=0.01)), + weight_attr=ParamAttr(initializer=Normal( + mean=0.0, std=0.01)), bias_attr=ParamAttr(initializer=Constant(value=0))) - def forward(self, neck_feats): + def forward(self, neck_feats, targets=None): cls_logits_list = [] bboxes_reg_list = [] for neck_feat in neck_feats: @@ -119,33 +100,40 @@ class RetinaHead(nn.Layer): bbox_reg = self.retina_reg(conv_reg_feat) cls_logits_list.append(cls_logits) bboxes_reg_list.append(bbox_reg) - return (cls_logits_list, bboxes_reg_list) - def get_loss(self, head_outputs, meta): + if self.training: + return self.get_loss([cls_logits_list, bboxes_reg_list], targets) + else: + return [cls_logits_list, bboxes_reg_list] + + def get_loss(self, head_outputs, targets): """Here we calculate loss for a batch of images. We assign anchors to gts in each image and gather all the assigned postive and negative samples. Then loss is calculated on the gathered samples. """ - cls_logits, bboxes_reg = head_outputs - # we use the same anchor for all images - anchors = self.anchor_generator(cls_logits) + cls_logits_list, bboxes_reg_list = head_outputs + anchors = self.anchor_generator(cls_logits_list) anchors = paddle.concat(anchors) # matches: contain gt_inds # match_labels: -1(ignore), 0(neg) or 1(pos) matches_list, match_labels_list = [], [] # assign anchors to gts, no sampling is involved - for gt_bbox in meta['gt_bbox']: + for gt_bbox in targets['gt_bbox']: matches, match_labels = self.bbox_assigner(anchors, gt_bbox) matches_list.append(matches) match_labels_list.append(match_labels) + # reshape network outputs - cls_logits = [_.transpose([0, 2, 3, 1]) for _ in cls_logits] - cls_logits = [_.reshape([0, -1, self.cls_out_channels]) \ - for _ in cls_logits] - bboxes_reg = [_.transpose([0, 2, 3, 1]) for _ in bboxes_reg] - bboxes_reg = [_.reshape([0, -1, 4]) for _ in bboxes_reg] + cls_logits = [ + _.transpose([0, 2, 3, 1]).reshape([0, -1, self.num_classes]) + for _ in cls_logits_list + ] + bboxes_reg = [ + _.transpose([0, 2, 3, 1]).reshape([0, -1, 4]) + for _ in bboxes_reg_list + ] cls_logits = paddle.concat(cls_logits, axis=1) bboxes_reg = paddle.concat(bboxes_reg, axis=1) @@ -154,7 +142,7 @@ class RetinaHead(nn.Layer): # find and gather preds and targets in each image for matches, match_labels, cls_logit, bbox_reg, gt_bbox, gt_class in \ zip(matches_list, match_labels_list, cls_logits, bboxes_reg, - meta['gt_bbox'], meta['gt_class']): + targets['gt_bbox'], targets['gt_class']): pos_mask = (match_labels == 1) neg_mask = (match_labels == 0) chosen_mask = paddle.logical_or(pos_mask, neg_mask) @@ -163,59 +151,65 @@ class RetinaHead(nn.Layer): bg_class = paddle.to_tensor( [self.num_classes], dtype=gt_class.dtype) # a trick to assign num_classes to negative targets - gt_class = paddle.concat([gt_class, bg_class]) - matches = paddle.where( - neg_mask, paddle.full_like(matches, gt_class.size-1), matches) + gt_class = paddle.concat([gt_class, bg_class], axis=-1) + matches = paddle.where(neg_mask, + paddle.full_like(matches, gt_class.size - 1), + matches) cls_pred = cls_logit[chosen_mask] - cls_tar = gt_class[matches[chosen_mask]] + cls_tar = gt_class[matches[chosen_mask]] reg_pred = bbox_reg[pos_mask].reshape([-1, 4]) reg_tar = gt_bbox[matches[pos_mask]].reshape([-1, 4]) - if self.decode_reg_out: - reg_pred = self.bbox_coder.decode( - anchors[pos_mask], reg_pred) - else: - reg_tar = self.bbox_coder.encode(anchors[pos_mask], reg_tar) + reg_tar = bbox2delta(anchors[pos_mask], reg_tar, self.weights) cls_pred_list.append(cls_pred) cls_tar_list.append(cls_tar) reg_pred_list.append(reg_pred) reg_tar_list.append(reg_tar) cls_pred = paddle.concat(cls_pred_list) - cls_tar = paddle.concat(cls_tar_list) + cls_tar = paddle.concat(cls_tar_list) reg_pred = paddle.concat(reg_pred_list) - reg_tar = paddle.concat(reg_tar_list) + reg_tar = paddle.concat(reg_tar_list) + avg_factor = max(1.0, reg_pred.shape[0]) cls_loss = self.loss_class( - cls_pred, cls_tar, reduction='sum')/avg_factor - if reg_pred.size == 0: - reg_loss = bboxes_reg[0][0].sum() * 0 + cls_pred, cls_tar, reduction='sum') / avg_factor + + if reg_pred.shape[0] == 0: + reg_loss = paddle.zeros([1]) + reg_loss.stop_gradient = False else: reg_loss = self.loss_bbox( - reg_pred, reg_tar, reduction='sum')/avg_factor - return dict(loss_cls=cls_loss, loss_reg=reg_loss) + reg_pred, reg_tar, reduction='sum') / avg_factor + + loss = cls_loss + reg_loss + out_dict = { + 'loss_cls': cls_loss, + 'loss_reg': reg_loss, + 'loss': loss, + } + return out_dict def get_bboxes_single(self, anchors, - cls_scores, - bbox_preds, + cls_scores_list, + bbox_preds_list, im_shape, scale_factor, rescale=True): - assert len(cls_scores) == len(bbox_preds) + assert len(cls_scores_list) == len(bbox_preds_list) mlvl_bboxes = [] mlvl_scores = [] - for anchor, cls_score, bbox_pred in zip(anchors, cls_scores, bbox_preds): + for anchor, cls_score, bbox_pred in zip(anchors, cls_scores_list, + bbox_preds_list): cls_score = cls_score.reshape([-1, self.num_classes]) bbox_pred = bbox_pred.reshape([-1, 4]) if self.nms_pre is not None and cls_score.shape[0] > self.nms_pre: max_score = cls_score.max(axis=1) _, topk_inds = max_score.topk(self.nms_pre) bbox_pred = bbox_pred.gather(topk_inds) - anchor = anchor.gather(topk_inds) + anchor = anchor.gather(topk_inds) cls_score = cls_score.gather(topk_inds) - bbox_pred = self.bbox_coder.decode( - anchor, bbox_pred, max_shape=im_shape) - bbox_pred = bbox_pred.squeeze() + bbox_pred = delta2bbox(bbox_pred, anchor, self.weights).squeeze() mlvl_bboxes.append(bbox_pred) mlvl_scores.append(F.sigmoid(cls_score)) mlvl_bboxes = paddle.concat(mlvl_bboxes) @@ -227,18 +221,15 @@ class RetinaHead(nn.Layer): mlvl_scores = mlvl_scores.transpose([1, 0]) return mlvl_bboxes, mlvl_scores - def decode(self, anchors, cls_scores, bbox_preds, im_shape, scale_factor): + def decode(self, anchors, cls_logits, bboxes_reg, im_shape, scale_factor): batch_bboxes = [] batch_scores = [] - for img_id in range(cls_scores[0].shape[0]): - num_lvls = len(cls_scores) - cls_score_list = [cls_scores[i][img_id] for i in range(num_lvls)] - bbox_pred_list = [bbox_preds[i][img_id] for i in range(num_lvls)] + for img_id in range(cls_logits[0].shape[0]): + num_lvls = len(cls_logits) + cls_scores_list = [cls_logits[i][img_id] for i in range(num_lvls)] + bbox_preds_list = [bboxes_reg[i][img_id] for i in range(num_lvls)] bboxes, scores = self.get_bboxes_single( - anchors, - cls_score_list, - bbox_pred_list, - im_shape[img_id], + anchors, cls_scores_list, bbox_preds_list, im_shape[img_id], scale_factor[img_id]) batch_bboxes.append(bboxes) batch_scores.append(scores) @@ -247,11 +238,12 @@ class RetinaHead(nn.Layer): return batch_bboxes, batch_scores def post_process(self, head_outputs, im_shape, scale_factor): - cls_scores, bbox_preds = head_outputs - anchors = self.anchor_generator(cls_scores) - cls_scores = [_.transpose([0, 2, 3, 1]) for _ in cls_scores] - bbox_preds = [_.transpose([0, 2, 3, 1]) for _ in bbox_preds] - bboxes, scores = self.decode( - anchors, cls_scores, bbox_preds, im_shape, scale_factor) + cls_logits_list, bboxes_reg_list = head_outputs + anchors = self.anchor_generator(cls_logits_list) + cls_logits = [_.transpose([0, 2, 3, 1]) for _ in cls_logits_list] + bboxes_reg = [_.transpose([0, 2, 3, 1]) for _ in bboxes_reg_list] + bboxes, scores = self.decode(anchors, cls_logits, bboxes_reg, im_shape, + scale_factor) + bbox_pred, bbox_num, _ = self.nms(bboxes, scores) return bbox_pred, bbox_num diff --git a/ppdet/modeling/proposal_generator/anchor_generator.py b/ppdet/modeling/proposal_generator/anchor_generator.py index 34f03c0ef084d1976f7f6879caf3e25b1f67d7de..94fd346002562fd772a21f525f7ad4f50f4c4680 100644 --- a/ppdet/modeling/proposal_generator/anchor_generator.py +++ b/ppdet/modeling/proposal_generator/anchor_generator.py @@ -22,6 +22,8 @@ import paddle.nn as nn from ppdet.core.workspace import register +__all__ = ['AnchorGenerator', 'RetinaAnchorGenerator'] + @register class AnchorGenerator(nn.Layer): @@ -129,3 +131,25 @@ class AnchorGenerator(nn.Layer): For FPN models, `num_anchors` on every feature map is the same. """ return len(self.cell_anchors[0]) + + +@register +class RetinaAnchorGenerator(AnchorGenerator): + def __init__(self, + octave_base_scale=4, + scales_per_octave=3, + aspect_ratios=[0.5, 1.0, 2.0], + strides=[8.0, 16.0, 32.0, 64.0, 128.0], + variance=[1.0, 1.0, 1.0, 1.0], + offset=0.0): + anchor_sizes = [] + for s in strides: + anchor_sizes.append([ + s * octave_base_scale * 2**(i/scales_per_octave) \ + for i in range(scales_per_octave)]) + super(RetinaAnchorGenerator, self).__init__( + anchor_sizes=anchor_sizes, + aspect_ratios=aspect_ratios, + strides=strides, + variance=variance, + offset=offset)