simota_assigner.py 11.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
G
Guanghua Yu 已提交
14 15 16

# The code is based on:
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/bbox/assigners/sim_ota_assigner.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

import paddle
import numpy as np
import paddle.nn.functional as F

from ppdet.modeling.losses.varifocal_loss import varifocal_loss
from ppdet.modeling.bbox_utils import batch_bbox_overlaps
from ppdet.core.workspace import register


@register
class SimOTAAssigner(object):
    """Computes matching between predictions and ground truth.
    Args:
        center_radius (int | float, optional): Ground truth center size
            to judge whether a prior is in center. Default 2.5.
        candidate_topk (int, optional): The candidate top-k which used to
            get top-k ious to calculate dynamic-k. Default 10.
        iou_weight (int | float, optional): The scale factor for regression
            iou cost. Default 3.0.
        cls_weight (int | float, optional): The scale factor for classification
            cost. Default 1.0.
        num_classes (int): The num_classes of dataset.
        use_vfl (int): Whether to use varifocal_loss when calculating the cost matrix.
    """
    __shared__ = ['num_classes']

    def __init__(self,
                 center_radius=2.5,
                 candidate_topk=10,
                 iou_weight=3.0,
                 cls_weight=1.0,
                 num_classes=80,
                 use_vfl=True):
        self.center_radius = center_radius
        self.candidate_topk = candidate_topk
        self.iou_weight = iou_weight
        self.cls_weight = cls_weight
        self.num_classes = num_classes
        self.use_vfl = use_vfl

G
Guanghua Yu 已提交
58 59
    def get_in_gt_and_in_center_info(self, flatten_center_and_stride,
                                     gt_bboxes):
60 61
        num_gt = gt_bboxes.shape[0]

G
Guanghua Yu 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75
        flatten_x = flatten_center_and_stride[:, 0].unsqueeze(1).tile(
            [1, num_gt])
        flatten_y = flatten_center_and_stride[:, 1].unsqueeze(1).tile(
            [1, num_gt])
        flatten_stride_x = flatten_center_and_stride[:, 2].unsqueeze(1).tile(
            [1, num_gt])
        flatten_stride_y = flatten_center_and_stride[:, 3].unsqueeze(1).tile(
            [1, num_gt])

        # is prior centers in gt bboxes, shape: [n_center, n_gt]
        l_ = flatten_x - gt_bboxes[:, 0]
        t_ = flatten_y - gt_bboxes[:, 1]
        r_ = gt_bboxes[:, 2] - flatten_x
        b_ = gt_bboxes[:, 3] - flatten_y
76 77 78 79 80 81

        deltas = paddle.stack([l_, t_, r_, b_], axis=1)
        is_in_gts = deltas.min(axis=1) > 0
        is_in_gts_all = is_in_gts.sum(axis=1) > 0

        # is prior centers in gt centers
G
Guanghua Yu 已提交
82 83 84 85 86 87 88 89 90 91 92
        gt_center_xs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
        gt_center_ys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
        ct_bound_l = gt_center_xs - self.center_radius * flatten_stride_x
        ct_bound_t = gt_center_ys - self.center_radius * flatten_stride_y
        ct_bound_r = gt_center_xs + self.center_radius * flatten_stride_x
        ct_bound_b = gt_center_ys + self.center_radius * flatten_stride_y

        cl_ = flatten_x - ct_bound_l
        ct_ = flatten_y - ct_bound_t
        cr_ = ct_bound_r - flatten_x
        cb_ = ct_bound_b - flatten_y
93 94 95 96 97

        ct_deltas = paddle.stack([cl_, ct_, cr_, cb_], axis=1)
        is_in_cts = ct_deltas.min(axis=1) > 0
        is_in_cts_all = is_in_cts.sum(axis=1) > 0

G
Guanghua Yu 已提交
98 99 100
        # in any of gts or gt centers, shape: [n_center]
        is_in_gts_or_centers_all = paddle.logical_or(is_in_gts_all,
                                                     is_in_cts_all)
101

G
Guanghua Yu 已提交
102 103
        is_in_gts_or_centers_all_inds = paddle.nonzero(
            is_in_gts_or_centers_all).squeeze(1)
104

G
Guanghua Yu 已提交
105 106
        # both in gts and gt centers, shape: [num_fg, num_gt]
        is_in_gts_and_centers = paddle.logical_and(
107
            paddle.gather(
G
Guanghua Yu 已提交
108
                is_in_gts.cast('int'), is_in_gts_or_centers_all_inds,
109 110
                axis=0).cast('bool'),
            paddle.gather(
G
Guanghua Yu 已提交
111
                is_in_cts.cast('int'), is_in_gts_or_centers_all_inds,
112
                axis=0).cast('bool'))
G
Guanghua Yu 已提交
113
        return is_in_gts_or_centers_all, is_in_gts_or_centers_all_inds, is_in_gts_and_centers
114

G
Guanghua Yu 已提交
115 116
    def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt):
        match_matrix = np.zeros_like(cost_matrix.numpy())
117
        # select candidate topk ious for dynamic-k calculation
W
Wenyu 已提交
118 119 120 121
        topk_ious, _ = paddle.topk(
            pairwise_ious,
            min(self.candidate_topk, pairwise_ious.shape[0]),
            axis=0)
122 123 124 125
        # calculate dynamic k for each gt
        dynamic_ks = paddle.clip(topk_ious.sum(0).cast('int'), min=1)
        for gt_idx in range(num_gt):
            _, pos_idx = paddle.topk(
G
Guanghua Yu 已提交
126 127
                cost_matrix[:, gt_idx], k=dynamic_ks[gt_idx], largest=False)
            match_matrix[:, gt_idx][pos_idx.numpy()] = 1.0
128 129 130

        del topk_ious, dynamic_ks, pos_idx

G
Guanghua Yu 已提交
131 132 133 134 135 136 137 138 139 140 141
        # match points more than two gts
        extra_match_gts_mask = match_matrix.sum(1) > 1
        if extra_match_gts_mask.sum() > 0:
            cost_matrix = cost_matrix.numpy()
            cost_argmin = np.argmin(
                cost_matrix[extra_match_gts_mask, :], axis=1)
            match_matrix[extra_match_gts_mask, :] *= 0.0
            match_matrix[extra_match_gts_mask, cost_argmin] = 1.0
        # get foreground mask
        match_fg_mask_inmatrix = match_matrix.sum(1) > 0
        match_gt_inds_to_fg = match_matrix[match_fg_mask_inmatrix, :].argmax(1)
142

G
Guanghua Yu 已提交
143
        return match_gt_inds_to_fg, match_fg_mask_inmatrix
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160

    def get_sample(self, assign_gt_inds, gt_bboxes):
        pos_inds = np.unique(np.nonzero(assign_gt_inds > 0)[0])
        neg_inds = np.unique(np.nonzero(assign_gt_inds == 0)[0])
        pos_assigned_gt_inds = assign_gt_inds[pos_inds] - 1

        if gt_bboxes.size == 0:
            # hack for index error case
            assert pos_assigned_gt_inds.size == 0
            pos_gt_bboxes = np.empty_like(gt_bboxes).reshape(-1, 4)
        else:
            if len(gt_bboxes.shape) < 2:
                gt_bboxes = gt_bboxes.resize(-1, 4)
            pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
        return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds

    def __call__(self,
G
Guanghua Yu 已提交
161 162 163
                 flatten_cls_pred_scores,
                 flatten_center_and_stride,
                 flatten_bboxes,
164 165 166 167 168 169 170 171 172
                 gt_bboxes,
                 gt_labels,
                 eps=1e-7):
        """Assign gt to priors using SimOTA.
        TODO: add comment.
        Returns:
            assign_result: The assigned result.
        """
        num_gt = gt_bboxes.shape[0]
G
Guanghua Yu 已提交
173
        num_bboxes = flatten_bboxes.shape[0]
174 175

        if num_gt == 0 or num_bboxes == 0:
G
Guanghua Yu 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
            # No ground truth or boxes
            label = np.ones([num_bboxes], dtype=np.int64) * self.num_classes
            label_weight = np.ones([num_bboxes], dtype=np.float32)
            bbox_target = np.zeros_like(flatten_center_and_stride)
            return 0, label, label_weight, bbox_target

        is_in_gts_or_centers_all, is_in_gts_or_centers_all_inds, is_in_boxes_and_center = self.get_in_gt_and_in_center_info(
            flatten_center_and_stride, gt_bboxes)

        # bboxes and scores to calculate matrix
        valid_flatten_bboxes = flatten_bboxes[is_in_gts_or_centers_all_inds]
        valid_cls_pred_scores = flatten_cls_pred_scores[
            is_in_gts_or_centers_all_inds]
        num_valid_bboxes = valid_flatten_bboxes.shape[0]

        pairwise_ious = batch_bbox_overlaps(valid_flatten_bboxes,
                                            gt_bboxes)  # [num_points,num_gts]
193 194
        if self.use_vfl:
            gt_vfl_labels = gt_labels.squeeze(-1).unsqueeze(0).tile(
G
Guanghua Yu 已提交
195 196
                [num_valid_bboxes, 1]).reshape([-1])
            valid_pred_scores = valid_cls_pred_scores.unsqueeze(1).tile(
197 198 199 200 201 202 203
                [1, num_gt, 1]).reshape([-1, self.num_classes])
            vfl_score = np.zeros(valid_pred_scores.shape)
            vfl_score[np.arange(0, vfl_score.shape[0]), gt_vfl_labels.numpy(
            )] = pairwise_ious.reshape([-1])
            vfl_score = paddle.to_tensor(vfl_score)
            losses_vfl = varifocal_loss(
                valid_pred_scores, vfl_score,
G
Guanghua Yu 已提交
204
                use_sigmoid=False).reshape([num_valid_bboxes, num_gt])
205
            losses_giou = batch_bbox_overlaps(
G
Guanghua Yu 已提交
206
                valid_flatten_bboxes, gt_bboxes, mode='giou')
207 208
            cost_matrix = (
                losses_vfl * self.cls_weight + losses_giou * self.iou_weight +
G
Guanghua Yu 已提交
209 210
                paddle.logical_not(is_in_boxes_and_center).cast('float32') *
                100000000)
211
        else:
G
Guanghua Yu 已提交
212
            iou_cost = -paddle.log(pairwise_ious + eps)
213 214
            gt_onehot_label = (F.one_hot(
                gt_labels.squeeze(-1).cast(paddle.int64),
G
Guanghua Yu 已提交
215 216
                flatten_cls_pred_scores.shape[-1]).cast('float32').unsqueeze(0)
                               .tile([num_valid_bboxes, 1, 1]))
217

G
Guanghua Yu 已提交
218
            valid_pred_scores = valid_cls_pred_scores.unsqueeze(1).tile(
219 220 221 222 223 224
                [1, num_gt, 1])
            cls_cost = F.binary_cross_entropy(
                valid_pred_scores, gt_onehot_label, reduction='none').sum(-1)

            cost_matrix = (
                cls_cost * self.cls_weight + iou_cost * self.iou_weight +
G
Guanghua Yu 已提交
225 226
                paddle.logical_not(is_in_boxes_and_center).cast('float32') *
                100000000)
227

G
Guanghua Yu 已提交
228
        match_gt_inds_to_fg, match_fg_mask_inmatrix = \
229
            self.dynamic_k_matching(
G
Guanghua Yu 已提交
230
                cost_matrix, pairwise_ious, num_gt)
231

G
Guanghua Yu 已提交
232 233 234 235 236
        # sample and assign results
        assigned_gt_inds = np.zeros([num_bboxes], dtype=np.int64)
        match_fg_mask_inall = np.zeros_like(assigned_gt_inds)
        match_fg_mask_inall[is_in_gts_or_centers_all.numpy(
        )] = match_fg_mask_inmatrix
237

G
Guanghua Yu 已提交
238
        assigned_gt_inds[match_fg_mask_inall.astype(
X
xiegegege 已提交
239
            np.bool_)] = match_gt_inds_to_fg + 1
240 241

        pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds \
G
Guanghua Yu 已提交
242
            = self.get_sample(assigned_gt_inds, gt_bboxes.numpy())
243

G
Guanghua Yu 已提交
244 245 246 247
        bbox_target = np.zeros_like(flatten_bboxes)
        bbox_weight = np.zeros_like(flatten_bboxes)
        label = np.ones([num_bboxes], dtype=np.int64) * self.num_classes
        label_weight = np.zeros([num_bboxes], dtype=np.float32)
248 249

        if len(pos_inds) > 0:
G
Guanghua Yu 已提交
250
            gt_labels = gt_labels.numpy()
251
            pos_bbox_targets = pos_gt_bboxes
G
Guanghua Yu 已提交
252 253
            bbox_target[pos_inds, :] = pos_bbox_targets
            bbox_weight[pos_inds, :] = 1.0
254
            if not np.any(gt_labels):
G
Guanghua Yu 已提交
255
                label[pos_inds] = 0
256
            else:
G
Guanghua Yu 已提交
257
                label[pos_inds] = gt_labels.squeeze(-1)[pos_assigned_gt_inds]
258

G
Guanghua Yu 已提交
259
            label_weight[pos_inds] = 1.0
260
        if len(neg_inds) > 0:
G
Guanghua Yu 已提交
261
            label_weight[neg_inds] = 1.0
262 263 264

        pos_num = max(pos_inds.size, 1)

G
Guanghua Yu 已提交
265
        return pos_num, label, label_weight, bbox_target