未验证 提交 e83c3ecf 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] add TOOD (#4259)

* [dev] add TOOD

* fix link and notes

* replace ConvNormLayer
上级 340da51b
# TOOD
## Introduction
[TOOD: Task-aligned One-stage Object Detection](https://arxiv.org/abs/2108.07755)
TOOD is an object detection model. We reproduced the model of the paper.
## Model Zoo
| Backbone | Model | Images/GPU | Inf time (fps) | Box AP | Config | Download |
|:------:|:--------:|:--------:|:--------------:|:------:|:------:|:--------:|
| R-50 | TOOD | 4 | --- | 42.8 | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/tood/tood_r50_fpn_1x_coco.yml) | [model](https://paddledet.bj.bcebos.com/models/tood_r50_fpn_1x_coco.pdparams) |
**Notes:**
- TOOD is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
- TOOD uses 8GPU to train 12 epochs.
GPU multi-card training
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/tood/tood_r50_fpn_1x_coco.yml --fleet
```
## Citations
```
@inproceedings{feng2021tood,
title={TOOD: Task-aligned One-stage Object Detection},
author={Feng, Chengjian and Zhong, Yujie and Gao, Yu and Scott, Matthew R and Huang, Weilin},
booktitle={ICCV},
year={2021}
}
```
epoch: 12
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [8, 11]
- !LinearWarmup
start_factor: 0.001
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
architecture: TOOD
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
TOOD:
backbone: ResNet
neck: FPN
head: TOODHead
ResNet:
depth: 50
variant: b
norm_type: bn
freeze_at: 0
return_idx: [1, 2, 3]
num_stages: 4
FPN:
out_channel: 256
spatial_scales: [0.125, 0.0625, 0.03125]
extra_stage: 2
has_extra_convs: true
use_c5: false
TOODHead:
stacked_convs: 6
grid_cell_scale: 8
static_assigner_epoch: 4
loss_weight: { class: 1.0, iou: 2.0 }
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.05
nms_threshold: 0.6
worker_num: 4
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 4
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: true
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
shuffle: false
TestReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
shuffle: false
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/tood_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/tood_reader.yml',
]
weights: output/tood_r50_fpn_1x_coco/model_final
find_unused_parameters: True
log_iter: 100
......@@ -352,6 +352,7 @@ class Trainer(object):
self.status['data_time'].update(time.time() - iter_tic)
self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status)
data['epoch_id'] = epoch_id
if self.cfg.get('fp16', False):
with amp.auto_cast(enable=self.cfg.use_gpu):
......
......@@ -28,6 +28,7 @@ from . import layers
from . import reid
from . import mot
from . import transformers
from . import assigners
from .ops import *
from .backbones import *
......@@ -41,3 +42,4 @@ from .layers import *
from .reid import *
from .mot import *
from .transformers import *
from .assigners import *
......@@ -25,6 +25,7 @@ from . import gfl
from . import picodet
from . import detr
from . import sparse_rcnn
from . import tood
from .meta_arch import *
from .faster_rcnn import *
......@@ -47,3 +48,4 @@ from .gfl import *
from .picodet import *
from .detr import *
from .sparse_rcnn import *
from .tood import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
__all__ = ['TOOD']
@register
class TOOD(BaseArch):
"""
TOOD: Task-aligned One-stage Object Detection, see https://arxiv.org/abs/2108.07755
Args:
backbone (nn.Layer): backbone instance
neck (nn.Layer): 'FPN' instance
head (nn.Layer): 'TOODHead' instance
"""
__category__ = 'architecture'
def __init__(self, backbone, neck, head):
super(TOOD, self).__init__()
self.backbone = backbone
self.neck = neck
self.head = head
@classmethod
def from_config(cls, cfg, *args, **kwargs):
backbone = create(cfg['backbone'])
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
kwargs = {'input_shape': neck.out_shape}
head = create(cfg['head'], **kwargs)
return {
'backbone': backbone,
'neck': neck,
"head": head,
}
def _forward(self):
body_feats = self.backbone(self.inputs)
fpn_feats = self.neck(body_feats)
head_outs = self.head(fpn_feats)
if not self.training:
bboxes, bbox_num = self.head.post_process(
head_outs, self.inputs['im_shape'], self.inputs['scale_factor'])
return bboxes, bbox_num
else:
loss = self.head.get_loss(head_outs, self.inputs)
return loss
def get_loss(self):
return self._forward()
def get_pred(self):
bbox_pred, bbox_num = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import utils
from . import task_aligned_assigner
from . import atss_assigner
from .utils import *
from .task_aligned_assigner import *
from .atss_assigner import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..ops import iou_similarity
from ..bbox_utils import bbox_center
from .utils import (pad_gt, check_points_inside_bboxes, compute_max_iou_anchor,
compute_max_iou_gt)
@register
class ATSSAssigner(nn.Layer):
"""Bridging the Gap Between Anchor-based and Anchor-free Detection
via Adaptive Training Sample Selection
"""
__shared__ = ['num_classes']
def __init__(self,
topk=9,
num_classes=80,
force_gt_matching=False,
eps=1e-9):
super(ATSSAssigner, self).__init__()
self.topk = topk
self.num_classes = num_classes
self.force_gt_matching = force_gt_matching
self.eps = eps
def _gather_topk_pyramid(self, gt2anchor_distances, num_anchors_list,
pad_gt_mask):
pad_gt_mask = pad_gt_mask.tile([1, 1, self.topk]).astype(paddle.bool)
gt2anchor_distances_list = paddle.split(
gt2anchor_distances, num_anchors_list, axis=-1)
num_anchors_index = np.cumsum(num_anchors_list).tolist()
num_anchors_index = [0, ] + num_anchors_index[:-1]
is_in_topk_list = []
topk_idxs_list = []
for distances, anchors_index in zip(gt2anchor_distances_list,
num_anchors_index):
num_anchors = distances.shape[-1]
topk_metrics, topk_idxs = paddle.topk(
distances, self.topk, axis=-1, largest=False)
topk_idxs_list.append(topk_idxs + anchors_index)
topk_idxs = paddle.where(pad_gt_mask, topk_idxs,
paddle.zeros_like(topk_idxs))
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2)
is_in_topk = paddle.where(is_in_topk > 1,
paddle.zeros_like(is_in_topk), is_in_topk)
is_in_topk_list.append(is_in_topk.astype(gt2anchor_distances.dtype))
is_in_topk_list = paddle.concat(is_in_topk_list, axis=-1)
topk_idxs_list = paddle.concat(topk_idxs_list, axis=-1)
return is_in_topk_list, topk_idxs_list
@paddle.no_grad()
def forward(self,
anchor_bboxes,
num_anchors_list,
gt_labels,
gt_bboxes,
bg_index,
gt_scores=None):
r"""The assignment is done in following steps
1. compute iou between all bbox (bbox of all pyramid levels) and gt
2. compute center distance between all bbox and gt
3. on each pyramid level, for each gt, select k bbox whose center
are closest to the gt center, so we total select k*l bbox as
candidates for each gt
4. get corresponding iou for the these candidates, and compute the
mean and std, set mean + std as the iou threshold
5. select these candidates whose iou are greater than or equal to
the threshold as positive
6. limit the positive sample's center in gt
7. if an anchor box is assigned to multiple gts, the one with the
highest iou will be selected.
Args:
anchor_bboxes (Tensor, float32): pre-defined anchors, shape(L, 4),
"xmin, xmax, ymin, ymax" format
num_anchors_list (List): num of anchors in each level
gt_labels (Tensor|List[Tensor], int64): Label of gt_bboxes, shape(B, n, 1)
gt_bboxes (Tensor|List[Tensor], float32): Ground truth bboxes, shape(B, n, 4)
bg_index (int): background index
gt_scores (Tensor|List[Tensor]|None, float32) Score of gt_bboxes,
shape(B, n, 1), if None, then it will initialize with one_hot label
Returns:
assigned_labels (Tensor): (B, L)
assigned_bboxes (Tensor): (B, L, 4)
assigned_scores (Tensor): (B, L, C)
"""
gt_labels, gt_bboxes, pad_gt_scores, pad_gt_mask = pad_gt(
gt_labels, gt_bboxes, gt_scores)
assert gt_labels.ndim == gt_bboxes.ndim and \
gt_bboxes.ndim == 3
num_anchors, _ = anchor_bboxes.shape
batch_size, num_max_boxes, _ = gt_bboxes.shape
# 1. compute iou between gt and anchor bbox, [B, n, L]
ious = iou_similarity(gt_bboxes.reshape([-1, 4]), anchor_bboxes)
ious = ious.reshape([batch_size, -1, num_anchors])
# 2. compute center distance between all anchors and gt, [B, n, L]
gt_centers = bbox_center(gt_bboxes.reshape([-1, 4])).unsqueeze(1)
anchor_centers = bbox_center(anchor_bboxes)
gt2anchor_distances = (gt_centers - anchor_centers.unsqueeze(0)) \
.norm(2, axis=-1).reshape([batch_size, -1, num_anchors])
# 3. on each pyramid level, selecting topk closest candidates
# based on the center distance, [B, n, L]
is_in_topk, topk_idxs = self._gather_topk_pyramid(
gt2anchor_distances, num_anchors_list, pad_gt_mask)
# 4. get corresponding iou for the these candidates, and compute the
# mean and std, 5. set mean + std as the iou threshold
iou_candidates = ious * is_in_topk
iou_threshold = paddle.index_sample(
iou_candidates.flatten(stop_axis=-2),
topk_idxs.flatten(stop_axis=-2))
iou_threshold = iou_threshold.reshape([batch_size, num_max_boxes, -1])
iou_threshold = iou_threshold.mean(axis=-1, keepdim=True) + \
iou_threshold.std(axis=-1, keepdim=True)
is_in_topk = paddle.where(
iou_candidates > iou_threshold.tile([1, 1, num_anchors]),
is_in_topk, paddle.zeros_like(is_in_topk))
# 6. check the positive sample's center in gt, [B, n, L]
is_in_gts = check_points_inside_bboxes(anchor_centers, gt_bboxes)
# select positive sample, [B, n, L]
mask_positive = is_in_topk * is_in_gts * pad_gt_mask
# 7. if an anchor box is assigned to multiple gts,
# the one with the highest iou will be selected.
mask_positive_sum = mask_positive.sum(axis=-2)
if mask_positive_sum.max() > 1:
mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile(
[1, num_max_boxes, 1])
is_max_iou = compute_max_iou_anchor(ious)
mask_positive = paddle.where(mask_multiple_gts, is_max_iou,
mask_positive)
mask_positive_sum = mask_positive.sum(axis=-2)
# 8. make sure every gt_bbox matches the anchor
if self.force_gt_matching:
is_max_iou = compute_max_iou_gt(ious) * pad_gt_mask
mask_max_iou = (is_max_iou.sum(-2, keepdim=True) == 1).tile(
[1, num_max_boxes, 1])
mask_positive = paddle.where(mask_max_iou, is_max_iou,
mask_positive)
mask_positive_sum = mask_positive.sum(axis=-2)
assigned_gt_index = mask_positive.argmax(axis=-2)
assert mask_positive_sum.max() == 1, \
("one anchor just assign one gt, but received not equals 1. "
"Received: %f" % mask_positive_sum.max().item())
# assigned target
batch_ind = paddle.arange(
end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1)
assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes
assigned_labels = paddle.gather(
gt_labels.flatten(), assigned_gt_index.flatten(), axis=0)
assigned_labels = assigned_labels.reshape([batch_size, num_anchors])
assigned_labels = paddle.where(
mask_positive_sum > 0, assigned_labels,
paddle.full_like(assigned_labels, bg_index))
assigned_bboxes = paddle.gather(
gt_bboxes.reshape([-1, 4]), assigned_gt_index.flatten(), axis=0)
assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
assigned_scores = F.one_hot(assigned_labels, self.num_classes)
if gt_scores is not None:
gather_scores = paddle.gather(
pad_gt_scores.flatten(), assigned_gt_index.flatten(), axis=0)
gather_scores = gather_scores.reshape([batch_size, num_anchors])
gather_scores = paddle.where(mask_positive_sum > 0, gather_scores,
paddle.zeros_like(gather_scores))
assigned_scores *= gather_scores.unsqueeze(-1)
return assigned_labels, assigned_bboxes, assigned_scores
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..bbox_utils import iou_similarity
from .utils import (pad_gt, gather_topk_anchors, check_points_inside_bboxes,
compute_max_iou_anchor)
@register
class TaskAlignedAssigner(nn.Layer):
"""TOOD: Task-aligned One-stage Object Detection
"""
def __init__(self, topk=13, alpha=1.0, beta=6.0, eps=1e-9):
super(TaskAlignedAssigner, self).__init__()
self.topk = topk
self.alpha = alpha
self.beta = beta
self.eps = eps
@paddle.no_grad()
def forward(self,
pred_scores,
pred_bboxes,
anchor_points,
gt_labels,
gt_bboxes,
bg_index,
gt_scores=None):
r"""The assignment is done in following steps
1. compute alignment metric between all bbox (bbox of all pyramid levels) and gt
2. select top-k bbox as candidates for each gt
3. limit the positive sample's center in gt (because the anchor-free detector
only can predict positive distance)
4. if an anchor box is assigned to multiple gts, the one with the
highest iou will be selected.
Args:
pred_scores (Tensor, float32): predicted class probability, shape(B, L, C)
pred_bboxes (Tensor, float32): predicted bounding boxes, shape(B, L, 4)
anchor_points (Tensor, float32): pre-defined anchors, shape(L, 2), "cxcy" format
gt_labels (Tensor|List[Tensor], int64): Label of gt_bboxes, shape(B, n, 1)
gt_bboxes (Tensor|List[Tensor], float32): Ground truth bboxes, shape(B, n, 4)
bg_index (int): background index
gt_scores (Tensor|List[Tensor]|None, float32) Score of gt_bboxes,
shape(B, n, 1), if None, then it will initialize with one_hot label
Returns:
assigned_labels (Tensor): (B, L)
assigned_bboxes (Tensor): (B, L, 4)
assigned_scores (Tensor): (B, L, C)
"""
assert pred_scores.ndim == pred_bboxes.ndim
gt_labels, gt_bboxes, pad_gt_scores, pad_gt_mask = pad_gt(
gt_labels, gt_bboxes, gt_scores)
assert gt_labels.ndim == gt_bboxes.ndim and \
gt_bboxes.ndim == 3
batch_size, num_anchors, num_classes = pred_scores.shape
_, num_max_boxes, _ = gt_bboxes.shape
# compute iou between gt and pred bbox, [B, n, L]
ious = iou_similarity(gt_bboxes, pred_bboxes)
# gather pred bboxes class score
pred_scores = pred_scores.transpose([0, 2, 1])
batch_ind = paddle.arange(
end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1)
gt_labels_ind = paddle.stack(
[batch_ind.tile([1, num_max_boxes]), gt_labels.squeeze(-1)],
axis=-1)
bbox_cls_scores = paddle.gather_nd(pred_scores, gt_labels_ind)
# compute alignment metrics, [B, n, L]
alignment_metrics = bbox_cls_scores.pow(self.alpha) * ious.pow(
self.beta)
# check the positive sample's center in gt, [B, n, L]
is_in_gts = check_points_inside_bboxes(anchor_points, gt_bboxes)
# select topk largest alignment metrics pred bbox as candidates
# for each gt, [B, n, L]
is_in_topk = gather_topk_anchors(
alignment_metrics * is_in_gts,
self.topk,
topk_mask=pad_gt_mask.tile([1, 1, self.topk]).astype(paddle.bool))
# select positive sample, [B, n, L]
mask_positive = is_in_topk * is_in_gts * pad_gt_mask
# if an anchor box is assigned to multiple gts,
# the one with the highest iou will be selected, [B, n, L]
mask_positive_sum = mask_positive.sum(axis=-2)
if mask_positive_sum.max() > 1:
mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile(
[1, num_max_boxes, 1])
is_max_iou = compute_max_iou_anchor(ious)
mask_positive = paddle.where(mask_multiple_gts, is_max_iou,
mask_positive)
mask_positive_sum = mask_positive.sum(axis=-2)
assigned_gt_index = mask_positive.argmax(axis=-2)
assert mask_positive_sum.max() == 1, \
("one anchor just assign one gt, but received not equals 1. "
"Received: %f" % mask_positive_sum.max().item())
# assigned target
assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes
assigned_labels = paddle.gather(
gt_labels.flatten(), assigned_gt_index.flatten(), axis=0)
assigned_labels = assigned_labels.reshape([batch_size, num_anchors])
assigned_labels = paddle.where(
mask_positive_sum > 0, assigned_labels,
paddle.full_like(assigned_labels, bg_index))
assigned_bboxes = paddle.gather(
gt_bboxes.reshape([-1, 4]), assigned_gt_index.flatten(), axis=0)
assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
assigned_scores = F.one_hot(assigned_labels, num_classes)
# rescale alignment metrics
alignment_metrics *= mask_positive
max_metrics_per_instance = alignment_metrics.max(axis=-1, keepdim=True)
max_ious_per_instance = (ious * mask_positive).max(axis=-1,
keepdim=True)
alignment_metrics = alignment_metrics / (
max_metrics_per_instance + self.eps) * max_ious_per_instance
alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1)
assigned_scores = assigned_scores * alignment_metrics
return assigned_labels, assigned_bboxes, assigned_scores
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn.functional as F
def pad_gt(gt_labels, gt_bboxes, gt_scores=None):
r""" Pad 0 in gt_labels and gt_bboxes.
Args:
gt_labels (Tensor|List[Tensor], int64): Label of gt_bboxes,
shape is [B, n, 1] or [[n_1, 1], [n_2, 1], ...], here n = sum(n_i)
gt_bboxes (Tensor|List[Tensor], float32): Ground truth bboxes,
shape is [B, n, 4] or [[n_1, 4], [n_2, 4], ...], here n = sum(n_i)
gt_scores (Tensor|List[Tensor]|None, float32): Score of gt_bboxes,
shape is [B, n, 1] or [[n_1, 4], [n_2, 4], ...], here n = sum(n_i)
Returns:
pad_gt_labels (Tensor, int64): shape[B, n, 1]
pad_gt_bboxes (Tensor, float32): shape[B, n, 4]
pad_gt_scores (Tensor, float32): shape[B, n, 1]
pad_gt_mask (Tensor, float32): shape[B, n, 1], 1 means bbox, 0 means no bbox
"""
if isinstance(gt_labels, paddle.Tensor) and isinstance(gt_bboxes,
paddle.Tensor):
assert gt_labels.ndim == gt_bboxes.ndim and \
gt_bboxes.ndim == 3
pad_gt_mask = (
gt_bboxes.sum(axis=-1, keepdim=True) > 0).astype(gt_bboxes.dtype)
if gt_scores is None:
gt_scores = pad_gt_mask.clone()
assert gt_labels.ndim == gt_scores.ndim
return gt_labels, gt_bboxes, gt_scores, pad_gt_mask
elif isinstance(gt_labels, list) and isinstance(gt_bboxes, list):
assert len(gt_labels) == len(gt_bboxes), \
'The number of `gt_labels` and `gt_bboxes` is not equal. '
num_max_boxes = max([len(a) for a in gt_bboxes])
batch_size = len(gt_bboxes)
# pad label and bbox
pad_gt_labels = paddle.zeros(
[batch_size, num_max_boxes, 1], dtype=gt_labels[0].dtype)
pad_gt_bboxes = paddle.zeros(
[batch_size, num_max_boxes, 4], dtype=gt_bboxes[0].dtype)
pad_gt_scores = paddle.zeros(
[batch_size, num_max_boxes, 1], dtype=gt_bboxes[0].dtype)
pad_gt_mask = paddle.zeros(
[batch_size, num_max_boxes, 1], dtype=gt_bboxes[0].dtype)
for i, (label, bbox) in enumerate(zip(gt_labels, gt_bboxes)):
if len(label) > 0 and len(bbox) > 0:
pad_gt_labels[i, :len(label)] = label
pad_gt_bboxes[i, :len(bbox)] = bbox
pad_gt_mask[i, :len(bbox)] = 1.
if gt_scores is not None:
pad_gt_scores[i, :len(gt_scores[i])] = gt_scores[i]
if gt_scores is None:
pad_gt_scores = pad_gt_mask.clone()
return pad_gt_labels, pad_gt_bboxes, pad_gt_scores, pad_gt_mask
else:
raise ValueError('The input `gt_labels` or `gt_bboxes` is invalid! ')
def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9):
r"""
Args:
metrics (Tensor, float32): shape[B, n, L], n: num_gts, L: num_anchors
topk (int): The number of top elements to look for along the axis.
largest (bool) : largest is a flag, if set to true,
algorithm will sort by descending order, otherwise sort by
ascending order. Default: True
topk_mask (Tensor, bool|None): shape[B, n, topk], ignore bbox mask,
Default: None
eps (float): Default: 1e-9
Returns:
is_in_topk (Tensor, float32): shape[B, n, L], value=1. means selected
"""
num_anchors = metrics.shape[-1]
topk_metrics, topk_idxs = paddle.topk(
metrics, topk, axis=-1, largest=largest)
if topk_mask is None:
topk_mask = (topk_metrics.max(axis=-1, keepdim=True) > eps).tile(
[1, 1, topk])
topk_idxs = paddle.where(topk_mask, topk_idxs, paddle.zeros_like(topk_idxs))
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2)
is_in_topk = paddle.where(is_in_topk > 1,
paddle.zeros_like(is_in_topk), is_in_topk)
return is_in_topk.astype(metrics.dtype)
def check_points_inside_bboxes(points, bboxes, eps=1e-9):
r"""
Args:
points (Tensor, float32): shape[L, 2], "xy" format, L: num_anchors
bboxes (Tensor, float32): shape[B, n, 4], "xmin, ymin, xmax, ymax" format
eps (float): Default: 1e-9
Returns:
is_in_bboxes (Tensor, float32): shape[B, n, L], value=1. means selected
"""
points = points.unsqueeze([0, 1])
x, y = points.chunk(2, axis=-1)
xmin, ymin, xmax, ymax = bboxes.unsqueeze(2).chunk(4, axis=-1)
l = x - xmin
t = y - ymin
r = xmax - x
b = ymax - y
bbox_ltrb = paddle.concat([l, t, r, b], axis=-1)
return (bbox_ltrb.min(axis=-1) > eps).astype(bboxes.dtype)
def compute_max_iou_anchor(ious):
r"""
For each anchor, find the GT with the largest IOU.
Args:
ious (Tensor, float32): shape[B, n, L], n: num_gts, L: num_anchors
Returns:
is_max_iou (Tensor, float32): shape[B, n, L], value=1. means selected
"""
num_max_boxes = ious.shape[-2]
max_iou_index = ious.argmax(axis=-2)
is_max_iou = F.one_hot(max_iou_index, num_max_boxes).transpose([0, 2, 1])
return is_max_iou.astype(ious.dtype)
def compute_max_iou_gt(ious):
r"""
For each GT, find the anchor with the largest IOU.
Args:
ious (Tensor, float32): shape[B, n, L], n: num_gts, L: num_anchors
Returns:
is_max_iou (Tensor, float32): shape[B, n, L], value=1. means selected
"""
num_anchors = ious.shape[-1]
max_iou_index = ious.argmax(axis=-1)
is_max_iou = F.one_hot(max_iou_index, num_anchors)
return is_max_iou.astype(ious.dtype)
......@@ -645,3 +645,15 @@ def distance2bbox(points, distance, max_shape=None):
x2 = x2.clip(min=0, max=max_shape[1])
y2 = y2.clip(min=0, max=max_shape[0])
return paddle.stack([x1, y1, x2, y2], -1)
def bbox_center(boxes):
"""Get bbox centers from boxes.
Args:
boxes (Tensor): boxes with shape (N, 4), "xmin, ymin, xmax, ymax" format.
Returns:
Tensor: boxes centers with shape (N, 2), "cx, cy" format.
"""
boxes_cx = (boxes[:, 0] + boxes[:, 2]) / 2
boxes_cy = (boxes[:, 1] + boxes[:, 3]) / 2
return paddle.stack([boxes_cx, boxes_cy], axis=-1)
......@@ -29,6 +29,7 @@ from . import gfl_head
from . import pico_head
from . import detr_head
from . import sparsercnn_head
from . import tood_head
from .bbox_head import *
from .mask_head import *
......@@ -47,3 +48,4 @@ from .gfl_head import *
from .pico_head import *
from .detr_head import *
from .sparsercnn_head import *
from .tood_head import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Constant
from ppdet.core.workspace import register
from ..initializer import normal_, constant_, bias_init_with_prob
from ppdet.modeling.bbox_utils import bbox_center
from ..losses import GIoULoss
from paddle.vision.ops import deform_conv2d
from ppdet.modeling.layers import ConvNormLayer
class ScaleReg(nn.Layer):
"""
Parameter for scaling the regression outputs.
"""
def __init__(self, init_scale=1.):
super(ScaleReg, self).__init__()
self.scale_reg = self.create_parameter(
shape=[1],
attr=ParamAttr(initializer=Constant(value=init_scale)),
dtype="float32")
def forward(self, inputs):
out = inputs * self.scale_reg
return out
class TaskDecomposition(nn.Layer):
def __init__(
self,
feat_channels,
stacked_convs,
la_down_rate=8,
norm_type='gn',
norm_groups=32, ):
super(TaskDecomposition, self).__init__()
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.norm_type = norm_type
self.norm_groups = norm_groups
self.in_channels = self.feat_channels * self.stacked_convs
self.la_conv1 = nn.Conv2D(self.in_channels,
self.in_channels // la_down_rate, 1)
self.la_conv2 = nn.Conv2D(self.in_channels // la_down_rate,
self.stacked_convs, 1)
self.reduction_conv = ConvNormLayer(
self.in_channels,
self.feat_channels,
filter_size=1,
stride=1,
norm_type=self.norm_type,
norm_groups=self.norm_groups)
self._init_weights()
def _init_weights(self):
normal_(self.la_conv1.weight, std=0.001)
normal_(self.la_conv2.weight, std=0.001)
def forward(self, feat, avg_feat=None):
b, _, h, w = feat.shape
if avg_feat is None:
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
weight = F.relu(self.la_conv1(avg_feat))
weight = F.sigmoid(self.la_conv2(weight))
# here new_conv_weight = layer_attention_weight * conv_weight
# in order to save memory and FLOPs.
conv_weight = weight.reshape([b, 1, self.stacked_convs, 1]) * \
self.reduction_conv.conv.weight.reshape(
[1, self.feat_channels, self.stacked_convs, self.feat_channels])
conv_weight = conv_weight.reshape(
[b, self.feat_channels, self.in_channels])
feat = feat.reshape([b, self.in_channels, h * w])
feat = paddle.bmm(conv_weight, feat).reshape(
[b, self.feat_channels, h, w])
if self.norm_type is not None:
feat = self.reduction_conv.norm(feat)
feat = F.relu(feat)
return feat
@register
class TOODHead(nn.Layer):
__inject__ = ['nms', 'static_assigner', 'assigner']
__shared__ = ['num_classes']
def __init__(self,
num_classes=80,
feat_channels=256,
stacked_convs=6,
fpn_strides=(8, 16, 32, 64, 128),
grid_cell_scale=8,
grid_cell_offset=0.5,
norm_type='gn',
norm_groups=32,
static_assigner_epoch=4,
use_align_head=True,
loss_weight={
'class': 1.0,
'bbox': 1.0,
'iou': 2.0,
},
nms='MultiClassNMS',
static_assigner='ATSSAssigner',
assigner='TaskAlignedAssigner'):
super(TOODHead, self).__init__()
self.num_classes = num_classes
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.fpn_strides = fpn_strides
self.grid_cell_scale = grid_cell_scale
self.grid_cell_offset = grid_cell_offset
self.static_assigner_epoch = static_assigner_epoch
self.use_align_head = use_align_head
self.nms = nms
self.static_assigner = static_assigner
self.assigner = assigner
self.loss_weight = loss_weight
self.giou_loss = GIoULoss()
self.inter_convs = nn.LayerList()
for i in range(self.stacked_convs):
self.inter_convs.append(
ConvNormLayer(
self.feat_channels,
self.feat_channels,
filter_size=3,
stride=1,
norm_type=norm_type,
norm_groups=norm_groups))
self.cls_decomp = TaskDecomposition(
self.feat_channels,
self.stacked_convs,
self.stacked_convs * 8,
norm_type=norm_type,
norm_groups=norm_groups)
self.reg_decomp = TaskDecomposition(
self.feat_channels,
self.stacked_convs,
self.stacked_convs * 8,
norm_type=norm_type,
norm_groups=norm_groups)
self.tood_cls = nn.Conv2D(
self.feat_channels, self.num_classes, 3, padding=1)
self.tood_reg = nn.Conv2D(self.feat_channels, 4, 3, padding=1)
if self.use_align_head:
self.cls_prob_conv1 = nn.Conv2D(self.feat_channels *
self.stacked_convs,
self.feat_channels // 4, 1)
self.cls_prob_conv2 = nn.Conv2D(
self.feat_channels // 4, 1, 3, padding=1)
self.reg_offset_conv1 = nn.Conv2D(self.feat_channels *
self.stacked_convs,
self.feat_channels // 4, 1)
self.reg_offset_conv2 = nn.Conv2D(
self.feat_channels // 4, 4 * 2, 3, padding=1)
self.scales_regs = nn.LayerList([ScaleReg() for _ in self.fpn_strides])
self._init_weights()
@classmethod
def from_config(cls, cfg, input_shape):
return {
'feat_channels': input_shape[0].channels,
'fpn_strides': [i.stride for i in input_shape],
}
def _init_weights(self):
bias_cls = bias_init_with_prob(0.01)
normal_(self.tood_cls.weight, std=0.01)
constant_(self.tood_cls.bias, bias_cls)
normal_(self.tood_reg.weight, std=0.01)
if self.use_align_head:
normal_(self.cls_prob_conv1.weight, std=0.01)
normal_(self.cls_prob_conv2.weight, std=0.01)
constant_(self.cls_prob_conv2.bias, bias_cls)
normal_(self.reg_offset_conv1.weight, std=0.001)
normal_(self.reg_offset_conv2.weight, std=0.001)
constant_(self.reg_offset_conv2.bias)
def _generate_anchors(self, feats):
anchors, num_anchors_list = [], []
stride_tensor_list = []
for feat, stride in zip(feats, self.fpn_strides):
_, _, h, w = feat.shape
cell_half_size = self.grid_cell_scale * stride * 0.5
shift_x = (paddle.arange(end=w) + self.grid_cell_offset) * stride
shift_y = (paddle.arange(end=h) + self.grid_cell_offset) * stride
shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
anchor = paddle.stack(
[
shift_x - cell_half_size, shift_y - cell_half_size,
shift_x + cell_half_size, shift_y + cell_half_size
],
axis=-1)
anchors.append(anchor.reshape([-1, 4]))
num_anchors_list.append(len(anchors[-1]))
stride_tensor_list.append(
paddle.full([num_anchors_list[-1], 1], stride))
return anchors, num_anchors_list, stride_tensor_list
@staticmethod
def _batch_distance2bbox(points, distance, max_shapes=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): [B, l, 2]
distance (Tensor): [B, l, 4]
max_shapes (tuple): [B, 2], "h w" format, Shape of the image.
Returns:
Tensor: Decoded bboxes.
"""
x1 = points[:, :, 0] - distance[:, :, 0]
y1 = points[:, :, 1] - distance[:, :, 1]
x2 = points[:, :, 0] + distance[:, :, 2]
y2 = points[:, :, 1] + distance[:, :, 3]
bboxes = paddle.stack([x1, y1, x2, y2], -1)
if max_shapes is not None:
out_bboxes = []
for bbox, max_shape in zip(bboxes, max_shapes):
bbox[:, 0] = bbox[:, 0].clip(min=0, max=max_shape[1])
bbox[:, 1] = bbox[:, 1].clip(min=0, max=max_shape[0])
bbox[:, 2] = bbox[:, 2].clip(min=0, max=max_shape[1])
bbox[:, 3] = bbox[:, 3].clip(min=0, max=max_shape[0])
out_bboxes.append(bbox)
out_bboxes = paddle.stack(out_bboxes)
return out_bboxes
return bboxes
@staticmethod
def _deform_sampling(feat, offset):
""" Sampling the feature according to offset.
Args:
feat (Tensor): Feature
offset (Tensor): Spatial offset for for feature sampliing
"""
# it is an equivalent implementation of bilinear interpolation
# you can also use F.grid_sample instead
c = feat.shape[1]
weight = paddle.ones([c, 1, 1, 1])
y = deform_conv2d(feat, offset, weight, deformable_groups=c, groups=c)
return y
def forward(self, feats):
assert len(feats) == len(self.fpn_strides), \
"The size of feats is not equal to size of fpn_strides"
anchors, num_anchors_list, stride_tensor_list = self._generate_anchors(
feats)
cls_score_list, bbox_pred_list = [], []
for feat, scale_reg, anchor, stride in zip(feats, self.scales_regs,
anchors, self.fpn_strides):
b, _, h, w = feat.shape
inter_feats = []
for inter_conv in self.inter_convs:
feat = F.relu(inter_conv(feat))
inter_feats.append(feat)
feat = paddle.concat(inter_feats, axis=1)
# task decomposition
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
cls_feat = self.cls_decomp(feat, avg_feat)
reg_feat = self.reg_decomp(feat, avg_feat)
# cls prediction and alignment
cls_logits = self.tood_cls(cls_feat)
if self.use_align_head:
cls_prob = F.relu(self.cls_prob_conv1(feat))
cls_prob = F.sigmoid(self.cls_prob_conv2(cls_prob))
cls_score = (F.sigmoid(cls_logits) * cls_prob).sqrt()
else:
cls_score = F.sigmoid(cls_logits)
cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
# reg prediction and alignment
reg_dist = scale_reg(self.tood_reg(reg_feat).exp())
reg_dist = reg_dist.transpose([0, 2, 3, 1]).reshape([b, -1, 4])
anchor_centers = bbox_center(anchor).unsqueeze(0) / stride
reg_bbox = self._batch_distance2bbox(
anchor_centers.tile([b, 1, 1]), reg_dist)
if self.use_align_head:
reg_bbox = reg_bbox.reshape([b, h, w, 4]).transpose(
[0, 3, 1, 2])
reg_offset = F.relu(self.reg_offset_conv1(feat))
reg_offset = self.reg_offset_conv2(reg_offset)
bbox_pred = self._deform_sampling(reg_bbox, reg_offset)
bbox_pred = bbox_pred.flatten(2).transpose([0, 2, 1])
else:
bbox_pred = reg_bbox
if not self.training:
bbox_pred *= stride
bbox_pred_list.append(bbox_pred)
cls_score_list = paddle.concat(cls_score_list, axis=1)
bbox_pred_list = paddle.concat(bbox_pred_list, axis=1)
anchors = paddle.concat(anchors)
anchors.stop_gradient = True
stride_tensor_list = paddle.concat(stride_tensor_list).unsqueeze(0)
stride_tensor_list.stop_gradient = True
return cls_score_list, bbox_pred_list, anchors, num_anchors_list, stride_tensor_list
@staticmethod
def _focal_loss(score, label, alpha=0.25, gamma=2.0):
weight = (score - label).pow(gamma)
if alpha > 0:
alpha_t = alpha * label + (1 - alpha) * (1 - label)
weight *= alpha_t
loss = F.binary_cross_entropy(
score, label, weight=weight, reduction='sum')
return loss
def get_loss(self, head_outs, gt_meta):
pred_scores, pred_bboxes, anchors, num_anchors_list, stride_tensor_list = head_outs
gt_labels = gt_meta['gt_class']
gt_bboxes = gt_meta['gt_bbox']
# label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner(
anchors,
num_anchors_list,
gt_labels,
gt_bboxes,
bg_index=self.num_classes)
alpha_l = 0.25
else:
assigned_labels, assigned_bboxes, assigned_scores = self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor_list,
bbox_center(anchors),
gt_labels,
gt_bboxes,
bg_index=self.num_classes)
alpha_l = -1
# rescale bbox
assigned_bboxes /= stride_tensor_list
# classification loss
loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha=alpha_l)
# select positive samples mask
mask_positive = (assigned_labels != self.num_classes)
num_pos = mask_positive.astype(paddle.float32).sum()
# bbox regression loss
if num_pos > 0:
bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4])
pred_bboxes_pos = paddle.masked_select(pred_bboxes,
bbox_mask).reshape([-1, 4])
assigned_bboxes_pos = paddle.masked_select(
assigned_bboxes, bbox_mask).reshape([-1, 4])
bbox_weight = paddle.masked_select(
assigned_scores.sum(-1), mask_positive).unsqueeze(-1)
# iou loss
loss_iou = self.giou_loss(pred_bboxes_pos,
assigned_bboxes_pos) * bbox_weight
loss_iou = loss_iou.sum() / bbox_weight.sum()
# l1 loss
loss_l1 = F.l1_loss(pred_bboxes_pos, assigned_bboxes_pos)
else:
loss_iou = paddle.zeros([1])
loss_l1 = paddle.zeros([1])
loss_cls /= assigned_scores.sum().clip(min=1)
loss = self.loss_weight['class'] * loss_cls + self.loss_weight[
'iou'] * loss_iou
return {
'loss': loss,
'loss_class': loss_cls,
'loss_iou': loss_iou,
'loss_l1': loss_l1
}
def post_process(self, head_outs, img_shape, scale_factor):
pred_scores, pred_bboxes, _, _, _ = head_outs
pred_scores = pred_scores.transpose([0, 2, 1])
for i in range(len(pred_bboxes)):
pred_bboxes[i, :, 0] = pred_bboxes[i, :, 0].clip(
min=0, max=img_shape[i, 1])
pred_bboxes[i, :, 1] = pred_bboxes[i, :, 1].clip(
min=0, max=img_shape[i, 0])
pred_bboxes[i, :, 2] = pred_bboxes[i, :, 2].clip(
min=0, max=img_shape[i, 1])
pred_bboxes[i, :, 3] = pred_bboxes[i, :, 3].clip(
min=0, max=img_shape[i, 0])
# scale bbox to origin
scale_factor = scale_factor.flip([1]).tile([1, 2]).unsqueeze(1)
pred_bboxes /= scale_factor
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num
......@@ -272,6 +272,12 @@ def conv_init_(module):
uniform_(module.bias, -bound, bound)
def bias_init_with_prob(prior_prob=0.01):
"""initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init
@paddle.no_grad()
def reset_initialized_parameter(model, include_self=True):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册