未验证 提交 fa250ff1 编写于 作者: F Feng Ni 提交者: GitHub

refine RetinaNet codes (#5797)

上级 bbb72659
# Focal Loss for Dense Object Detection
## Introduction
We reproduce RetinaNet proposed in paper Focal Loss for Dense Object Detection.
# RetinaNet (Focal Loss for Dense Object Detection)
## Model Zoo
| Backbone | Model | mstrain | imgs/GPU | lr schedule | FPS | Box AP | download | config |
| ------------ | --------- | ------- | -------- | ----------- | --- | ------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------- |
| ResNet50-FPN | RetinaNet | Yes | 4 | 1x | --- | 37.5 | [model](https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_mstrain_1x_coco.pdparams)\|[log](https://bj.bcebos.com/v1/paddledet/logs/retinanet_r50_fpn_mstrain_1x_coco.log) | retinanet_r50_fpn_mstrain_1x_coco.yml |
| Backbone | Model | imgs/GPU | lr schedule | FPS | Box AP | download | config |
| ------------ | --------- | -------- | ----------- | --- | ------ | ---------- | ----------- |
| ResNet50-FPN | RetinaNet | 2 | 1x | --- | 37.5 | [model](https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_1x_coco.pdparams) | [config](./retinanet_r50_fpn_1x_coco.yml) |
**Notes:**
- All above models are trained on COCO train2017 with 4 GPUs and evaludated on val2017. Box AP=`mAP(IoU=0.5:0.95)`.
- All above models are trained on COCO train2017 with 8 GPUs and evaludated on val2017. Box AP=`mAP(IoU=0.5:0.95)`.
- Config `configs/retinanet/retinanet_r50_fpn_1x_coco.yml` is for 8 GPUs and `configs/retinanet/retinanet_r50_fpn_mstrain_1x_coco.yml` is for 4 GPUs (mind the difference of train batch size).
## Citation
......
......@@ -22,10 +22,6 @@ FPN:
use_c5: false
RetinaHead:
num_classes: 80
prior_prob: 0.01
nms_pre: 1000
decode_reg_out: false
conv_feat:
name: RetinaFeat
feat_in: 256
......@@ -44,10 +40,6 @@ RetinaHead:
positive_overlap: 0.5
negative_overlap: 0.4
allow_low_quality: true
bbox_coder:
name: DeltaBBoxCoder
norm_mean: [0.0, 0.0, 0.0, 0.0]
norm_std: [1.0, 1.0, 1.0, 1.0]
loss_class:
name: FocalLoss
gamma: 2.0
......
worker_num: 2
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
- Decode: {}
- RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: True, interp: 1}
- RandomFlip: {}
- NormalizeImage: {is_scale: True, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- PadBatch: {pad_to_stride: 32}
batch_size: 2
shuffle: true
drop_last: true
use_process: true
collate_batch: false
shuffle: True
drop_last: True
collate_batch: False
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1}
- NormalizeImage: {is_scale: True, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 2
shuffle: false
- PadBatch: {pad_to_stride: 32}
batch_size: 8
TestReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1}
- NormalizeImage: {is_scale: True, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- PadBatch: {pad_to_stride: 32}
batch_size: 1
shuffle: false
......@@ -7,4 +7,3 @@ _BASE_: [
]
weights: output/retinanet_r50_fpn_1x_coco/model_final
find_unused_parameters: true
\ No newline at end of file
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/retinanet_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/retinanet_reader.yml'
]
worker_num: 4
TrainReader:
batch_size: 4
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: true, interp: 1}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
weights: output/retinanet_r50_fpn_mstrain_1x_coco/model_final
find_unused_parameters: true
\ No newline at end of file
......@@ -29,7 +29,6 @@ from . import reid
from . import mot
from . import transformers
from . import assigners
from . import coders
from .ops import *
from .backbones import *
......@@ -44,4 +43,3 @@ from .reid import *
from .mot import *
from .transformers import *
from .assigners import *
from .coders import *
......@@ -22,14 +22,12 @@ import paddle
__all__ = ['RetinaNet']
@register
class RetinaNet(BaseArch):
__category__ = 'architecture'
def __init__(self,
backbone,
neck,
head):
def __init__(self, backbone, neck, head):
super(RetinaNet, self).__init__()
self.backbone = backbone
self.neck = neck
......@@ -38,35 +36,33 @@ class RetinaNet(BaseArch):
@classmethod
def from_config(cls, cfg, *args, **kwargs):
backbone = create(cfg['backbone'])
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
head = create(cfg['head'])
kwargs = {'input_shape': neck.out_shape}
head = create(cfg['head'], **kwargs)
return {
'backbone': backbone,
'neck': neck,
'head': head}
'head': head,
}
def _forward(self):
body_feats = self.backbone(self.inputs)
neck_feats = self.neck(body_feats)
head_outs = self.head(neck_feats)
if not self.training:
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor']
bboxes, bbox_num = self.head.post_process(head_outs, im_shape, scale_factor)
return bboxes, bbox_num
return head_outs
if self.training:
return self.head(neck_feats, self.inputs)
else:
head_outs = self.head(neck_feats)
bbox, bbox_num = self.head.post_process(
head_outs, self.inputs['im_shape'], self.inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
def get_loss(self):
loss = dict()
head_outs = self._forward()
loss_retina = self.head.get_loss(head_outs, self.inputs)
loss.update(loss_retina)
total_loss = paddle.add_n(list(loss.values()))
loss.update(loss=total_loss)
return loss
return self._forward()
def get_pred(self):
bbox_pred, bbox_num = self._forward()
output = dict(bbox=bbox_pred, bbox_num=bbox_num)
return output
return self._forward()
from .delta_bbox_coder import DeltaBBoxCoder
import paddle
import numpy as np
from ppdet.core.workspace import register
from ppdet.modeling.bbox_utils import delta2bbox_v2, bbox2delta_v2
__all__ = ['DeltaBBoxCoder']
@register
class DeltaBBoxCoder:
"""Encode bboxes in terms of delta/offset of a reference bbox.
Args:
norm_mean (list[float]): the mean to normalize delta
norm_std (list[float]): the std to normalize delta
wh_ratio_clip (float): to clip delta wh of decoded bboxes
ctr_clip (float or None): whether to clip delta xy of decoded bboxes
"""
def __init__(self,
norm_mean=[0.0, 0.0, 0.0, 0.0],
norm_std=[1., 1., 1., 1.],
wh_ratio_clip=16/1000.0,
ctr_clip=None):
self.norm_mean = norm_mean
self.norm_std = norm_std
self.wh_ratio_clip = wh_ratio_clip
self.ctr_clip = ctr_clip
def encode(self, bboxes, tar_bboxes):
return bbox2delta_v2(
bboxes, tar_bboxes, means=self.norm_mean, stds=self.norm_std)
def decode(self, bboxes, deltas, max_shape=None):
return delta2bbox_v2(
bboxes,
deltas,
max_shape=max_shape,
wh_ratio_clip=self.wh_ratio_clip,
ctr_clip=self.ctr_clip,
means=self.norm_mean,
stds=self.norm_std)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,17 +16,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math, paddle
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Normal, Constant
from ppdet.modeling.proposal_generator import AnchorGenerator
from ppdet.core.workspace import register
from ppdet.modeling.bbox_utils import bbox2delta, delta2bbox
from ppdet.modeling.heads.fcos_head import FCOSFeat
from ppdet.core.workspace import register
__all__ = ['RetinaHead']
@register
class RetinaFeat(FCOSFeat):
"""We use FCOSFeat to construct conv layers in RetinaNet.
......@@ -34,72 +37,49 @@ class RetinaFeat(FCOSFeat):
"""
pass
@register
class RetinaAnchorGenerator(AnchorGenerator):
def __init__(self,
octave_base_scale=4,
scales_per_octave=3,
aspect_ratios=[0.5, 1.0, 2.0],
strides=[8.0, 16.0, 32.0, 64.0, 128.0],
variance=[1.0, 1.0, 1.0, 1.0],
offset=0.0):
anchor_sizes = []
for s in strides:
anchor_sizes.append([
s * octave_base_scale * 2**(i/scales_per_octave) \
for i in range(scales_per_octave)])
super(RetinaAnchorGenerator, self).__init__(
anchor_sizes=anchor_sizes,
aspect_ratios=aspect_ratios,
strides=strides,
variance=variance,
offset=offset)
@register
class RetinaHead(nn.Layer):
"""Used in RetinaNet proposed in paper https://arxiv.org/pdf/1708.02002.pdf
"""
__shared__ = ['num_classes']
__inject__ = [
'conv_feat', 'anchor_generator', 'bbox_assigner',
'bbox_coder', 'loss_class', 'loss_bbox', 'nms']
'conv_feat', 'anchor_generator', 'bbox_assigner', 'loss_class',
'loss_bbox', 'nms'
]
def __init__(self,
num_classes=80,
conv_feat='RetinaFeat',
anchor_generator='RetinaAnchorGenerator',
bbox_assigner='MaxIoUAssigner',
loss_class='FocalLoss',
loss_bbox='SmoothL1Loss',
nms='MultiClassNMS',
prior_prob=0.01,
decode_reg_out=False,
conv_feat=None,
anchor_generator=None,
bbox_assigner=None,
bbox_coder=None,
loss_class=None,
loss_bbox=None,
nms_pre=1000,
nms=None):
weights=[1., 1., 1., 1.]):
super(RetinaHead, self).__init__()
self.num_classes = num_classes
self.prior_prob = prior_prob
# allow RetinaNet to use IoU based losses.
self.decode_reg_out = decode_reg_out
self.conv_feat = conv_feat
self.anchor_generator = anchor_generator
self.bbox_assigner = bbox_assigner
self.bbox_coder = bbox_coder
self.loss_class = loss_class
self.loss_bbox = loss_bbox
self.nms_pre = nms_pre
self.nms = nms
self.cls_out_channels = num_classes
self.init_layers()
self.nms_pre = nms_pre
self.weights = weights
def init_layers(self):
bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
bias_init_value = -math.log((1 - prior_prob) / prior_prob)
num_anchors = self.anchor_generator.num_anchors
self.retina_cls = nn.Conv2D(
in_channels=self.conv_feat.feat_out,
out_channels=self.cls_out_channels * num_anchors,
out_channels=self.num_classes * num_anchors,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(initializer=Normal(mean=0.0, std=0.01)),
weight_attr=ParamAttr(initializer=Normal(
mean=0.0, std=0.01)),
bias_attr=ParamAttr(initializer=Constant(value=bias_init_value)))
self.retina_reg = nn.Conv2D(
in_channels=self.conv_feat.feat_out,
......@@ -107,10 +87,11 @@ class RetinaHead(nn.Layer):
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(initializer=Normal(mean=0.0, std=0.01)),
weight_attr=ParamAttr(initializer=Normal(
mean=0.0, std=0.01)),
bias_attr=ParamAttr(initializer=Constant(value=0)))
def forward(self, neck_feats):
def forward(self, neck_feats, targets=None):
cls_logits_list = []
bboxes_reg_list = []
for neck_feat in neck_feats:
......@@ -119,33 +100,40 @@ class RetinaHead(nn.Layer):
bbox_reg = self.retina_reg(conv_reg_feat)
cls_logits_list.append(cls_logits)
bboxes_reg_list.append(bbox_reg)
return (cls_logits_list, bboxes_reg_list)
def get_loss(self, head_outputs, meta):
if self.training:
return self.get_loss([cls_logits_list, bboxes_reg_list], targets)
else:
return [cls_logits_list, bboxes_reg_list]
def get_loss(self, head_outputs, targets):
"""Here we calculate loss for a batch of images.
We assign anchors to gts in each image and gather all the assigned
postive and negative samples. Then loss is calculated on the gathered
samples.
"""
cls_logits, bboxes_reg = head_outputs
# we use the same anchor for all images
anchors = self.anchor_generator(cls_logits)
cls_logits_list, bboxes_reg_list = head_outputs
anchors = self.anchor_generator(cls_logits_list)
anchors = paddle.concat(anchors)
# matches: contain gt_inds
# match_labels: -1(ignore), 0(neg) or 1(pos)
matches_list, match_labels_list = [], []
# assign anchors to gts, no sampling is involved
for gt_bbox in meta['gt_bbox']:
for gt_bbox in targets['gt_bbox']:
matches, match_labels = self.bbox_assigner(anchors, gt_bbox)
matches_list.append(matches)
match_labels_list.append(match_labels)
# reshape network outputs
cls_logits = [_.transpose([0, 2, 3, 1]) for _ in cls_logits]
cls_logits = [_.reshape([0, -1, self.cls_out_channels]) \
for _ in cls_logits]
bboxes_reg = [_.transpose([0, 2, 3, 1]) for _ in bboxes_reg]
bboxes_reg = [_.reshape([0, -1, 4]) for _ in bboxes_reg]
cls_logits = [
_.transpose([0, 2, 3, 1]).reshape([0, -1, self.num_classes])
for _ in cls_logits_list
]
bboxes_reg = [
_.transpose([0, 2, 3, 1]).reshape([0, -1, 4])
for _ in bboxes_reg_list
]
cls_logits = paddle.concat(cls_logits, axis=1)
bboxes_reg = paddle.concat(bboxes_reg, axis=1)
......@@ -154,7 +142,7 @@ class RetinaHead(nn.Layer):
# find and gather preds and targets in each image
for matches, match_labels, cls_logit, bbox_reg, gt_bbox, gt_class in \
zip(matches_list, match_labels_list, cls_logits, bboxes_reg,
meta['gt_bbox'], meta['gt_class']):
targets['gt_bbox'], targets['gt_class']):
pos_mask = (match_labels == 1)
neg_mask = (match_labels == 0)
chosen_mask = paddle.logical_or(pos_mask, neg_mask)
......@@ -163,59 +151,65 @@ class RetinaHead(nn.Layer):
bg_class = paddle.to_tensor(
[self.num_classes], dtype=gt_class.dtype)
# a trick to assign num_classes to negative targets
gt_class = paddle.concat([gt_class, bg_class])
matches = paddle.where(
neg_mask, paddle.full_like(matches, gt_class.size-1), matches)
gt_class = paddle.concat([gt_class, bg_class], axis=-1)
matches = paddle.where(neg_mask,
paddle.full_like(matches, gt_class.size - 1),
matches)
cls_pred = cls_logit[chosen_mask]
cls_tar = gt_class[matches[chosen_mask]]
cls_tar = gt_class[matches[chosen_mask]]
reg_pred = bbox_reg[pos_mask].reshape([-1, 4])
reg_tar = gt_bbox[matches[pos_mask]].reshape([-1, 4])
if self.decode_reg_out:
reg_pred = self.bbox_coder.decode(
anchors[pos_mask], reg_pred)
else:
reg_tar = self.bbox_coder.encode(anchors[pos_mask], reg_tar)
reg_tar = bbox2delta(anchors[pos_mask], reg_tar, self.weights)
cls_pred_list.append(cls_pred)
cls_tar_list.append(cls_tar)
reg_pred_list.append(reg_pred)
reg_tar_list.append(reg_tar)
cls_pred = paddle.concat(cls_pred_list)
cls_tar = paddle.concat(cls_tar_list)
cls_tar = paddle.concat(cls_tar_list)
reg_pred = paddle.concat(reg_pred_list)
reg_tar = paddle.concat(reg_tar_list)
reg_tar = paddle.concat(reg_tar_list)
avg_factor = max(1.0, reg_pred.shape[0])
cls_loss = self.loss_class(
cls_pred, cls_tar, reduction='sum')/avg_factor
if reg_pred.size == 0:
reg_loss = bboxes_reg[0][0].sum() * 0
cls_pred, cls_tar, reduction='sum') / avg_factor
if reg_pred.shape[0] == 0:
reg_loss = paddle.zeros([1])
reg_loss.stop_gradient = False
else:
reg_loss = self.loss_bbox(
reg_pred, reg_tar, reduction='sum')/avg_factor
return dict(loss_cls=cls_loss, loss_reg=reg_loss)
reg_pred, reg_tar, reduction='sum') / avg_factor
loss = cls_loss + reg_loss
out_dict = {
'loss_cls': cls_loss,
'loss_reg': reg_loss,
'loss': loss,
}
return out_dict
def get_bboxes_single(self,
anchors,
cls_scores,
bbox_preds,
cls_scores_list,
bbox_preds_list,
im_shape,
scale_factor,
rescale=True):
assert len(cls_scores) == len(bbox_preds)
assert len(cls_scores_list) == len(bbox_preds_list)
mlvl_bboxes = []
mlvl_scores = []
for anchor, cls_score, bbox_pred in zip(anchors, cls_scores, bbox_preds):
for anchor, cls_score, bbox_pred in zip(anchors, cls_scores_list,
bbox_preds_list):
cls_score = cls_score.reshape([-1, self.num_classes])
bbox_pred = bbox_pred.reshape([-1, 4])
if self.nms_pre is not None and cls_score.shape[0] > self.nms_pre:
max_score = cls_score.max(axis=1)
_, topk_inds = max_score.topk(self.nms_pre)
bbox_pred = bbox_pred.gather(topk_inds)
anchor = anchor.gather(topk_inds)
anchor = anchor.gather(topk_inds)
cls_score = cls_score.gather(topk_inds)
bbox_pred = self.bbox_coder.decode(
anchor, bbox_pred, max_shape=im_shape)
bbox_pred = bbox_pred.squeeze()
bbox_pred = delta2bbox(bbox_pred, anchor, self.weights).squeeze()
mlvl_bboxes.append(bbox_pred)
mlvl_scores.append(F.sigmoid(cls_score))
mlvl_bboxes = paddle.concat(mlvl_bboxes)
......@@ -227,18 +221,15 @@ class RetinaHead(nn.Layer):
mlvl_scores = mlvl_scores.transpose([1, 0])
return mlvl_bboxes, mlvl_scores
def decode(self, anchors, cls_scores, bbox_preds, im_shape, scale_factor):
def decode(self, anchors, cls_logits, bboxes_reg, im_shape, scale_factor):
batch_bboxes = []
batch_scores = []
for img_id in range(cls_scores[0].shape[0]):
num_lvls = len(cls_scores)
cls_score_list = [cls_scores[i][img_id] for i in range(num_lvls)]
bbox_pred_list = [bbox_preds[i][img_id] for i in range(num_lvls)]
for img_id in range(cls_logits[0].shape[0]):
num_lvls = len(cls_logits)
cls_scores_list = [cls_logits[i][img_id] for i in range(num_lvls)]
bbox_preds_list = [bboxes_reg[i][img_id] for i in range(num_lvls)]
bboxes, scores = self.get_bboxes_single(
anchors,
cls_score_list,
bbox_pred_list,
im_shape[img_id],
anchors, cls_scores_list, bbox_preds_list, im_shape[img_id],
scale_factor[img_id])
batch_bboxes.append(bboxes)
batch_scores.append(scores)
......@@ -247,11 +238,12 @@ class RetinaHead(nn.Layer):
return batch_bboxes, batch_scores
def post_process(self, head_outputs, im_shape, scale_factor):
cls_scores, bbox_preds = head_outputs
anchors = self.anchor_generator(cls_scores)
cls_scores = [_.transpose([0, 2, 3, 1]) for _ in cls_scores]
bbox_preds = [_.transpose([0, 2, 3, 1]) for _ in bbox_preds]
bboxes, scores = self.decode(
anchors, cls_scores, bbox_preds, im_shape, scale_factor)
cls_logits_list, bboxes_reg_list = head_outputs
anchors = self.anchor_generator(cls_logits_list)
cls_logits = [_.transpose([0, 2, 3, 1]) for _ in cls_logits_list]
bboxes_reg = [_.transpose([0, 2, 3, 1]) for _ in bboxes_reg_list]
bboxes, scores = self.decode(anchors, cls_logits, bboxes_reg, im_shape,
scale_factor)
bbox_pred, bbox_num, _ = self.nms(bboxes, scores)
return bbox_pred, bbox_num
......@@ -22,6 +22,8 @@ import paddle.nn as nn
from ppdet.core.workspace import register
__all__ = ['AnchorGenerator', 'RetinaAnchorGenerator']
@register
class AnchorGenerator(nn.Layer):
......@@ -129,3 +131,25 @@ class AnchorGenerator(nn.Layer):
For FPN models, `num_anchors` on every feature map is the same.
"""
return len(self.cell_anchors[0])
@register
class RetinaAnchorGenerator(AnchorGenerator):
def __init__(self,
octave_base_scale=4,
scales_per_octave=3,
aspect_ratios=[0.5, 1.0, 2.0],
strides=[8.0, 16.0, 32.0, 64.0, 128.0],
variance=[1.0, 1.0, 1.0, 1.0],
offset=0.0):
anchor_sizes = []
for s in strides:
anchor_sizes.append([
s * octave_base_scale * 2**(i/scales_per_octave) \
for i in range(scales_per_octave)])
super(RetinaAnchorGenerator, self).__init__(
anchor_sizes=anchor_sizes,
aspect_ratios=aspect_ratios,
strides=strides,
variance=variance,
offset=offset)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册