diff --git a/configs/tood/README.md b/configs/tood/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f23b1844ab5e1124f39dfd2bf4335b2f401b9db0 --- /dev/null +++ b/configs/tood/README.md @@ -0,0 +1,35 @@ +# TOOD + +## Introduction + +[TOOD: Task-aligned One-stage Object Detection](https://arxiv.org/abs/2108.07755) + +TOOD is an object detection model. We reproduced the model of the paper. + + +## Model Zoo + +| Backbone | Model | Images/GPU | Inf time (fps) | Box AP | Config | Download | +|:------:|:--------:|:--------:|:--------------:|:------:|:------:|:--------:| +| R-50 | TOOD | 4 | --- | 42.8 | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/tood/tood_r50_fpn_1x_coco.yml) | [model](https://paddledet.bj.bcebos.com/models/tood_r50_fpn_1x_coco.pdparams) | + +**Notes:** + +- TOOD is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`. +- TOOD uses 8GPU to train 12 epochs. + +GPU multi-card training +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/tood/tood_r50_fpn_1x_coco.yml --fleet +``` + +## Citations +``` +@inproceedings{feng2021tood, + title={TOOD: Task-aligned One-stage Object Detection}, + author={Feng, Chengjian and Zhong, Yujie and Gao, Yu and Scott, Matthew R and Huang, Weilin}, + booktitle={ICCV}, + year={2021} +} +``` diff --git a/configs/tood/_base_/optimizer_1x.yml b/configs/tood/_base_/optimizer_1x.yml new file mode 100644 index 0000000000000000000000000000000000000000..39c54ac805031619debf9b31119afa86b3ead857 --- /dev/null +++ b/configs/tood/_base_/optimizer_1x.yml @@ -0,0 +1,19 @@ +epoch: 12 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [8, 11] + - !LinearWarmup + start_factor: 0.001 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 diff --git a/configs/tood/_base_/tood_r50_fpn.yml b/configs/tood/_base_/tood_r50_fpn.yml new file mode 100644 index 0000000000000000000000000000000000000000..0cb8575b09beb8ba4d0e20d2512bdac5b34ecaf1 --- /dev/null +++ b/configs/tood/_base_/tood_r50_fpn.yml @@ -0,0 +1,42 @@ +architecture: TOOD +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams + +TOOD: + backbone: ResNet + neck: FPN + head: TOODHead + +ResNet: + depth: 50 + variant: b + norm_type: bn + freeze_at: 0 + return_idx: [1, 2, 3] + num_stages: 4 + +FPN: + out_channel: 256 + spatial_scales: [0.125, 0.0625, 0.03125] + extra_stage: 2 + has_extra_convs: true + use_c5: false + +TOODHead: + stacked_convs: 6 + grid_cell_scale: 8 + static_assigner_epoch: 4 + loss_weight: { class: 1.0, iou: 2.0 } + static_assigner: + name: ATSSAssigner + topk: 9 + assigner: + name: TaskAlignedAssigner + topk: 13 + alpha: 1.0 + beta: 6.0 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.05 + nms_threshold: 0.6 diff --git a/configs/tood/_base_/tood_reader.yml b/configs/tood/_base_/tood_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..cda3cb80db4144fe3469c3844b3b698e1357e539 --- /dev/null +++ b/configs/tood/_base_/tood_reader.yml @@ -0,0 +1,39 @@ +worker_num: 4 +TrainReader: + sample_transforms: + - Decode: {} + - RandomFlip: {prob: 0.5} + - Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1} + - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 4 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: true + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false diff --git a/configs/tood/tood_r50_fpn_1x_coco.yml b/configs/tood/tood_r50_fpn_1x_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..3d05c9884ea12013ea7b599d9c04c81abd709f40 --- /dev/null +++ b/configs/tood/tood_r50_fpn_1x_coco.yml @@ -0,0 +1,11 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/tood_r50_fpn.yml', + '_base_/optimizer_1x.yml', + '_base_/tood_reader.yml', +] + +weights: output/tood_r50_fpn_1x_coco/model_final +find_unused_parameters: True +log_iter: 100 diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 85490ae4b0dbf39e485c31550e8692a6656f0942..2d73e845ee7f0994dbd1f588d971c6359260ecc2 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -352,6 +352,7 @@ class Trainer(object): self.status['data_time'].update(time.time() - iter_tic) self.status['step_id'] = step_id self._compose_callback.on_step_begin(self.status) + data['epoch_id'] = epoch_id if self.cfg.get('fp16', False): with amp.auto_cast(enable=self.cfg.use_gpu): diff --git a/ppdet/modeling/__init__.py b/ppdet/modeling/__init__.py index 5e4c26120e548a8d90ab9d8a2cb7c7bd4a6deee2..cdcb5d1bf08d813257dc577366de2efa9da9add7 100644 --- a/ppdet/modeling/__init__.py +++ b/ppdet/modeling/__init__.py @@ -28,6 +28,7 @@ from . import layers from . import reid from . import mot from . import transformers +from . import assigners from .ops import * from .backbones import * @@ -41,3 +42,4 @@ from .layers import * from .reid import * from .mot import * from .transformers import * +from .assigners import * diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index 278d72000157b0049e9e0e46b174da843eb99471..b5feb06d8aaa7cd1c99d3473d6ea64ba68feef0a 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -25,6 +25,7 @@ from . import gfl from . import picodet from . import detr from . import sparse_rcnn +from . import tood from .meta_arch import * from .faster_rcnn import * @@ -47,3 +48,4 @@ from .gfl import * from .picodet import * from .detr import * from .sparse_rcnn import * +from .tood import * diff --git a/ppdet/modeling/architectures/tood.py b/ppdet/modeling/architectures/tood.py new file mode 100644 index 0000000000000000000000000000000000000000..157ec6f3a581a1a4f14b915553c397213c29dcd2 --- /dev/null +++ b/ppdet/modeling/architectures/tood.py @@ -0,0 +1,77 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ppdet.core.workspace import register, create +from .meta_arch import BaseArch + +__all__ = ['TOOD'] + + +@register +class TOOD(BaseArch): + """ + TOOD: Task-aligned One-stage Object Detection, see https://arxiv.org/abs/2108.07755 + Args: + backbone (nn.Layer): backbone instance + neck (nn.Layer): 'FPN' instance + head (nn.Layer): 'TOODHead' instance + """ + + __category__ = 'architecture' + + def __init__(self, backbone, neck, head): + super(TOOD, self).__init__() + self.backbone = backbone + self.neck = neck + self.head = head + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + backbone = create(cfg['backbone']) + + kwargs = {'input_shape': backbone.out_shape} + neck = create(cfg['neck'], **kwargs) + + kwargs = {'input_shape': neck.out_shape} + head = create(cfg['head'], **kwargs) + + return { + 'backbone': backbone, + 'neck': neck, + "head": head, + } + + def _forward(self): + body_feats = self.backbone(self.inputs) + fpn_feats = self.neck(body_feats) + head_outs = self.head(fpn_feats) + if not self.training: + bboxes, bbox_num = self.head.post_process( + head_outs, self.inputs['im_shape'], self.inputs['scale_factor']) + return bboxes, bbox_num + else: + loss = self.head.get_loss(head_outs, self.inputs) + return loss + + def get_loss(self): + return self._forward() + + def get_pred(self): + bbox_pred, bbox_num = self._forward() + output = {'bbox': bbox_pred, 'bbox_num': bbox_num} + return output diff --git a/ppdet/modeling/assigners/__init__.py b/ppdet/modeling/assigners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..108ac7ecde6f80ec14a43eceb2e875b90e6a82a8 --- /dev/null +++ b/ppdet/modeling/assigners/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import utils +from . import task_aligned_assigner +from . import atss_assigner + +from .utils import * +from .task_aligned_assigner import * +from .atss_assigner import * diff --git a/ppdet/modeling/assigners/atss_assigner.py b/ppdet/modeling/assigners/atss_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..8d8e555ae024d75f4bdfe26cbf39e7f954f1a5b0 --- /dev/null +++ b/ppdet/modeling/assigners/atss_assigner.py @@ -0,0 +1,198 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ppdet.core.workspace import register +from ..ops import iou_similarity +from ..bbox_utils import bbox_center +from .utils import (pad_gt, check_points_inside_bboxes, compute_max_iou_anchor, + compute_max_iou_gt) + + +@register +class ATSSAssigner(nn.Layer): + """Bridging the Gap Between Anchor-based and Anchor-free Detection + via Adaptive Training Sample Selection + """ + __shared__ = ['num_classes'] + + def __init__(self, + topk=9, + num_classes=80, + force_gt_matching=False, + eps=1e-9): + super(ATSSAssigner, self).__init__() + self.topk = topk + self.num_classes = num_classes + self.force_gt_matching = force_gt_matching + self.eps = eps + + def _gather_topk_pyramid(self, gt2anchor_distances, num_anchors_list, + pad_gt_mask): + pad_gt_mask = pad_gt_mask.tile([1, 1, self.topk]).astype(paddle.bool) + gt2anchor_distances_list = paddle.split( + gt2anchor_distances, num_anchors_list, axis=-1) + num_anchors_index = np.cumsum(num_anchors_list).tolist() + num_anchors_index = [0, ] + num_anchors_index[:-1] + is_in_topk_list = [] + topk_idxs_list = [] + for distances, anchors_index in zip(gt2anchor_distances_list, + num_anchors_index): + num_anchors = distances.shape[-1] + topk_metrics, topk_idxs = paddle.topk( + distances, self.topk, axis=-1, largest=False) + topk_idxs_list.append(topk_idxs + anchors_index) + topk_idxs = paddle.where(pad_gt_mask, topk_idxs, + paddle.zeros_like(topk_idxs)) + is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2) + is_in_topk = paddle.where(is_in_topk > 1, + paddle.zeros_like(is_in_topk), is_in_topk) + is_in_topk_list.append(is_in_topk.astype(gt2anchor_distances.dtype)) + is_in_topk_list = paddle.concat(is_in_topk_list, axis=-1) + topk_idxs_list = paddle.concat(topk_idxs_list, axis=-1) + return is_in_topk_list, topk_idxs_list + + @paddle.no_grad() + def forward(self, + anchor_bboxes, + num_anchors_list, + gt_labels, + gt_bboxes, + bg_index, + gt_scores=None): + r"""The assignment is done in following steps + 1. compute iou between all bbox (bbox of all pyramid levels) and gt + 2. compute center distance between all bbox and gt + 3. on each pyramid level, for each gt, select k bbox whose center + are closest to the gt center, so we total select k*l bbox as + candidates for each gt + 4. get corresponding iou for the these candidates, and compute the + mean and std, set mean + std as the iou threshold + 5. select these candidates whose iou are greater than or equal to + the threshold as positive + 6. limit the positive sample's center in gt + 7. if an anchor box is assigned to multiple gts, the one with the + highest iou will be selected. + Args: + anchor_bboxes (Tensor, float32): pre-defined anchors, shape(L, 4), + "xmin, xmax, ymin, ymax" format + num_anchors_list (List): num of anchors in each level + gt_labels (Tensor|List[Tensor], int64): Label of gt_bboxes, shape(B, n, 1) + gt_bboxes (Tensor|List[Tensor], float32): Ground truth bboxes, shape(B, n, 4) + bg_index (int): background index + gt_scores (Tensor|List[Tensor]|None, float32) Score of gt_bboxes, + shape(B, n, 1), if None, then it will initialize with one_hot label + Returns: + assigned_labels (Tensor): (B, L) + assigned_bboxes (Tensor): (B, L, 4) + assigned_scores (Tensor): (B, L, C) + """ + gt_labels, gt_bboxes, pad_gt_scores, pad_gt_mask = pad_gt( + gt_labels, gt_bboxes, gt_scores) + assert gt_labels.ndim == gt_bboxes.ndim and \ + gt_bboxes.ndim == 3 + + num_anchors, _ = anchor_bboxes.shape + batch_size, num_max_boxes, _ = gt_bboxes.shape + + # 1. compute iou between gt and anchor bbox, [B, n, L] + ious = iou_similarity(gt_bboxes.reshape([-1, 4]), anchor_bboxes) + ious = ious.reshape([batch_size, -1, num_anchors]) + + # 2. compute center distance between all anchors and gt, [B, n, L] + gt_centers = bbox_center(gt_bboxes.reshape([-1, 4])).unsqueeze(1) + anchor_centers = bbox_center(anchor_bboxes) + gt2anchor_distances = (gt_centers - anchor_centers.unsqueeze(0)) \ + .norm(2, axis=-1).reshape([batch_size, -1, num_anchors]) + + # 3. on each pyramid level, selecting topk closest candidates + # based on the center distance, [B, n, L] + is_in_topk, topk_idxs = self._gather_topk_pyramid( + gt2anchor_distances, num_anchors_list, pad_gt_mask) + + # 4. get corresponding iou for the these candidates, and compute the + # mean and std, 5. set mean + std as the iou threshold + iou_candidates = ious * is_in_topk + iou_threshold = paddle.index_sample( + iou_candidates.flatten(stop_axis=-2), + topk_idxs.flatten(stop_axis=-2)) + iou_threshold = iou_threshold.reshape([batch_size, num_max_boxes, -1]) + iou_threshold = iou_threshold.mean(axis=-1, keepdim=True) + \ + iou_threshold.std(axis=-1, keepdim=True) + is_in_topk = paddle.where( + iou_candidates > iou_threshold.tile([1, 1, num_anchors]), + is_in_topk, paddle.zeros_like(is_in_topk)) + + # 6. check the positive sample's center in gt, [B, n, L] + is_in_gts = check_points_inside_bboxes(anchor_centers, gt_bboxes) + + # select positive sample, [B, n, L] + mask_positive = is_in_topk * is_in_gts * pad_gt_mask + + # 7. if an anchor box is assigned to multiple gts, + # the one with the highest iou will be selected. + mask_positive_sum = mask_positive.sum(axis=-2) + if mask_positive_sum.max() > 1: + mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile( + [1, num_max_boxes, 1]) + is_max_iou = compute_max_iou_anchor(ious) + mask_positive = paddle.where(mask_multiple_gts, is_max_iou, + mask_positive) + mask_positive_sum = mask_positive.sum(axis=-2) + # 8. make sure every gt_bbox matches the anchor + if self.force_gt_matching: + is_max_iou = compute_max_iou_gt(ious) * pad_gt_mask + mask_max_iou = (is_max_iou.sum(-2, keepdim=True) == 1).tile( + [1, num_max_boxes, 1]) + mask_positive = paddle.where(mask_max_iou, is_max_iou, + mask_positive) + mask_positive_sum = mask_positive.sum(axis=-2) + assigned_gt_index = mask_positive.argmax(axis=-2) + assert mask_positive_sum.max() == 1, \ + ("one anchor just assign one gt, but received not equals 1. " + "Received: %f" % mask_positive_sum.max().item()) + + # assigned target + batch_ind = paddle.arange( + end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1) + assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes + assigned_labels = paddle.gather( + gt_labels.flatten(), assigned_gt_index.flatten(), axis=0) + assigned_labels = assigned_labels.reshape([batch_size, num_anchors]) + assigned_labels = paddle.where( + mask_positive_sum > 0, assigned_labels, + paddle.full_like(assigned_labels, bg_index)) + + assigned_bboxes = paddle.gather( + gt_bboxes.reshape([-1, 4]), assigned_gt_index.flatten(), axis=0) + assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4]) + + assigned_scores = F.one_hot(assigned_labels, self.num_classes) + if gt_scores is not None: + gather_scores = paddle.gather( + pad_gt_scores.flatten(), assigned_gt_index.flatten(), axis=0) + gather_scores = gather_scores.reshape([batch_size, num_anchors]) + gather_scores = paddle.where(mask_positive_sum > 0, gather_scores, + paddle.zeros_like(gather_scores)) + assigned_scores *= gather_scores.unsqueeze(-1) + + return assigned_labels, assigned_bboxes, assigned_scores diff --git a/ppdet/modeling/assigners/task_aligned_assigner.py b/ppdet/modeling/assigners/task_aligned_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..ed43c4936bbf3f3ea8b5ebad0c283d7d46238b0e --- /dev/null +++ b/ppdet/modeling/assigners/task_aligned_assigner.py @@ -0,0 +1,147 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ppdet.core.workspace import register +from ..bbox_utils import iou_similarity +from .utils import (pad_gt, gather_topk_anchors, check_points_inside_bboxes, + compute_max_iou_anchor) + + +@register +class TaskAlignedAssigner(nn.Layer): + """TOOD: Task-aligned One-stage Object Detection + """ + + def __init__(self, topk=13, alpha=1.0, beta=6.0, eps=1e-9): + super(TaskAlignedAssigner, self).__init__() + self.topk = topk + self.alpha = alpha + self.beta = beta + self.eps = eps + + @paddle.no_grad() + def forward(self, + pred_scores, + pred_bboxes, + anchor_points, + gt_labels, + gt_bboxes, + bg_index, + gt_scores=None): + r"""The assignment is done in following steps + 1. compute alignment metric between all bbox (bbox of all pyramid levels) and gt + 2. select top-k bbox as candidates for each gt + 3. limit the positive sample's center in gt (because the anchor-free detector + only can predict positive distance) + 4. if an anchor box is assigned to multiple gts, the one with the + highest iou will be selected. + Args: + pred_scores (Tensor, float32): predicted class probability, shape(B, L, C) + pred_bboxes (Tensor, float32): predicted bounding boxes, shape(B, L, 4) + anchor_points (Tensor, float32): pre-defined anchors, shape(L, 2), "cxcy" format + gt_labels (Tensor|List[Tensor], int64): Label of gt_bboxes, shape(B, n, 1) + gt_bboxes (Tensor|List[Tensor], float32): Ground truth bboxes, shape(B, n, 4) + bg_index (int): background index + gt_scores (Tensor|List[Tensor]|None, float32) Score of gt_bboxes, + shape(B, n, 1), if None, then it will initialize with one_hot label + Returns: + assigned_labels (Tensor): (B, L) + assigned_bboxes (Tensor): (B, L, 4) + assigned_scores (Tensor): (B, L, C) + """ + assert pred_scores.ndim == pred_bboxes.ndim + + gt_labels, gt_bboxes, pad_gt_scores, pad_gt_mask = pad_gt( + gt_labels, gt_bboxes, gt_scores) + assert gt_labels.ndim == gt_bboxes.ndim and \ + gt_bboxes.ndim == 3 + + batch_size, num_anchors, num_classes = pred_scores.shape + _, num_max_boxes, _ = gt_bboxes.shape + + # compute iou between gt and pred bbox, [B, n, L] + ious = iou_similarity(gt_bboxes, pred_bboxes) + # gather pred bboxes class score + pred_scores = pred_scores.transpose([0, 2, 1]) + batch_ind = paddle.arange( + end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1) + gt_labels_ind = paddle.stack( + [batch_ind.tile([1, num_max_boxes]), gt_labels.squeeze(-1)], + axis=-1) + bbox_cls_scores = paddle.gather_nd(pred_scores, gt_labels_ind) + # compute alignment metrics, [B, n, L] + alignment_metrics = bbox_cls_scores.pow(self.alpha) * ious.pow( + self.beta) + + # check the positive sample's center in gt, [B, n, L] + is_in_gts = check_points_inside_bboxes(anchor_points, gt_bboxes) + + # select topk largest alignment metrics pred bbox as candidates + # for each gt, [B, n, L] + is_in_topk = gather_topk_anchors( + alignment_metrics * is_in_gts, + self.topk, + topk_mask=pad_gt_mask.tile([1, 1, self.topk]).astype(paddle.bool)) + + # select positive sample, [B, n, L] + mask_positive = is_in_topk * is_in_gts * pad_gt_mask + + # if an anchor box is assigned to multiple gts, + # the one with the highest iou will be selected, [B, n, L] + mask_positive_sum = mask_positive.sum(axis=-2) + if mask_positive_sum.max() > 1: + mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile( + [1, num_max_boxes, 1]) + is_max_iou = compute_max_iou_anchor(ious) + mask_positive = paddle.where(mask_multiple_gts, is_max_iou, + mask_positive) + mask_positive_sum = mask_positive.sum(axis=-2) + assigned_gt_index = mask_positive.argmax(axis=-2) + assert mask_positive_sum.max() == 1, \ + ("one anchor just assign one gt, but received not equals 1. " + "Received: %f" % mask_positive_sum.max().item()) + + # assigned target + assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes + assigned_labels = paddle.gather( + gt_labels.flatten(), assigned_gt_index.flatten(), axis=0) + assigned_labels = assigned_labels.reshape([batch_size, num_anchors]) + assigned_labels = paddle.where( + mask_positive_sum > 0, assigned_labels, + paddle.full_like(assigned_labels, bg_index)) + + assigned_bboxes = paddle.gather( + gt_bboxes.reshape([-1, 4]), assigned_gt_index.flatten(), axis=0) + assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4]) + + assigned_scores = F.one_hot(assigned_labels, num_classes) + # rescale alignment metrics + alignment_metrics *= mask_positive + max_metrics_per_instance = alignment_metrics.max(axis=-1, keepdim=True) + max_ious_per_instance = (ious * mask_positive).max(axis=-1, + keepdim=True) + alignment_metrics = alignment_metrics / ( + max_metrics_per_instance + self.eps) * max_ious_per_instance + alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1) + assigned_scores = assigned_scores * alignment_metrics + + return assigned_labels, assigned_bboxes, assigned_scores diff --git a/ppdet/modeling/assigners/utils.py b/ppdet/modeling/assigners/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3448d9d8ae5825bfab2814970e8c52b7ba54548b --- /dev/null +++ b/ppdet/modeling/assigners/utils.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn.functional as F + + +def pad_gt(gt_labels, gt_bboxes, gt_scores=None): + r""" Pad 0 in gt_labels and gt_bboxes. + Args: + gt_labels (Tensor|List[Tensor], int64): Label of gt_bboxes, + shape is [B, n, 1] or [[n_1, 1], [n_2, 1], ...], here n = sum(n_i) + gt_bboxes (Tensor|List[Tensor], float32): Ground truth bboxes, + shape is [B, n, 4] or [[n_1, 4], [n_2, 4], ...], here n = sum(n_i) + gt_scores (Tensor|List[Tensor]|None, float32): Score of gt_bboxes, + shape is [B, n, 1] or [[n_1, 4], [n_2, 4], ...], here n = sum(n_i) + Returns: + pad_gt_labels (Tensor, int64): shape[B, n, 1] + pad_gt_bboxes (Tensor, float32): shape[B, n, 4] + pad_gt_scores (Tensor, float32): shape[B, n, 1] + pad_gt_mask (Tensor, float32): shape[B, n, 1], 1 means bbox, 0 means no bbox + """ + if isinstance(gt_labels, paddle.Tensor) and isinstance(gt_bboxes, + paddle.Tensor): + assert gt_labels.ndim == gt_bboxes.ndim and \ + gt_bboxes.ndim == 3 + pad_gt_mask = ( + gt_bboxes.sum(axis=-1, keepdim=True) > 0).astype(gt_bboxes.dtype) + if gt_scores is None: + gt_scores = pad_gt_mask.clone() + assert gt_labels.ndim == gt_scores.ndim + + return gt_labels, gt_bboxes, gt_scores, pad_gt_mask + elif isinstance(gt_labels, list) and isinstance(gt_bboxes, list): + assert len(gt_labels) == len(gt_bboxes), \ + 'The number of `gt_labels` and `gt_bboxes` is not equal. ' + num_max_boxes = max([len(a) for a in gt_bboxes]) + batch_size = len(gt_bboxes) + # pad label and bbox + pad_gt_labels = paddle.zeros( + [batch_size, num_max_boxes, 1], dtype=gt_labels[0].dtype) + pad_gt_bboxes = paddle.zeros( + [batch_size, num_max_boxes, 4], dtype=gt_bboxes[0].dtype) + pad_gt_scores = paddle.zeros( + [batch_size, num_max_boxes, 1], dtype=gt_bboxes[0].dtype) + pad_gt_mask = paddle.zeros( + [batch_size, num_max_boxes, 1], dtype=gt_bboxes[0].dtype) + for i, (label, bbox) in enumerate(zip(gt_labels, gt_bboxes)): + if len(label) > 0 and len(bbox) > 0: + pad_gt_labels[i, :len(label)] = label + pad_gt_bboxes[i, :len(bbox)] = bbox + pad_gt_mask[i, :len(bbox)] = 1. + if gt_scores is not None: + pad_gt_scores[i, :len(gt_scores[i])] = gt_scores[i] + if gt_scores is None: + pad_gt_scores = pad_gt_mask.clone() + return pad_gt_labels, pad_gt_bboxes, pad_gt_scores, pad_gt_mask + else: + raise ValueError('The input `gt_labels` or `gt_bboxes` is invalid! ') + + +def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9): + r""" + Args: + metrics (Tensor, float32): shape[B, n, L], n: num_gts, L: num_anchors + topk (int): The number of top elements to look for along the axis. + largest (bool) : largest is a flag, if set to true, + algorithm will sort by descending order, otherwise sort by + ascending order. Default: True + topk_mask (Tensor, bool|None): shape[B, n, topk], ignore bbox mask, + Default: None + eps (float): Default: 1e-9 + Returns: + is_in_topk (Tensor, float32): shape[B, n, L], value=1. means selected + """ + num_anchors = metrics.shape[-1] + topk_metrics, topk_idxs = paddle.topk( + metrics, topk, axis=-1, largest=largest) + if topk_mask is None: + topk_mask = (topk_metrics.max(axis=-1, keepdim=True) > eps).tile( + [1, 1, topk]) + topk_idxs = paddle.where(topk_mask, topk_idxs, paddle.zeros_like(topk_idxs)) + is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2) + is_in_topk = paddle.where(is_in_topk > 1, + paddle.zeros_like(is_in_topk), is_in_topk) + return is_in_topk.astype(metrics.dtype) + + +def check_points_inside_bboxes(points, bboxes, eps=1e-9): + r""" + Args: + points (Tensor, float32): shape[L, 2], "xy" format, L: num_anchors + bboxes (Tensor, float32): shape[B, n, 4], "xmin, ymin, xmax, ymax" format + eps (float): Default: 1e-9 + Returns: + is_in_bboxes (Tensor, float32): shape[B, n, L], value=1. means selected + """ + points = points.unsqueeze([0, 1]) + x, y = points.chunk(2, axis=-1) + xmin, ymin, xmax, ymax = bboxes.unsqueeze(2).chunk(4, axis=-1) + l = x - xmin + t = y - ymin + r = xmax - x + b = ymax - y + bbox_ltrb = paddle.concat([l, t, r, b], axis=-1) + return (bbox_ltrb.min(axis=-1) > eps).astype(bboxes.dtype) + + +def compute_max_iou_anchor(ious): + r""" + For each anchor, find the GT with the largest IOU. + Args: + ious (Tensor, float32): shape[B, n, L], n: num_gts, L: num_anchors + Returns: + is_max_iou (Tensor, float32): shape[B, n, L], value=1. means selected + """ + num_max_boxes = ious.shape[-2] + max_iou_index = ious.argmax(axis=-2) + is_max_iou = F.one_hot(max_iou_index, num_max_boxes).transpose([0, 2, 1]) + return is_max_iou.astype(ious.dtype) + + +def compute_max_iou_gt(ious): + r""" + For each GT, find the anchor with the largest IOU. + Args: + ious (Tensor, float32): shape[B, n, L], n: num_gts, L: num_anchors + Returns: + is_max_iou (Tensor, float32): shape[B, n, L], value=1. means selected + """ + num_anchors = ious.shape[-1] + max_iou_index = ious.argmax(axis=-1) + is_max_iou = F.one_hot(max_iou_index, num_anchors) + return is_max_iou.astype(ious.dtype) diff --git a/ppdet/modeling/bbox_utils.py b/ppdet/modeling/bbox_utils.py index df8eda94d9aada710e3e6a9c934c0ffd3e082f21..05d7cef2fc9a04b690fd79b4b84c744bb28ec2b3 100644 --- a/ppdet/modeling/bbox_utils.py +++ b/ppdet/modeling/bbox_utils.py @@ -645,3 +645,15 @@ def distance2bbox(points, distance, max_shape=None): x2 = x2.clip(min=0, max=max_shape[1]) y2 = y2.clip(min=0, max=max_shape[0]) return paddle.stack([x1, y1, x2, y2], -1) + + +def bbox_center(boxes): + """Get bbox centers from boxes. + Args: + boxes (Tensor): boxes with shape (N, 4), "xmin, ymin, xmax, ymax" format. + Returns: + Tensor: boxes centers with shape (N, 2), "cx, cy" format. + """ + boxes_cx = (boxes[:, 0] + boxes[:, 2]) / 2 + boxes_cy = (boxes[:, 1] + boxes[:, 3]) / 2 + return paddle.stack([boxes_cx, boxes_cy], axis=-1) diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index dd6b1dcc2249e65750b7cbaa9ae9cae09486921a..55b9d907dd7b0e1a6260a559ba4ca85b03de3cc1 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -29,6 +29,7 @@ from . import gfl_head from . import pico_head from . import detr_head from . import sparsercnn_head +from . import tood_head from .bbox_head import * from .mask_head import * @@ -47,3 +48,4 @@ from .gfl_head import * from .pico_head import * from .detr_head import * from .sparsercnn_head import * +from .tood_head import * diff --git a/ppdet/modeling/heads/tood_head.py b/ppdet/modeling/heads/tood_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b244dfdc00eeffdab6af10b46afe57095fb4f3d7 --- /dev/null +++ b/ppdet/modeling/heads/tood_head.py @@ -0,0 +1,418 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn.initializer import Constant + +from ppdet.core.workspace import register +from ..initializer import normal_, constant_, bias_init_with_prob +from ppdet.modeling.bbox_utils import bbox_center +from ..losses import GIoULoss +from paddle.vision.ops import deform_conv2d +from ppdet.modeling.layers import ConvNormLayer + + +class ScaleReg(nn.Layer): + """ + Parameter for scaling the regression outputs. + """ + + def __init__(self, init_scale=1.): + super(ScaleReg, self).__init__() + self.scale_reg = self.create_parameter( + shape=[1], + attr=ParamAttr(initializer=Constant(value=init_scale)), + dtype="float32") + + def forward(self, inputs): + out = inputs * self.scale_reg + return out + + +class TaskDecomposition(nn.Layer): + def __init__( + self, + feat_channels, + stacked_convs, + la_down_rate=8, + norm_type='gn', + norm_groups=32, ): + super(TaskDecomposition, self).__init__() + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.norm_type = norm_type + self.norm_groups = norm_groups + self.in_channels = self.feat_channels * self.stacked_convs + self.la_conv1 = nn.Conv2D(self.in_channels, + self.in_channels // la_down_rate, 1) + self.la_conv2 = nn.Conv2D(self.in_channels // la_down_rate, + self.stacked_convs, 1) + + self.reduction_conv = ConvNormLayer( + self.in_channels, + self.feat_channels, + filter_size=1, + stride=1, + norm_type=self.norm_type, + norm_groups=self.norm_groups) + + self._init_weights() + + def _init_weights(self): + normal_(self.la_conv1.weight, std=0.001) + normal_(self.la_conv2.weight, std=0.001) + + def forward(self, feat, avg_feat=None): + b, _, h, w = feat.shape + if avg_feat is None: + avg_feat = F.adaptive_avg_pool2d(feat, (1, 1)) + weight = F.relu(self.la_conv1(avg_feat)) + weight = F.sigmoid(self.la_conv2(weight)) + + # here new_conv_weight = layer_attention_weight * conv_weight + # in order to save memory and FLOPs. + conv_weight = weight.reshape([b, 1, self.stacked_convs, 1]) * \ + self.reduction_conv.conv.weight.reshape( + [1, self.feat_channels, self.stacked_convs, self.feat_channels]) + conv_weight = conv_weight.reshape( + [b, self.feat_channels, self.in_channels]) + feat = feat.reshape([b, self.in_channels, h * w]) + feat = paddle.bmm(conv_weight, feat).reshape( + [b, self.feat_channels, h, w]) + if self.norm_type is not None: + feat = self.reduction_conv.norm(feat) + feat = F.relu(feat) + return feat + + +@register +class TOODHead(nn.Layer): + __inject__ = ['nms', 'static_assigner', 'assigner'] + __shared__ = ['num_classes'] + + def __init__(self, + num_classes=80, + feat_channels=256, + stacked_convs=6, + fpn_strides=(8, 16, 32, 64, 128), + grid_cell_scale=8, + grid_cell_offset=0.5, + norm_type='gn', + norm_groups=32, + static_assigner_epoch=4, + use_align_head=True, + loss_weight={ + 'class': 1.0, + 'bbox': 1.0, + 'iou': 2.0, + }, + nms='MultiClassNMS', + static_assigner='ATSSAssigner', + assigner='TaskAlignedAssigner'): + super(TOODHead, self).__init__() + self.num_classes = num_classes + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.fpn_strides = fpn_strides + self.grid_cell_scale = grid_cell_scale + self.grid_cell_offset = grid_cell_offset + self.static_assigner_epoch = static_assigner_epoch + self.use_align_head = use_align_head + self.nms = nms + self.static_assigner = static_assigner + self.assigner = assigner + self.loss_weight = loss_weight + self.giou_loss = GIoULoss() + + self.inter_convs = nn.LayerList() + for i in range(self.stacked_convs): + self.inter_convs.append( + ConvNormLayer( + self.feat_channels, + self.feat_channels, + filter_size=3, + stride=1, + norm_type=norm_type, + norm_groups=norm_groups)) + + self.cls_decomp = TaskDecomposition( + self.feat_channels, + self.stacked_convs, + self.stacked_convs * 8, + norm_type=norm_type, + norm_groups=norm_groups) + self.reg_decomp = TaskDecomposition( + self.feat_channels, + self.stacked_convs, + self.stacked_convs * 8, + norm_type=norm_type, + norm_groups=norm_groups) + + self.tood_cls = nn.Conv2D( + self.feat_channels, self.num_classes, 3, padding=1) + self.tood_reg = nn.Conv2D(self.feat_channels, 4, 3, padding=1) + + if self.use_align_head: + self.cls_prob_conv1 = nn.Conv2D(self.feat_channels * + self.stacked_convs, + self.feat_channels // 4, 1) + self.cls_prob_conv2 = nn.Conv2D( + self.feat_channels // 4, 1, 3, padding=1) + self.reg_offset_conv1 = nn.Conv2D(self.feat_channels * + self.stacked_convs, + self.feat_channels // 4, 1) + self.reg_offset_conv2 = nn.Conv2D( + self.feat_channels // 4, 4 * 2, 3, padding=1) + + self.scales_regs = nn.LayerList([ScaleReg() for _ in self.fpn_strides]) + + self._init_weights() + + @classmethod + def from_config(cls, cfg, input_shape): + return { + 'feat_channels': input_shape[0].channels, + 'fpn_strides': [i.stride for i in input_shape], + } + + def _init_weights(self): + bias_cls = bias_init_with_prob(0.01) + normal_(self.tood_cls.weight, std=0.01) + constant_(self.tood_cls.bias, bias_cls) + normal_(self.tood_reg.weight, std=0.01) + + if self.use_align_head: + normal_(self.cls_prob_conv1.weight, std=0.01) + normal_(self.cls_prob_conv2.weight, std=0.01) + constant_(self.cls_prob_conv2.bias, bias_cls) + normal_(self.reg_offset_conv1.weight, std=0.001) + normal_(self.reg_offset_conv2.weight, std=0.001) + constant_(self.reg_offset_conv2.bias) + + def _generate_anchors(self, feats): + anchors, num_anchors_list = [], [] + stride_tensor_list = [] + for feat, stride in zip(feats, self.fpn_strides): + _, _, h, w = feat.shape + cell_half_size = self.grid_cell_scale * stride * 0.5 + shift_x = (paddle.arange(end=w) + self.grid_cell_offset) * stride + shift_y = (paddle.arange(end=h) + self.grid_cell_offset) * stride + shift_y, shift_x = paddle.meshgrid(shift_y, shift_x) + anchor = paddle.stack( + [ + shift_x - cell_half_size, shift_y - cell_half_size, + shift_x + cell_half_size, shift_y + cell_half_size + ], + axis=-1) + anchors.append(anchor.reshape([-1, 4])) + num_anchors_list.append(len(anchors[-1])) + stride_tensor_list.append( + paddle.full([num_anchors_list[-1], 1], stride)) + return anchors, num_anchors_list, stride_tensor_list + + @staticmethod + def _batch_distance2bbox(points, distance, max_shapes=None): + """Decode distance prediction to bounding box. + Args: + points (Tensor): [B, l, 2] + distance (Tensor): [B, l, 4] + max_shapes (tuple): [B, 2], "h w" format, Shape of the image. + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[:, :, 0] - distance[:, :, 0] + y1 = points[:, :, 1] - distance[:, :, 1] + x2 = points[:, :, 0] + distance[:, :, 2] + y2 = points[:, :, 1] + distance[:, :, 3] + bboxes = paddle.stack([x1, y1, x2, y2], -1) + if max_shapes is not None: + out_bboxes = [] + for bbox, max_shape in zip(bboxes, max_shapes): + bbox[:, 0] = bbox[:, 0].clip(min=0, max=max_shape[1]) + bbox[:, 1] = bbox[:, 1].clip(min=0, max=max_shape[0]) + bbox[:, 2] = bbox[:, 2].clip(min=0, max=max_shape[1]) + bbox[:, 3] = bbox[:, 3].clip(min=0, max=max_shape[0]) + out_bboxes.append(bbox) + out_bboxes = paddle.stack(out_bboxes) + return out_bboxes + return bboxes + + @staticmethod + def _deform_sampling(feat, offset): + """ Sampling the feature according to offset. + Args: + feat (Tensor): Feature + offset (Tensor): Spatial offset for for feature sampliing + """ + # it is an equivalent implementation of bilinear interpolation + # you can also use F.grid_sample instead + c = feat.shape[1] + weight = paddle.ones([c, 1, 1, 1]) + y = deform_conv2d(feat, offset, weight, deformable_groups=c, groups=c) + return y + + def forward(self, feats): + assert len(feats) == len(self.fpn_strides), \ + "The size of feats is not equal to size of fpn_strides" + + anchors, num_anchors_list, stride_tensor_list = self._generate_anchors( + feats) + cls_score_list, bbox_pred_list = [], [] + for feat, scale_reg, anchor, stride in zip(feats, self.scales_regs, + anchors, self.fpn_strides): + b, _, h, w = feat.shape + inter_feats = [] + for inter_conv in self.inter_convs: + feat = F.relu(inter_conv(feat)) + inter_feats.append(feat) + feat = paddle.concat(inter_feats, axis=1) + + # task decomposition + avg_feat = F.adaptive_avg_pool2d(feat, (1, 1)) + cls_feat = self.cls_decomp(feat, avg_feat) + reg_feat = self.reg_decomp(feat, avg_feat) + + # cls prediction and alignment + cls_logits = self.tood_cls(cls_feat) + if self.use_align_head: + cls_prob = F.relu(self.cls_prob_conv1(feat)) + cls_prob = F.sigmoid(self.cls_prob_conv2(cls_prob)) + cls_score = (F.sigmoid(cls_logits) * cls_prob).sqrt() + else: + cls_score = F.sigmoid(cls_logits) + cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1])) + + # reg prediction and alignment + reg_dist = scale_reg(self.tood_reg(reg_feat).exp()) + reg_dist = reg_dist.transpose([0, 2, 3, 1]).reshape([b, -1, 4]) + anchor_centers = bbox_center(anchor).unsqueeze(0) / stride + reg_bbox = self._batch_distance2bbox( + anchor_centers.tile([b, 1, 1]), reg_dist) + if self.use_align_head: + reg_bbox = reg_bbox.reshape([b, h, w, 4]).transpose( + [0, 3, 1, 2]) + reg_offset = F.relu(self.reg_offset_conv1(feat)) + reg_offset = self.reg_offset_conv2(reg_offset) + bbox_pred = self._deform_sampling(reg_bbox, reg_offset) + bbox_pred = bbox_pred.flatten(2).transpose([0, 2, 1]) + else: + bbox_pred = reg_bbox + + if not self.training: + bbox_pred *= stride + bbox_pred_list.append(bbox_pred) + cls_score_list = paddle.concat(cls_score_list, axis=1) + bbox_pred_list = paddle.concat(bbox_pred_list, axis=1) + anchors = paddle.concat(anchors) + anchors.stop_gradient = True + stride_tensor_list = paddle.concat(stride_tensor_list).unsqueeze(0) + stride_tensor_list.stop_gradient = True + + return cls_score_list, bbox_pred_list, anchors, num_anchors_list, stride_tensor_list + + @staticmethod + def _focal_loss(score, label, alpha=0.25, gamma=2.0): + weight = (score - label).pow(gamma) + if alpha > 0: + alpha_t = alpha * label + (1 - alpha) * (1 - label) + weight *= alpha_t + loss = F.binary_cross_entropy( + score, label, weight=weight, reduction='sum') + return loss + + def get_loss(self, head_outs, gt_meta): + pred_scores, pred_bboxes, anchors, num_anchors_list, stride_tensor_list = head_outs + gt_labels = gt_meta['gt_class'] + gt_bboxes = gt_meta['gt_bbox'] + # label assignment + if gt_meta['epoch_id'] < self.static_assigner_epoch: + assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner( + anchors, + num_anchors_list, + gt_labels, + gt_bboxes, + bg_index=self.num_classes) + alpha_l = 0.25 + else: + assigned_labels, assigned_bboxes, assigned_scores = self.assigner( + pred_scores.detach(), + pred_bboxes.detach() * stride_tensor_list, + bbox_center(anchors), + gt_labels, + gt_bboxes, + bg_index=self.num_classes) + alpha_l = -1 + + # rescale bbox + assigned_bboxes /= stride_tensor_list + # classification loss + loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha=alpha_l) + # select positive samples mask + mask_positive = (assigned_labels != self.num_classes) + num_pos = mask_positive.astype(paddle.float32).sum() + # bbox regression loss + if num_pos > 0: + bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4]) + pred_bboxes_pos = paddle.masked_select(pred_bboxes, + bbox_mask).reshape([-1, 4]) + assigned_bboxes_pos = paddle.masked_select( + assigned_bboxes, bbox_mask).reshape([-1, 4]) + bbox_weight = paddle.masked_select( + assigned_scores.sum(-1), mask_positive).unsqueeze(-1) + # iou loss + loss_iou = self.giou_loss(pred_bboxes_pos, + assigned_bboxes_pos) * bbox_weight + loss_iou = loss_iou.sum() / bbox_weight.sum() + # l1 loss + loss_l1 = F.l1_loss(pred_bboxes_pos, assigned_bboxes_pos) + else: + loss_iou = paddle.zeros([1]) + loss_l1 = paddle.zeros([1]) + + loss_cls /= assigned_scores.sum().clip(min=1) + loss = self.loss_weight['class'] * loss_cls + self.loss_weight[ + 'iou'] * loss_iou + + return { + 'loss': loss, + 'loss_class': loss_cls, + 'loss_iou': loss_iou, + 'loss_l1': loss_l1 + } + + def post_process(self, head_outs, img_shape, scale_factor): + pred_scores, pred_bboxes, _, _, _ = head_outs + pred_scores = pred_scores.transpose([0, 2, 1]) + + for i in range(len(pred_bboxes)): + pred_bboxes[i, :, 0] = pred_bboxes[i, :, 0].clip( + min=0, max=img_shape[i, 1]) + pred_bboxes[i, :, 1] = pred_bboxes[i, :, 1].clip( + min=0, max=img_shape[i, 0]) + pred_bboxes[i, :, 2] = pred_bboxes[i, :, 2].clip( + min=0, max=img_shape[i, 1]) + pred_bboxes[i, :, 3] = pred_bboxes[i, :, 3].clip( + min=0, max=img_shape[i, 0]) + # scale bbox to origin + scale_factor = scale_factor.flip([1]).tile([1, 2]).unsqueeze(1) + pred_bboxes /= scale_factor + bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores) + return bbox_pred, bbox_num diff --git a/ppdet/modeling/initializer.py b/ppdet/modeling/initializer.py index 6171270e47a895e7ab5636220cb853c1a738a9b6..ae0bc4287dabe8ded379063a0282a1a99096ed22 100644 --- a/ppdet/modeling/initializer.py +++ b/ppdet/modeling/initializer.py @@ -272,6 +272,12 @@ def conv_init_(module): uniform_(module.bias, -bound, bound) +def bias_init_with_prob(prior_prob=0.01): + """initialize conv/fc bias value according to a given probability value.""" + bias_init = float(-np.log((1 - prior_prob) / prior_prob)) + return bias_init + + @paddle.no_grad() def reset_initialized_parameter(model, include_self=True): """