未验证 提交 3dafbd94 编写于 作者: F Feng Ni 提交者: GitHub

[Dygraph] add FCOS (#2012)

* fcos infer

* fcos training

* fix stgraph fcosloss

* fcos doc

* fix dy_fcos

* add comment

* train eval deploy fcos

* remove fcos config, add comment

* fix iter_tic
Co-authored-by: Nnemonameless <nemonameless@qq.com>
Co-authored-by: Ngeorgenemo <nifeng@pku.edu.cn>
上级 5afa83e5
......@@ -10,6 +10,7 @@
- Cascade RCNN
- YOLOv3
- SSD
- FCOS
- SOLOv2
扩展特性:
......
# FCOS for Object Detection
## Introduction
FCOS (Fully Convolutional One-Stage Object Detection) is a fast anchor-free object detection framework with strong performance. We reproduced the model of the paper, and improved and optimized the accuracy of the FCOS.
**Highlights:**
- Training Time: The training time of the model of `fcos_r50_fpn_1x` on Tesla v100 with 8 GPU is only 8.5 hours.
## Model Zoo
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 |
| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| ResNet50-FPN | FCOS | 2 | 1x | ---- | 39.6 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/fcos_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/fcos/fcos_r50_fpn_1x_coco.yml) |
**Notes:**
- FCOS is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
## Citations
```
@inproceedings{tian2019fcos,
title = {{FCOS}: Fully Convolutional One-Stage Object Detection},
author = {Tian, Zhi and Shen, Chunhua and Chen, Hao and He, Tong},
booktitle = {Proc. Int. Conf. Computer Vision (ICCV)},
year = {2019}
}
```
architecture: FCOS
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
load_static_weights: True
FCOS:
backbone: ResNet
neck: FPN
fcos_head: FCOSHead
fcos_post_process: FCOSPostProcess
ResNet:
# index 0 stands for res2
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
FPN:
in_channels: [256, 512, 1024, 2048]
out_channel: 256
min_level: 1
max_level: 5
spatial_scale: [0.125, 0.0625, 0.03125]
has_extra_convs: true
use_c5: false
FCOSHead:
fcos_feat:
name: FCOSFeat
feat_in: 256
feat_out: 256
num_convs: 4
norm_type: "gn"
use_dcn: false
num_classes: 80
fpn_stride: [8, 16, 32, 64, 128]
prior_prob: 0.01
fcos_loss: FCOSLoss
norm_reg_targets: true
centerness_on_reg: true
FCOSLoss:
loss_alpha: 0.25
loss_gamma: 2.0
iou_loss_type: "giou"
reg_weights: 1.0
FCOSPostProcess:
decode:
name: FCOSBox
num_classes: 80
batch_size: 1
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.025
nms_threshold: 0.6
background_label: -1
worker_num: 2
TrainReader:
sample_transforms:
- DecodeOp: {}
- RandomFlipOp: {prob: 0.5}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true}
- PermuteOp: {}
batch_transforms:
- PadBatchOp: {pad_to_stride: 128}
- Gt2FCOSTarget:
object_sizes_boundary: [64, 128, 256, 512]
center_sampling_radius: 1.5
downsample_ratios: [8, 16, 32, 64, 128]
norm_reg_targets: True
batch_size: 2
shuffle: true
drop_last: true
EvalReader:
sample_transforms:
- DecodeOp: {}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- PermuteOp: {}
batch_transforms:
- PadBatchOp: {pad_to_stride: 128}
batch_size: 1
shuffle: false
TestReader:
sample_transforms:
- DecodeOp: {}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- PermuteOp: {}
batch_transforms:
- PadBatchOp: {pad_to_stride: 128}
batch_size: 1
shuffle: false
epoch: 12
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [8, 11]
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_BASE_: [
'../_base_/datasets/coco_detection.yml',
'../_base_/runtime.yml',
'_base_/fcos_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/fcos_reader.yml',
]
weights: output/fcos_r50_fpn_1x_coco/model_final
......@@ -33,6 +33,7 @@ SUPPORT_MODELS = {
'YOLO',
'RCNN',
'SSD',
'FCOS',
'SOLOv2',
}
......
......@@ -80,6 +80,10 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
**注意:** SSD使用4GPU训练,训练240个epoch
### FCOS
请参考[fcos](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/fcos/)
### SOLOv2
请参考[solov2](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/solov2/)
......@@ -99,7 +99,7 @@ def _load_config_with_base(file_path):
return file_cfg
WITHOUT_BACKGROUND_ARCHS = ['YOLOv3']
WITHOUT_BACKGROUND_ARCHS = ['YOLOv3', 'FCOS']
def _parse_with_background():
......
......@@ -505,6 +505,10 @@ class Gt2FCOSTargetOp(BaseOperator):
labels_by_level[lvl], newshape=[grid_h, grid_w, 1])
sample['centerness{}'.format(lvl)] = np.reshape(
ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1])
sample.pop('is_crowd')
sample.pop('gt_class')
sample.pop('gt_bbox')
return samples
......
......@@ -438,12 +438,11 @@ class Gt2FCOSTarget(BaseOperator):
"object_sizes_of_interest', and 'downsample_ratios' should have same length."
for sample in samples:
# im, gt_bbox, gt_class, gt_score = sample
# im, gt_bbox, gt_class = sample
im = sample['image']
im_info = sample['im_info']
bboxes = sample['gt_bbox']
gt_class = sample['gt_class']
gt_score = sample['gt_score']
bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * np.floor(im_info[1]) / \
np.floor(im_info[1] / im_info[2])
bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * np.floor(im_info[0]) / \
......@@ -535,6 +534,10 @@ class Gt2FCOSTarget(BaseOperator):
labels_by_level[lvl], newshape=[grid_h, grid_w, 1])
sample['centerness{}'.format(lvl)] = np.reshape(
ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1])
sample.pop('is_crowd')
sample.pop('gt_class')
sample.pop('gt_bbox')
return samples
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
......@@ -201,7 +201,7 @@ class Trainer(object):
self.status['batch_time'].update(time.time() - iter_tic)
self._compose_callback.on_step_end(self.status)
iter_tic = time.time()
self._compose_callback.on_epoch_end(self.status)
def evaluate(self):
......@@ -244,7 +244,7 @@ class Trainer(object):
clsid2catid, catid2name = get_categories(self.cfg.metric, anno_file,
with_background)
# Run Infer
# Run Infer
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
# forward
......
......@@ -11,6 +11,7 @@ from . import mask_rcnn
from . import yolo
from . import cascade_rcnn
from . import ssd
from . import fcos
from . import solov2
from .meta_arch import *
......@@ -19,4 +20,5 @@ from .mask_rcnn import *
from .yolo import *
from .cascade_rcnn import *
from .ssd import *
from .fcos import *
from .solov2 import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from ppdet.core.workspace import register
from .meta_arch import BaseArch
__all__ = ['FCOS']
@register
class FCOS(BaseArch):
__category__ = 'architecture'
__inject__ = [
'backbone',
'neck',
'fcos_head',
'fcos_post_process',
]
def __init__(self,
backbone,
neck,
fcos_head='FCOSHead',
fcos_post_process='FCOSPostProcess'):
super(FCOS, self).__init__()
self.backbone = backbone
self.neck = neck
self.fcos_head = fcos_head
self.fcos_post_process = fcos_post_process
def model_arch(self, ):
body_feats = self.backbone(self.inputs)
fpn_feats, spatial_scale = self.neck(body_feats)
self.fcos_head_outs = self.fcos_head(fpn_feats, self.training)
if not self.training:
self.bboxes = self.fcos_post_process(self.fcos_head_outs,
self.inputs['scale_factor'])
def get_loss(self, ):
loss = {}
tag_labels, tag_bboxes, tag_centerness = [], [], []
for i in range(len(self.fcos_head.fpn_stride)):
# reg_target, labels, scores, centerness
k_lbl = 'labels{}'.format(i)
if k_lbl in self.inputs:
tag_labels.append(self.inputs[k_lbl])
k_box = 'reg_target{}'.format(i)
if k_box in self.inputs:
tag_bboxes.append(self.inputs[k_box])
k_ctn = 'centerness{}'.format(i)
if k_ctn in self.inputs:
tag_centerness.append(self.inputs[k_ctn])
loss_fcos = self.fcos_head.get_loss(self.fcos_head_outs, tag_labels,
tag_bboxes, tag_centerness)
loss.update(loss_fcos)
total_loss = paddle.add_n(list(loss.values()))
loss.update({'loss': total_loss})
return loss
def get_pred(self):
bbox, bbox_num = self.bboxes
output = {
'bbox': bbox,
'bbox_num': bbox_num,
}
return output
......@@ -18,6 +18,7 @@ from . import mask_head
from . import yolo_head
from . import roi_extractor
from . import ssd_head
from . import fcos_head
from . import solov2_head
from .rpn_head import *
......@@ -26,4 +27,5 @@ from .mask_head import *
from .yolo_head import *
from .roi_extractor import *
from .ssd_head import *
from .fcos_head import *
from .solov2_head import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Normal, Constant
from ppdet.core.workspace import register
from ppdet.modeling.layers import ConvNormLayer
class ScaleReg(nn.Layer):
def __init__(self):
super(ScaleReg, self).__init__()
self.scale_reg = self.create_parameter(
shape=[1],
attr=ParamAttr(initializer=Constant(value=1.)),
dtype="float32")
def forward(self, inputs):
out = inputs * self.scale_reg
return out
@register
class FCOSFeat(nn.Layer):
"""
FCOSFeat of FCOS
Args:
feat_in (int): The channel number of input Tensor.
feat_out (int): The channel number of output Tensor.
num_convs (int): The convolution number of the FCOSFeat.
norm_type (str): Normalization type, 'bn'/'sync_bn'/'gn'.
use_dcn (bool): Whether to use dcn in tower or not.
"""
def __init__(self,
feat_in=256,
feat_out=256,
num_convs=4,
norm_type='bn',
use_dcn=False):
super(FCOSFeat, self).__init__()
self.num_convs = num_convs
self.norm_type = norm_type
self.cls_subnet_convs = []
self.reg_subnet_convs = []
for i in range(self.num_convs):
in_c = feat_in if i == 0 else feat_out
cls_conv_name = 'fcos_head_cls_tower_conv_{}'.format(i)
cls_conv = self.add_sublayer(
cls_conv_name,
ConvNormLayer(
ch_in=in_c,
ch_out=feat_out,
filter_size=3,
stride=1,
norm_type=norm_type,
use_dcn=use_dcn,
norm_name=cls_conv_name + '_norm',
bias_on=True,
lr_scale=2.,
name=cls_conv_name))
self.cls_subnet_convs.append(cls_conv)
reg_conv_name = 'fcos_head_reg_tower_conv_{}'.format(i)
reg_conv = self.add_sublayer(
reg_conv_name,
ConvNormLayer(
ch_in=in_c,
ch_out=feat_out,
filter_size=3,
stride=1,
norm_type=norm_type,
use_dcn=use_dcn,
norm_name=reg_conv_name + '_norm',
bias_on=True,
lr_scale=2.,
name=reg_conv_name))
self.reg_subnet_convs.append(reg_conv)
def forward(self, fpn_feat):
cls_feat = fpn_feat
reg_feat = fpn_feat
for i in range(self.num_convs):
cls_feat = F.relu(self.cls_subnet_convs[i](cls_feat))
reg_feat = F.relu(self.reg_subnet_convs[i](reg_feat))
return cls_feat, reg_feat
@register
class FCOSHead(nn.Layer):
"""
FCOSHead
Args:
num_classes(int): Number of classes
fpn_stride(list): The stride of each FPN Layer
prior_prob(float): Used to set the bias init for the class prediction layer
fcos_loss(object): Instance of 'FCOSLoss'
norm_reg_targets(bool): Normalization the regression target if true
centerness_on_reg(bool): The prediction of centerness on regression or clssification branch
"""
__inject__ = ['fcos_feat', 'fcos_loss']
__shared__ = ['num_classes']
def __init__(self,
fcos_feat,
num_classes=80,
fpn_stride=[8, 16, 32, 64, 128],
prior_prob=0.01,
fcos_loss='FCOSLoss',
norm_reg_targets=True,
centerness_on_reg=True):
super(FCOSHead, self).__init__()
self.fcos_feat = fcos_feat
self.num_classes = num_classes
self.fpn_stride = fpn_stride
self.prior_prob = prior_prob
self.fcos_loss = fcos_loss
self.norm_reg_targets = norm_reg_targets
self.centerness_on_reg = centerness_on_reg
conv_cls_name = "fcos_head_cls"
bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
self.fcos_head_cls = self.add_sublayer(
conv_cls_name,
nn.Conv2D(
in_channels=256,
out_channels=self.num_classes,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(
name=conv_cls_name + "_weights",
initializer=Normal(
mean=0., std=0.01)),
bias_attr=ParamAttr(
name=conv_cls_name + "_bias",
initializer=Constant(value=bias_init_value))))
conv_reg_name = "fcos_head_reg"
self.fcos_head_reg = self.add_sublayer(
conv_reg_name,
nn.Conv2D(
in_channels=256,
out_channels=4,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(
name=conv_reg_name + "_weights",
initializer=Normal(
mean=0., std=0.01)),
bias_attr=ParamAttr(
name=conv_reg_name + "_bias",
initializer=Constant(value=0))))
conv_centerness_name = "fcos_head_centerness"
self.fcos_head_centerness = self.add_sublayer(
conv_centerness_name,
nn.Conv2D(
in_channels=256,
out_channels=1,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(
name=conv_centerness_name + "_weights",
initializer=Normal(
mean=0., std=0.01)),
bias_attr=ParamAttr(
name=conv_centerness_name + "_bias",
initializer=Constant(value=0))))
self.scales_regs = []
for i in range(len(self.fpn_stride)):
lvl = int(math.log(int(self.fpn_stride[i]), 2))
feat_name = 'p{}_feat'.format(lvl)
scale_reg = self.add_sublayer(feat_name, ScaleReg())
self.scales_regs.append(scale_reg)
def _compute_locatioins_by_level(self, fpn_stride, feature):
shape_fm = feature.shape
h, w = shape_fm[2], shape_fm[3]
shift_x = paddle.arange(0, w * fpn_stride, fpn_stride)
shift_y = paddle.arange(0, h * fpn_stride, fpn_stride)
shift_x = paddle.unsqueeze(shift_x, axis=0)
shift_y = paddle.unsqueeze(shift_y, axis=1)
shift_x = paddle.expand_as(shift_x, feature[0, 0, :, :])
shift_y = paddle.expand_as(shift_y, feature[0, 0, :, :])
shift_x.stop_gradient = True
shift_y.stop_gradient = True
shift_x = paddle.reshape(shift_x, shape=[-1])
shift_y = paddle.reshape(shift_y, shape=[-1])
location = paddle.stack([shift_x, shift_y], axis=-1) + fpn_stride / 2
location.stop_gradient = True
return location
def forward(self, fpn_feats, is_training):
assert len(fpn_feats) == len(
self.fpn_stride
), "The size of fpn_feats is not equal to size of fpn_stride"
cls_logits_list = []
bboxes_reg_list = []
centerness_list = []
for scale_reg, fpn_stride, fpn_feat in zip(self.scales_regs,
self.fpn_stride, fpn_feats):
fcos_cls_feat, fcos_reg_feat = self.fcos_feat(fpn_feat)
cls_logits = self.fcos_head_cls(fcos_cls_feat)
bbox_reg = scale_reg(self.fcos_head_reg(fcos_reg_feat))
if self.centerness_on_reg:
centerness = self.fcos_head_centerness(fcos_reg_feat)
else:
centerness = self.fcos_head_centerness(fcos_cls_feat)
if self.norm_reg_targets:
bbox_reg = F.relu(bbox_reg)
if not is_training:
bbox_reg = bbox_reg * fpn_stride
else:
bbox_reg = paddle.exp(bbox_reg)
cls_logits_list.append(cls_logits)
bboxes_reg_list.append(bbox_reg)
centerness_list.append(centerness)
if not is_training:
locations_list = []
for fpn_stride, feature in zip(self.fpn_stride, fpn_feats):
location = self._compute_locatioins_by_level(fpn_stride,
feature)
locations_list.append(location)
return locations_list, cls_logits_list, bboxes_reg_list, centerness_list
else:
return cls_logits_list, bboxes_reg_list, centerness_list
def get_loss(self, fcos_head_outs, tag_labels, tag_bboxes, tag_centerness):
cls_logits, bboxes_reg, centerness = fcos_head_outs
return self.fcos_loss(cls_logits, bboxes_reg, centerness, tag_labels,
tag_bboxes, tag_centerness)
......@@ -48,11 +48,21 @@ class ConvNormLayer(nn.Layer):
norm_groups=32,
use_dcn=False,
norm_name=None,
bias_on=False,
lr_scale=1.,
name=None):
super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn', 'gn']
self.conv = Conv2D(
if bias_on:
bias_attr = ParamAttr(
name=name + "_bias",
initializer=Constant(value=0.),
learning_rate=lr_scale)
else:
bias_attr = False
self.conv = nn.Conv2D(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
......@@ -64,7 +74,7 @@ class ConvNormLayer(nn.Layer):
initializer=Normal(
mean=0., std=0.01),
learning_rate=1.),
bias_attr=False)
bias_attr=bias_attr)
param_attr = ParamAttr(
name=norm_name + "_scale",
......@@ -75,10 +85,10 @@ class ConvNormLayer(nn.Layer):
learning_rate=1.,
regularizer=L2Decay(0.))
if norm_type in ['bn', 'sync_bn']:
self.norm = BatchNorm2D(
self.norm = nn.BatchNorm2D(
ch_out, weight_attr=param_attr, bias_attr=bias_attr)
elif norm_type == 'gn':
self.norm = GroupNorm(
self.norm = nn.GroupNorm(
num_groups=norm_groups,
num_channels=ch_out,
weight_attr=param_attr,
......@@ -713,6 +723,92 @@ class AnchorGrid(object):
@register
@serializable
class FCOSBox(object):
__shared__ = ['num_classes', 'batch_size']
def __init__(self, num_classes=80, batch_size=1):
super(FCOSBox, self).__init__()
self.num_classes = num_classes
self.batch_size = batch_size
def _merge_hw(self, inputs, ch_type="channel_first"):
"""
Args:
inputs (Variables): Feature map whose H and W will be merged into one dimension
ch_type (str): channel_first / channel_last
Return:
new_shape (Variables): The new shape after h and w merged into one dimension
"""
shape_ = paddle.shape(inputs)
bs, ch, hi, wi = shape_[0], shape_[1], shape_[2], shape_[3]
img_size = hi * wi
img_size.stop_gradient = True
if ch_type == "channel_first":
new_shape = paddle.concat([bs, ch, img_size])
elif ch_type == "channel_last":
new_shape = paddle.concat([bs, img_size, ch])
else:
raise KeyError("Wrong ch_type %s" % ch_type)
new_shape.stop_gradient = True
return new_shape
def _postprocessing_by_level(self, locations, box_cls, box_reg, box_ctn,
scale_factor):
"""
Args:
locations (Variables): anchor points for current layer, [H*W, 2]
box_cls (Variables): categories prediction, [N, C, H, W], C is the number of classes
box_reg (Variables): bounding box prediction, [N, 4, H, W]
box_ctn (Variables): centerness prediction, [N, 1, H, W]
scale_factor (Variables): [h_scale, w_scale] for input images
Return:
box_cls_ch_last (Variables): score for each category, in [N, C, M]
C is the number of classes and M is the number of anchor points
box_reg_decoding (Variables): decoded bounding box, in [N, M, 4]
last dimension is [x1, y1, x2, y2]
"""
act_shape_cls = self._merge_hw(box_cls)
box_cls_ch_last = paddle.reshape(x=box_cls, shape=act_shape_cls)
box_cls_ch_last = F.sigmoid(box_cls_ch_last)
act_shape_reg = self._merge_hw(box_reg)
box_reg_ch_last = paddle.reshape(x=box_reg, shape=act_shape_reg)
box_reg_ch_last = paddle.transpose(box_reg_ch_last, perm=[0, 2, 1])
box_reg_decoding = paddle.stack(
[
locations[:, 0] - box_reg_ch_last[:, :, 0],
locations[:, 1] - box_reg_ch_last[:, :, 1],
locations[:, 0] + box_reg_ch_last[:, :, 2],
locations[:, 1] + box_reg_ch_last[:, :, 3]
],
axis=1)
box_reg_decoding = paddle.transpose(box_reg_decoding, perm=[0, 2, 1])
act_shape_ctn = self._merge_hw(box_ctn)
box_ctn_ch_last = paddle.reshape(x=box_ctn, shape=act_shape_ctn)
box_ctn_ch_last = F.sigmoid(box_ctn_ch_last)
# recover the location to original image
im_scale = paddle.concat([scale_factor, scale_factor], axis=1)
box_reg_decoding = box_reg_decoding / im_scale
box_cls_ch_last = box_cls_ch_last * box_ctn_ch_last
return box_cls_ch_last, box_reg_decoding
def __call__(self, locations, cls_logits, bboxes_reg, centerness,
scale_factor):
pred_boxes_ = []
pred_scores_ = []
for pts, cls, box, ctn in zip(locations, cls_logits, bboxes_reg,
centerness):
pred_scores_lvl, pred_boxes_lvl = self._postprocessing_by_level(
pts, cls, box, ctn, scale_factor)
pred_boxes_.append(pred_boxes_lvl)
pred_scores_.append(pred_scores_lvl)
pred_boxes = paddle.concat(pred_boxes_, axis=1)
pred_scores = paddle.concat(pred_scores_, axis=2)
return pred_boxes, pred_scores
class MaskMatrixNMS(object):
"""
Matrix NMS for multi-class masks.
......
......@@ -16,10 +16,12 @@ from . import yolo_loss
from . import iou_aware_loss
from . import iou_loss
from . import ssd_loss
from . import fcos_loss
from . import solov2_loss
from .yolo_loss import *
from .iou_aware_loss import *
from .iou_loss import *
from .ssd_loss import *
from .fcos_loss import *
from .solov2_loss import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
INF = 1e8
__all__ = ['FCOSLoss']
def flatten_tensor(inputs, channel_first=False):
"""
Flatten a Tensor
Args:
inputs (Tensor): 4-D Tensor with shape [N, C, H, W] or [N, H, W, C]
channel_first(bool): if true the dimension order of
Tensor is [N, C, H, W], otherwise is [N, H, W, C]
Return:
input_channel_last (Tensor): The flattened Tensor in channel_last style
"""
if channel_first:
input_channel_last = paddle.transpose(inputs, perm=[0, 2, 3, 1])
else:
input_channel_last = inputs
output_channel_last = paddle.flatten(
input_channel_last, start_axis=0, stop_axis=2) # [N*H*W, C]
return output_channel_last
def sigmoid_cross_entropy_with_logits_loss(inputs,
label,
ignore_index=-100,
normalize=False):
output = F.binary_cross_entropy_with_logits(inputs, label, reduction='none')
mask_tensor = paddle.cast(label != ignore_index, 'float32')
output = paddle.multiply(output, mask_tensor)
if normalize:
sum_valid_mask = paddle.sum(mask_tensor)
output = output / sum_valid_mask
return output
@register
class FCOSLoss(nn.Layer):
"""
FCOSLoss
Args:
loss_alpha (float): alpha in focal loss
loss_gamma (float): gamma in focal loss
iou_loss_type(str): location loss type, IoU/GIoU/LINEAR_IoU
reg_weights(float): weight for location loss
"""
def __init__(self,
loss_alpha=0.25,
loss_gamma=2.0,
iou_loss_type="giou",
reg_weights=1.0):
super(FCOSLoss, self).__init__()
self.loss_alpha = loss_alpha
self.loss_gamma = loss_gamma
self.iou_loss_type = iou_loss_type
self.reg_weights = reg_weights
def __iou_loss(self, pred, targets, positive_mask, weights=None):
"""
Calculate the loss for location prediction
Args:
pred (Tensor): bounding boxes prediction
targets (Tensor): targets for positive samples
positive_mask (Tensor): mask of positive samples
weights (Tensor): weights for each positive samples
Return:
loss (Tensor): location loss
"""
plw = pred[:, 0] * positive_mask
pth = pred[:, 1] * positive_mask
prw = pred[:, 2] * positive_mask
pbh = pred[:, 3] * positive_mask
tlw = targets[:, 0] * positive_mask
tth = targets[:, 1] * positive_mask
trw = targets[:, 2] * positive_mask
tbh = targets[:, 3] * positive_mask
tlw.stop_gradient = True
trw.stop_gradient = True
tth.stop_gradient = True
tbh.stop_gradient = True
ilw = paddle.minimum(plw, tlw)
irw = paddle.minimum(prw, trw)
ith = paddle.minimum(pth, tth)
ibh = paddle.minimum(pbh, tbh)
clw = paddle.maximum(plw, tlw)
crw = paddle.maximum(prw, trw)
cth = paddle.maximum(pth, tth)
cbh = paddle.maximum(pbh, tbh)
area_predict = (plw + prw) * (pth + pbh)
area_target = (tlw + trw) * (tth + tbh)
area_inter = (ilw + irw) * (ith + ibh)
ious = (area_inter + 1.0) / (
area_predict + area_target - area_inter + 1.0)
ious = ious * positive_mask
if self.iou_loss_type.lower() == "linear_iou":
loss = 1.0 - ious
elif self.iou_loss_type.lower() == "giou":
area_uniou = area_predict + area_target - area_inter
area_circum = (clw + crw) * (cth + cbh) + 1e-7
giou = ious - (area_circum - area_uniou) / area_circum
loss = 1.0 - giou
elif self.iou_loss_type.lower() == "iou":
loss = 0.0 - paddle.log(ious)
else:
raise KeyError
if weights is not None:
loss = loss * weights
return loss
def forward(self, cls_logits, bboxes_reg, centerness, tag_labels,
tag_bboxes, tag_center):
"""
Calculate the loss for classification, location and centerness
Args:
cls_logits (list): list of Tensor, which is predicted
score for all anchor points with shape [N, M, C]
bboxes_reg (list): list of Tensor, which is predicted
offsets for all anchor points with shape [N, M, 4]
centerness (list): list of Tensor, which is predicted
centerness for all anchor points with shape [N, M, 1]
tag_labels (list): list of Tensor, which is category
targets for each anchor point
tag_bboxes (list): list of Tensor, which is bounding
boxes targets for positive samples
tag_center (list): list of Tensor, which is centerness
targets for positive samples
Return:
loss (dict): loss composed by classification loss, bounding box
"""
cls_logits_flatten_list = []
bboxes_reg_flatten_list = []
centerness_flatten_list = []
tag_labels_flatten_list = []
tag_bboxes_flatten_list = []
tag_center_flatten_list = []
num_lvl = len(cls_logits)
for lvl in range(num_lvl):
cls_logits_flatten_list.append(
flatten_tensor(cls_logits[lvl], True))
bboxes_reg_flatten_list.append(
flatten_tensor(bboxes_reg[lvl], True))
centerness_flatten_list.append(
flatten_tensor(centerness[lvl], True))
tag_labels_flatten_list.append(
flatten_tensor(tag_labels[lvl], False))
tag_bboxes_flatten_list.append(
flatten_tensor(tag_bboxes[lvl], False))
tag_center_flatten_list.append(
flatten_tensor(tag_center[lvl], False))
cls_logits_flatten = paddle.concat(cls_logits_flatten_list, axis=0)
bboxes_reg_flatten = paddle.concat(bboxes_reg_flatten_list, axis=0)
centerness_flatten = paddle.concat(centerness_flatten_list, axis=0)
tag_labels_flatten = paddle.concat(tag_labels_flatten_list, axis=0)
tag_bboxes_flatten = paddle.concat(tag_bboxes_flatten_list, axis=0)
tag_center_flatten = paddle.concat(tag_center_flatten_list, axis=0)
tag_labels_flatten.stop_gradient = True
tag_bboxes_flatten.stop_gradient = True
tag_center_flatten.stop_gradient = True
mask_positive_bool = tag_labels_flatten > 0
mask_positive_bool.stop_gradient = True
mask_positive_float = paddle.cast(mask_positive_bool, dtype="float32")
mask_positive_float.stop_gradient = True
num_positive_fp32 = paddle.sum(mask_positive_float)
num_positive_fp32.stop_gradient = True
num_positive_int32 = paddle.cast(num_positive_fp32, dtype="int32")
num_positive_int32 = num_positive_int32 * 0 + 1
num_positive_int32.stop_gradient = True
normalize_sum = paddle.sum(tag_center_flatten * mask_positive_float)
normalize_sum.stop_gradient = True
# 1. cls_logits: sigmoid_focal_loss
# expand onehot labels
num_classes = cls_logits_flatten.shape[-1]
tag_labels_flatten = paddle.squeeze(tag_labels_flatten, axis=-1)
tag_labels_flatten_bin = F.one_hot(
tag_labels_flatten, num_classes=1 + num_classes)
tag_labels_flatten_bin = tag_labels_flatten_bin[:, 1:]
# sigmoid_focal_loss
cls_loss = F.sigmoid_focal_loss(
cls_logits_flatten, tag_labels_flatten_bin) / num_positive_fp32
# 2. bboxes_reg: giou_loss
mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1)
tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1)
reg_loss = self.__iou_loss(
bboxes_reg_flatten,
tag_bboxes_flatten,
mask_positive_float,
weights=tag_center_flatten)
reg_loss = reg_loss * mask_positive_float / normalize_sum
# 3. centerness: sigmoid_cross_entropy_with_logits_loss
centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1)
ctn_loss = sigmoid_cross_entropy_with_logits_loss(centerness_flatten,
tag_center_flatten)
ctn_loss = ctn_loss * mask_positive_float / num_positive_fp32
loss_all = {
"loss_centerness": paddle.sum(ctn_loss),
"loss_cls": paddle.sum(cls_loss),
"loss_box": paddle.sum(reg_loss)
}
return loss_all
......@@ -31,17 +31,28 @@ class FPN(Layer):
out_channel,
min_level=0,
max_level=4,
spatial_scale=[0.25, 0.125, 0.0625, 0.03125]):
spatial_scale=[0.25, 0.125, 0.0625, 0.03125],
has_extra_convs=False,
use_c5=True,
relu_before_extra_convs=True):
super(FPN, self).__init__()
self.min_level = min_level
self.max_level = max_level
self.spatial_scale = spatial_scale
self.has_extra_convs = has_extra_convs
self.use_c5 = use_c5
self.relu_before_extra_convs = relu_before_extra_convs
self.lateral_convs = []
self.fpn_convs = []
fan = out_channel * 3 * 3
for i in range(min_level, max_level):
self.num_backbone_stages = len(spatial_scale)
self.num_outs = self.max_level - self.min_level + 1
self.highest_backbone_level = self.min_level + self.num_backbone_stages - 1
for i in range(self.min_level, self.highest_backbone_level + 1):
if i == 3:
lateral_name = 'fpn_inner_res5_sum'
else:
......@@ -73,25 +84,69 @@ class FPN(Layer):
learning_rate=2., regularizer=L2Decay(0.))))
self.fpn_convs.append(fpn_conv)
# add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
if self.has_extra_convs and self.num_outs > self.num_backbone_stages:
for lvl in range(self.highest_backbone_level + 1,
self.max_level + 1): # P6 P7 ...
if lvl == self.highest_backbone_level + 1 and self.use_c5:
in_c = in_channels[self.highest_backbone_level]
else:
in_c = out_channel
extra_fpn_name = 'fpn_{}'.format(lvl + 2)
extra_fpn_conv = self.add_sublayer(
extra_fpn_name,
Conv2D(
in_channels=in_c,
out_channels=out_channel,
kernel_size=3,
stride=2,
padding=1,
weight_attr=ParamAttr(
initializer=XavierUniform(fan_out=fan)),
bias_attr=ParamAttr(
learning_rate=2., regularizer=L2Decay(0.))))
self.fpn_convs.append(extra_fpn_conv)
def forward(self, body_feats):
laterals = []
for lvl in range(self.min_level, self.max_level):
laterals.append(self.lateral_convs[lvl](body_feats[lvl]))
used_backbone_levels = len(self.spatial_scale)
for i in range(used_backbone_levels):
laterals.append(self.lateral_convs[i](body_feats[i]))
for i in range(self.min_level + 1, self.max_level):
lvl = self.max_level + self.min_level - i
used_backbone_levels = len(self.spatial_scale)
for i in range(used_backbone_levels - 1):
idx = used_backbone_levels - 1 - i
upsample = F.interpolate(
laterals[lvl],
laterals[idx],
scale_factor=2.,
mode='nearest', )
laterals[lvl - 1] = laterals[lvl - 1] + upsample
laterals[idx - 1] += upsample
fpn_output = []
for lvl in range(self.min_level, self.max_level):
fpn_output.append(self.fpn_convs[lvl](laterals[lvl]))
for lvl in range(self.min_level, self.highest_backbone_level + 1):
i = lvl - self.min_level
fpn_output.append(self.fpn_convs[i](laterals[i]))
extension = F.max_pool2d(fpn_output[-1], 1, stride=2)
spatial_scale = self.spatial_scale + [self.spatial_scale[-1] * 0.5]
fpn_output.append(extension)
return fpn_output, spatial_scale
spatial_scales = self.spatial_scale
if self.num_outs > len(fpn_output):
# use max pool to get more levels on top of outputs (Faster R-CNN, Mask R-CNN)
if not self.has_extra_convs:
fpn_output.append(F.max_pool2d(fpn_output[-1], 1, stride=2))
spatial_scales = spatial_scales + [spatial_scales[-1] * 0.5]
# add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
else:
if self.use_c5:
extra_source = body_feats[-1]
else:
extra_source = fpn_output[-1]
fpn_output.append(self.fpn_convs[used_backbone_levels](
extra_source))
spatial_scales = spatial_scales + [spatial_scales[-1] * 0.5]
for i in range(used_backbone_levels + 1, self.num_outs):
if self.relu_before_extra_convs:
fpn_output.append(self.fpn_convs[i](F.relu(fpn_output[
-1])))
else:
fpn_output.append(self.fpn_convs[i](fpn_output[-1]))
spatial_scales = spatial_scales + [spatial_scales[-1] * 0.5]
return fpn_output, spatial_scales
......@@ -40,3 +40,20 @@ class MaskPostProcess(object):
self.mask_resolution, self.binary_thresh)
mask = {'mask': mask}
return mask
@register
class FCOSPostProcess(object):
__inject__ = ['decode', 'nms']
def __init__(self, decode=None, nms=None):
super(FCOSPostProcess, self).__init__()
self.decode = decode
self.nms = nms
def __call__(self, fcos_head_outs, scale_factor):
locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
bboxes, score = self.decode(locations, cls_logits, bboxes_reg,
centerness, scale_factor)
bbox_pred, bbox_num, _ = self.nms(bboxes, score)
return bbox_pred, bbox_num
......@@ -83,7 +83,7 @@ class TrainingStats(object):
for k, v in extras.items():
stats[k] = v
for k, v in self.meters.items():
stats[k] = format(v.median, '.4f')
stats[k] = format(v.median, '.6f')
return stats
......
......@@ -197,10 +197,9 @@ class FCOSLoss(object):
reg_loss = fluid.layers.elementwise_mul(
reg_loss, mask_positive_float, axis=0) / normalize_sum
ctn_loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=centerness_flatten,
label=tag_center_flatten) * mask_positive_float / num_positive_fp32
x=centerness_flatten, label=tag_center_flatten)
ctn_loss = fluid.layers.elementwise_mul(
ctn_loss, mask_positive_float, axis=0) / normalize_sum
ctn_loss, mask_positive_float, axis=0) / num_positive_fp32
loss_all = {
"loss_centerness": fluid.layers.reduce_sum(ctn_loss),
"loss_cls": fluid.layers.reduce_sum(cls_loss),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册