未验证 提交 ae621055 编写于 作者: B Blake 提交者: GitHub

add implementation of RetinaNet (#5140)

* add implementation of RetinaNet

* * add README.md and model zoo
* rename FOCSFeat -> RetianFeat
* add mstrain model to model zoo
* refactor DeltaBBoxCoder

* update link for model and log
上级 1dcec15b
# Focal Loss for Dense Object Detection
## Introduction
We reproduce RetinaNet proposed in paper Focal Loss for Dense Object Detection.
## Model Zoo
| Backbone | Model | mstrain | imgs/GPU | lr schedule | FPS | Box AP | download | config |
| ------------ | --------- | ------- | -------- | ----------- | --- | ------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------- |
| ResNet50-FPN | RetinaNet | Yes | 4 | 1x | --- | 37.5 | [model](https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_mstrain_1x_coco.pdparams)\|[log](https://bj.bcebos.com/v1/paddledet/logs/retinanet_r50_fpn_mstrain_1x_coco.log) | retinanet_r50_fpn_mstrain_1x_coco.yml |
**Notes:**
- All above models are trained on COCO train2017 with 4 GPUs and evaludated on val2017. Box AP=`mAP(IoU=0.5:0.95)`.
- Config `configs/retinanet/retinanet_r50_fpn_1x_coco.yml` is for 8 GPUs and `configs/retinanet/retinanet_r50_fpn_mstrain_1x_coco.yml` is for 4 GPUs (mind the difference of train batch size).
## Citation
```latex
@inproceedings{lin2017focal,
title={Focal loss for dense object detection},
author={Lin, Tsung-Yi and Goyal, Priya and Girshick, Ross and He, Kaiming and Doll{\'a}r, Piotr},
booktitle={Proceedings of the IEEE international conference on computer vision},
year={2017}
}
```
epoch: 12
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [8, 11]
- !LinearWarmup
start_factor: 0.001
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
architecture: RetinaNet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
RetinaNet:
backbone: ResNet
neck: FPN
head: RetinaHead
ResNet:
depth: 50
variant: b
norm_type: bn
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
FPN:
out_channel: 256
spatial_scales: [0.125, 0.0625, 0.03125]
extra_stage: 2
has_extra_convs: true
use_c5: false
RetinaHead:
num_classes: 80
prior_prob: 0.01
nms_pre: 1000
decode_reg_out: false
conv_feat:
name: RetinaFeat
feat_in: 256
feat_out: 256
num_convs: 4
norm_type: null
use_dcn: false
anchor_generator:
name: RetinaAnchorGenerator
octave_base_scale: 4
scales_per_octave: 3
aspect_ratios: [0.5, 1.0, 2.0]
strides: [8.0, 16.0, 32.0, 64.0, 128.0]
bbox_assigner:
name: MaxIoUAssigner
positive_overlap: 0.5
negative_overlap: 0.4
allow_low_quality: true
bbox_coder:
name: DeltaBBoxCoder
norm_mean: [0.0, 0.0, 0.0, 0.0]
norm_std: [1.0, 1.0, 1.0, 1.0]
loss_class:
name: FocalLoss
gamma: 2.0
alpha: 0.25
loss_weight: 1.0
loss_bbox:
name: SmoothL1Loss
beta: 0.0
loss_weight: 1.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.05
nms_threshold: 0.5
worker_num: 2
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 2
shuffle: true
drop_last: true
use_process: true
collate_batch: false
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 2
shuffle: false
TestReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
shuffle: false
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/retinanet_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/retinanet_reader.yml'
]
weights: output/retinanet_r50_fpn_1x_coco/model_final
find_unused_parameters: true
\ No newline at end of file
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/retinanet_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/retinanet_reader.yml'
]
worker_num: 4
TrainReader:
batch_size: 4
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: true, interp: 1}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
weights: output/retinanet_r50_fpn_mstrain_1x_coco/model_final
find_unused_parameters: true
\ No newline at end of file
......@@ -29,6 +29,7 @@ from . import reid
from . import mot
from . import transformers
from . import assigners
from . import coders
from .ops import *
from .backbones import *
......@@ -43,3 +44,4 @@ from .reid import *
from .mot import *
from .transformers import *
from .assigners import *
from .coders import *
......@@ -26,6 +26,7 @@ from . import picodet
from . import detr
from . import sparse_rcnn
from . import tood
from . import retinanet
from .meta_arch import *
from .faster_rcnn import *
......@@ -49,3 +50,4 @@ from .picodet import *
from .detr import *
from .sparse_rcnn import *
from .tood import *
from .retinanet import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
import paddle
__all__ = ['RetinaNet']
@register
class RetinaNet(BaseArch):
__category__ = 'architecture'
def __init__(self,
backbone,
neck,
head):
super(RetinaNet, self).__init__()
self.backbone = backbone
self.neck = neck
self.head = head
@classmethod
def from_config(cls, cfg, *args, **kwargs):
backbone = create(cfg['backbone'])
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
head = create(cfg['head'])
return {
'backbone': backbone,
'neck': neck,
'head': head}
def _forward(self):
body_feats = self.backbone(self.inputs)
neck_feats = self.neck(body_feats)
head_outs = self.head(neck_feats)
if not self.training:
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor']
bboxes, bbox_num = self.head.post_process(head_outs, im_shape, scale_factor)
return bboxes, bbox_num
return head_outs
def get_loss(self):
loss = dict()
head_outs = self._forward()
loss_retina = self.head.get_loss(head_outs, self.inputs)
loss.update(loss_retina)
total_loss = paddle.add_n(list(loss.values()))
loss.update(loss=total_loss)
return loss
def get_pred(self):
bbox_pred, bbox_num = self._forward()
output = dict(bbox=bbox_pred, bbox_num=bbox_num)
return output
......@@ -16,8 +16,10 @@ from . import utils
from . import task_aligned_assigner
from . import atss_assigner
from . import simota_assigner
from . import max_iou_assigner
from .utils import *
from .task_aligned_assigner import *
from .atss_assigner import *
from .simota_assigner import *
from .max_iou_assigner import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ppdet.core.workspace import register
from ppdet.modeling.proposal_generator.target import label_box
__all__ = ['MaxIoUAssigner']
@register
class MaxIoUAssigner(object):
"""a standard bbox assigner based on max IoU, use ppdet's label_box
as backend.
Args:
positive_overlap (float): threshold for defining positive samples
negative_overlap (float): threshold for denining negative samples
allow_low_quality (bool): whether to lower IoU thr if a GT poorly
overlaps with candidate bboxes
"""
def __init__(self,
positive_overlap,
negative_overlap,
allow_low_quality=True):
self.positive_overlap = positive_overlap
self.negative_overlap = negative_overlap
self.allow_low_quality = allow_low_quality
def __call__(self, bboxes, gt_bboxes):
matches, match_labels = label_box(
bboxes,
gt_bboxes,
positive_overlap=self.positive_overlap,
negative_overlap=self.negative_overlap,
allow_low_quality=self.allow_low_quality,
ignore_thresh=-1,
is_crowd=None,
assign_on_cpu=False)
return matches, match_labels
......@@ -775,3 +775,93 @@ def batch_distance2bbox(points, distance, max_shapes=None):
out_bbox = paddle.where(out_bbox > 0, out_bbox,
paddle.zeros_like(out_bbox))
return out_bbox
def delta2bbox_v2(rois,
deltas,
means=(0.0, 0.0, 0.0, 0.0),
stds=(1.0, 1.0, 1.0, 1.0),
max_shape=None,
wh_ratio_clip=16.0/1000.0,
ctr_clip=None):
"""Transform network output(delta) to bboxes.
Based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/
bbox/coder/delta_xywh_bbox_coder.py
Args:
rois (Tensor): shape [..., 4], base bboxes, typical examples include
anchor and rois
deltas (Tensor): shape [..., 4], offset relative to base bboxes
means (list[float]): the mean that was used to normalize deltas,
must be of size 4
stds (list[float]): the std that was used to normalize deltas,
must be of size 4
max_shape (list[float] or None): height and width of image, will be
used to clip bboxes if not None
wh_ratio_clip (float): to clip delta wh of decoded bboxes
ctr_clip (float or None): whether to clip delta xy of decoded bboxes
"""
if rois.size == 0:
return paddle.empty_like(rois)
means = paddle.to_tensor(means)
stds = paddle.to_tensor(stds)
deltas = deltas * stds + means
dxy = deltas[..., :2]
dwh = deltas[..., 2:]
pxy = (rois[..., :2] + rois[..., 2:]) * 0.5
pwh = rois[..., 2:] - rois[..., :2]
dxy_wh = pwh * dxy
max_ratio = np.abs(np.log(wh_ratio_clip))
if ctr_clip is not None:
dxy_wh = paddle.clip(dxy_wh, max=ctr_clip, min=-ctr_clip)
dwh = paddle.clip(dwh, max=max_ratio)
else:
dwh = dwh.clip(min=-max_ratio, max=max_ratio)
gxy = pxy + dxy_wh
gwh = pwh * dwh.exp()
x1y1 = gxy - (gwh * 0.5)
x2y2 = gxy + (gwh * 0.5)
bboxes = paddle.concat([x1y1, x2y2], axis=-1)
if max_shape is not None:
bboxes[..., 0::2] = bboxes[..., 0::2].clip(min=0, max=max_shape[1])
bboxes[..., 1::2] = bboxes[..., 1::2].clip(min=0, max=max_shape[0])
return bboxes
def bbox2delta_v2(src_boxes,
tgt_boxes,
means=(0.0, 0.0, 0.0, 0.0),
stds=(1.0, 1.0, 1.0, 1.0)):
"""Encode bboxes to deltas.
Modified from ppdet.modeling.bbox_utils.bbox2delta.
Args:
src_boxes (Tensor[..., 4]): base bboxes
tgt_boxes (Tensor[..., 4]): target bboxes
means (list[float]): the mean that will be used to normalize delta
stds (list[float]): the std that will be used to normalize delta
"""
if src_boxes.size == 0:
return paddle.empty_like(src_boxes)
src_w = src_boxes[..., 2] - src_boxes[..., 0]
src_h = src_boxes[..., 3] - src_boxes[..., 1]
src_ctr_x = src_boxes[..., 0] + 0.5 * src_w
src_ctr_y = src_boxes[..., 1] + 0.5 * src_h
tgt_w = tgt_boxes[..., 2] - tgt_boxes[..., 0]
tgt_h = tgt_boxes[..., 3] - tgt_boxes[..., 1]
tgt_ctr_x = tgt_boxes[..., 0] + 0.5 * tgt_w
tgt_ctr_y = tgt_boxes[..., 1] + 0.5 * tgt_h
dx = (tgt_ctr_x - src_ctr_x) / src_w
dy = (tgt_ctr_y - src_ctr_y) / src_h
dw = paddle.log(tgt_w / src_w)
dh = paddle.log(tgt_h / src_h)
deltas = paddle.stack((dx, dy, dw, dh), axis=1) # [n, 4]
means = paddle.to_tensor(means, place=src_boxes.place)
stds = paddle.to_tensor(stds, place=src_boxes.place)
deltas = (deltas - means) / stds
return deltas
from .delta_bbox_coder import DeltaBBoxCoder
import paddle
import numpy as np
from ppdet.core.workspace import register
from ppdet.modeling.bbox_utils import delta2bbox_v2, bbox2delta_v2
__all__ = ['DeltaBBoxCoder']
@register
class DeltaBBoxCoder:
"""Encode bboxes in terms of delta/offset of a reference bbox.
Args:
norm_mean (list[float]): the mean to normalize delta
norm_std (list[float]): the std to normalize delta
wh_ratio_clip (float): to clip delta wh of decoded bboxes
ctr_clip (float or None): whether to clip delta xy of decoded bboxes
"""
def __init__(self,
norm_mean=[0.0, 0.0, 0.0, 0.0],
norm_std=[1., 1., 1., 1.],
wh_ratio_clip=16/1000.0,
ctr_clip=None):
self.norm_mean = norm_mean
self.norm_std = norm_std
self.wh_ratio_clip = wh_ratio_clip
self.ctr_clip = ctr_clip
def encode(self, bboxes, tar_bboxes):
return bbox2delta_v2(
bboxes, tar_bboxes, means=self.norm_mean, stds=self.norm_std)
def decode(self, bboxes, deltas, max_shape=None):
return delta2bbox_v2(
bboxes,
deltas,
max_shape=max_shape,
wh_ratio_clip=self.wh_ratio_clip,
ctr_clip=self.ctr_clip,
means=self.norm_mean,
stds=self.norm_std)
......@@ -31,6 +31,7 @@ from . import pico_head
from . import detr_head
from . import sparsercnn_head
from . import tood_head
from . import retina_head
from .bbox_head import *
from .mask_head import *
......@@ -51,3 +52,4 @@ from .pico_head import *
from .detr_head import *
from .sparsercnn_head import *
from .tood_head import *
from .retina_head import *
......@@ -64,6 +64,8 @@ class FCOSFeat(nn.Layer):
norm_type='bn',
use_dcn=False):
super(FCOSFeat, self).__init__()
self.feat_in = feat_in
self.feat_out = feat_out
self.num_convs = num_convs
self.norm_type = norm_type
self.cls_subnet_convs = []
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math, paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Normal, Constant
from ppdet.modeling.proposal_generator import AnchorGenerator
from ppdet.core.workspace import register
from ppdet.modeling.heads.fcos_head import FCOSFeat
__all__ = ['RetinaHead']
@register
class RetinaFeat(FCOSFeat):
"""We use FCOSFeat to construct conv layers in RetinaNet.
We rename FCOSFeat to RetinaFeat to avoid confusion.
"""
pass
@register
class RetinaAnchorGenerator(AnchorGenerator):
def __init__(self,
octave_base_scale=4,
scales_per_octave=3,
aspect_ratios=[0.5, 1.0, 2.0],
strides=[8.0, 16.0, 32.0, 64.0, 128.0],
variance=[1.0, 1.0, 1.0, 1.0],
offset=0.0):
anchor_sizes = []
for s in strides:
anchor_sizes.append([
s * octave_base_scale * 2**(i/scales_per_octave) \
for i in range(scales_per_octave)])
super(RetinaAnchorGenerator, self).__init__(
anchor_sizes=anchor_sizes,
aspect_ratios=aspect_ratios,
strides=strides,
variance=variance,
offset=offset)
@register
class RetinaHead(nn.Layer):
"""Used in RetinaNet proposed in paper https://arxiv.org/pdf/1708.02002.pdf
"""
__inject__ = [
'conv_feat', 'anchor_generator', 'bbox_assigner',
'bbox_coder', 'loss_class', 'loss_bbox', 'nms']
def __init__(self,
num_classes=80,
prior_prob=0.01,
decode_reg_out=False,
conv_feat=None,
anchor_generator=None,
bbox_assigner=None,
bbox_coder=None,
loss_class=None,
loss_bbox=None,
nms_pre=1000,
nms=None):
super(RetinaHead, self).__init__()
self.num_classes = num_classes
self.prior_prob = prior_prob
# allow RetinaNet to use IoU based losses.
self.decode_reg_out = decode_reg_out
self.conv_feat = conv_feat
self.anchor_generator = anchor_generator
self.bbox_assigner = bbox_assigner
self.bbox_coder = bbox_coder
self.loss_class = loss_class
self.loss_bbox = loss_bbox
self.nms_pre = nms_pre
self.nms = nms
self.cls_out_channels = num_classes
self.init_layers()
def init_layers(self):
bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
num_anchors = self.anchor_generator.num_anchors
self.retina_cls = nn.Conv2D(
in_channels=self.conv_feat.feat_out,
out_channels=self.cls_out_channels * num_anchors,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(initializer=Normal(mean=0.0, std=0.01)),
bias_attr=ParamAttr(initializer=Constant(value=bias_init_value)))
self.retina_reg = nn.Conv2D(
in_channels=self.conv_feat.feat_out,
out_channels=4 * num_anchors,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(initializer=Normal(mean=0.0, std=0.01)),
bias_attr=ParamAttr(initializer=Constant(value=0)))
def forward(self, neck_feats):
cls_logits_list = []
bboxes_reg_list = []
for neck_feat in neck_feats:
conv_cls_feat, conv_reg_feat = self.conv_feat(neck_feat)
cls_logits = self.retina_cls(conv_cls_feat)
bbox_reg = self.retina_reg(conv_reg_feat)
cls_logits_list.append(cls_logits)
bboxes_reg_list.append(bbox_reg)
return (cls_logits_list, bboxes_reg_list)
def get_loss(self, head_outputs, meta):
"""Here we calculate loss for a batch of images.
We assign anchors to gts in each image and gather all the assigned
postive and negative samples. Then loss is calculated on the gathered
samples.
"""
cls_logits, bboxes_reg = head_outputs
# we use the same anchor for all images
anchors = self.anchor_generator(cls_logits)
anchors = paddle.concat(anchors)
# matches: contain gt_inds
# match_labels: -1(ignore), 0(neg) or 1(pos)
matches_list, match_labels_list = [], []
# assign anchors to gts, no sampling is involved
for gt_bbox in meta['gt_bbox']:
matches, match_labels = self.bbox_assigner(anchors, gt_bbox)
matches_list.append(matches)
match_labels_list.append(match_labels)
# reshape network outputs
cls_logits = [_.transpose([0, 2, 3, 1]) for _ in cls_logits]
cls_logits = [_.reshape([0, -1, self.cls_out_channels]) \
for _ in cls_logits]
bboxes_reg = [_.transpose([0, 2, 3, 1]) for _ in bboxes_reg]
bboxes_reg = [_.reshape([0, -1, 4]) for _ in bboxes_reg]
cls_logits = paddle.concat(cls_logits, axis=1)
bboxes_reg = paddle.concat(bboxes_reg, axis=1)
cls_pred_list, cls_tar_list = [], []
reg_pred_list, reg_tar_list = [], []
# find and gather preds and targets in each image
for matches, match_labels, cls_logit, bbox_reg, gt_bbox, gt_class in \
zip(matches_list, match_labels_list, cls_logits, bboxes_reg,
meta['gt_bbox'], meta['gt_class']):
pos_mask = (match_labels == 1)
neg_mask = (match_labels == 0)
chosen_mask = paddle.logical_or(pos_mask, neg_mask)
gt_class = gt_class.reshape([-1])
bg_class = paddle.to_tensor(
[self.num_classes], dtype=gt_class.dtype)
# a trick to assign num_classes to negative targets
gt_class = paddle.concat([gt_class, bg_class])
matches = paddle.where(
neg_mask, paddle.full_like(matches, gt_class.size-1), matches)
cls_pred = cls_logit[chosen_mask]
cls_tar = gt_class[matches[chosen_mask]]
reg_pred = bbox_reg[pos_mask].reshape([-1, 4])
reg_tar = gt_bbox[matches[pos_mask]].reshape([-1, 4])
if self.decode_reg_out:
reg_pred = self.bbox_coder.decode(
anchors[pos_mask], reg_pred)
else:
reg_tar = self.bbox_coder.encode(anchors[pos_mask], reg_tar)
cls_pred_list.append(cls_pred)
cls_tar_list.append(cls_tar)
reg_pred_list.append(reg_pred)
reg_tar_list.append(reg_tar)
cls_pred = paddle.concat(cls_pred_list)
cls_tar = paddle.concat(cls_tar_list)
reg_pred = paddle.concat(reg_pred_list)
reg_tar = paddle.concat(reg_tar_list)
avg_factor = max(1.0, reg_pred.shape[0])
cls_loss = self.loss_class(
cls_pred, cls_tar, reduction='sum')/avg_factor
if reg_pred.size == 0:
reg_loss = bboxes_reg[0][0].sum() * 0
else:
reg_loss = self.loss_bbox(
reg_pred, reg_tar, reduction='sum')/avg_factor
return dict(loss_cls=cls_loss, loss_reg=reg_loss)
def get_bboxes_single(self,
anchors,
cls_scores,
bbox_preds,
im_shape,
scale_factor,
rescale=True):
assert len(cls_scores) == len(bbox_preds)
mlvl_bboxes = []
mlvl_scores = []
for anchor, cls_score, bbox_pred in zip(anchors, cls_scores, bbox_preds):
cls_score = cls_score.reshape([-1, self.num_classes])
bbox_pred = bbox_pred.reshape([-1, 4])
if self.nms_pre is not None and cls_score.shape[0] > self.nms_pre:
max_score = cls_score.max(axis=1)
_, topk_inds = max_score.topk(self.nms_pre)
bbox_pred = bbox_pred.gather(topk_inds)
anchor = anchor.gather(topk_inds)
cls_score = cls_score.gather(topk_inds)
bbox_pred = self.bbox_coder.decode(
anchor, bbox_pred, max_shape=im_shape)
bbox_pred = bbox_pred.squeeze()
mlvl_bboxes.append(bbox_pred)
mlvl_scores.append(F.sigmoid(cls_score))
mlvl_bboxes = paddle.concat(mlvl_bboxes)
mlvl_bboxes = paddle.squeeze(mlvl_bboxes)
if rescale:
mlvl_bboxes = mlvl_bboxes / paddle.concat(
[scale_factor[::-1], scale_factor[::-1]])
mlvl_scores = paddle.concat(mlvl_scores)
mlvl_scores = mlvl_scores.transpose([1, 0])
return mlvl_bboxes, mlvl_scores
def decode(self, anchors, cls_scores, bbox_preds, im_shape, scale_factor):
batch_bboxes = []
batch_scores = []
for img_id in range(cls_scores[0].shape[0]):
num_lvls = len(cls_scores)
cls_score_list = [cls_scores[i][img_id] for i in range(num_lvls)]
bbox_pred_list = [bbox_preds[i][img_id] for i in range(num_lvls)]
bboxes, scores = self.get_bboxes_single(
anchors,
cls_score_list,
bbox_pred_list,
im_shape[img_id],
scale_factor[img_id])
batch_bboxes.append(bboxes)
batch_scores.append(scores)
batch_bboxes = paddle.stack(batch_bboxes, axis=0)
batch_scores = paddle.stack(batch_scores, axis=0)
return batch_bboxes, batch_scores
def post_process(self, head_outputs, im_shape, scale_factor):
cls_scores, bbox_preds = head_outputs
anchors = self.anchor_generator(cls_scores)
cls_scores = [_.transpose([0, 2, 3, 1]) for _ in cls_scores]
bbox_preds = [_.transpose([0, 2, 3, 1]) for _ in bbox_preds]
bboxes, scores = self.decode(
anchors, cls_scores, bbox_preds, im_shape, scale_factor)
bbox_pred, bbox_num, _ = self.nms(bboxes, scores)
return bbox_pred, bbox_num
......@@ -128,7 +128,7 @@ class ConvNormLayer(nn.Layer):
dcn_lr_scale=2.,
dcn_regularizer=L2Decay(0.)):
super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn', 'gn']
assert norm_type in ['bn', 'sync_bn', 'gn', None]
if bias_on:
bias_attr = ParamAttr(
......@@ -183,9 +183,12 @@ class ConvNormLayer(nn.Layer):
num_channels=ch_out,
weight_attr=param_attr,
bias_attr=bias_attr)
else:
self.norm = None
def forward(self, inputs):
out = self.conv(inputs)
if self.norm is not None:
out = self.norm(out)
return out
......
......@@ -25,6 +25,8 @@ from . import fairmot_loss
from . import gfocal_loss
from . import detr_loss
from . import sparsercnn_loss
from . import focal_loss
from . import smooth_l1_loss
from .yolo_loss import *
from .iou_aware_loss import *
......@@ -39,3 +41,5 @@ from .fairmot_loss import *
from .gfocal_loss import *
from .detr_loss import *
from .sparsercnn_loss import *
from .focal_loss import *
from .smooth_l1_loss import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn.functional as F
import paddle.nn as nn
from ppdet.core.workspace import register
__all__ = ['FocalLoss']
@register
class FocalLoss(nn.Layer):
"""A wrapper around paddle.nn.functional.sigmoid_focal_loss.
Args:
use_sigmoid (bool): currently only support use_sigmoid=True
alpha (float): parameter alpha in Focal Loss
gamma (float): parameter gamma in Focal Loss
loss_weight (float): final loss will be multiplied by this
"""
def __init__(self,
use_sigmoid=True,
alpha=0.25,
gamma=2.0,
loss_weight=1.0):
super(FocalLoss, self).__init__()
assert use_sigmoid == True, \
'Focal Loss only supports sigmoid at the moment'
self.use_sigmoid = use_sigmoid
self.alpha = alpha
self.gamma = gamma
self.loss_weight = loss_weight
def forward(self, pred, target, reduction='none'):
"""forward function.
Args:
pred (Tensor): logits of class prediction, of shape (N, num_classes)
target (Tensor): target class label, of shape (N, )
reduction (str): the way to reduce loss, one of (none, sum, mean)
"""
num_classes = pred.shape[1]
target = F.one_hot(target, num_classes+1).cast(pred.dtype)
target = target[:, :-1].detach()
loss = F.sigmoid_focal_loss(
pred, target, alpha=self.alpha, gamma=self.gamma,
reduction=reduction)
return loss * self.loss_weight
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
__all__ = ['SmoothL1Loss']
@register
class SmoothL1Loss(nn.Layer):
"""Smooth L1 Loss.
Args:
beta (float): controls smooth region, it becomes L1 Loss when beta=0.0
loss_weight (float): the final loss will be multiplied by this
"""
def __init__(self,
beta=1.0,
loss_weight=1.0):
super(SmoothL1Loss, self).__init__()
assert beta >= 0
self.beta = beta
self.loss_weight = loss_weight
def forward(self, pred, target, reduction='none'):
"""forward function, based on fvcore.
Args:
pred (Tensor): prediction tensor
target (Tensor): target tensor, pred.shape must be the same as target.shape
reduction (str): the way to reduce loss, one of (none, sum, mean)
"""
assert reduction in ('none', 'sum', 'mean')
target = target.detach()
if self.beta < 1e-5:
loss = paddle.abs(pred - target)
else:
n = paddle.abs(pred - target)
cond = n < self.beta
loss = paddle.where(cond, 0.5 * n ** 2 / self.beta, n - 0.5 * self.beta)
if reduction == 'mean':
loss = loss.mean() if loss.size > 0 else 0.0 * loss.sum()
elif reduction == 'sum':
loss = loss.sum()
return loss * self.loss_weight
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册