detr_loss.py 13.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from .iou_loss import GIoULoss
S
shangliang Xu 已提交
24
from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss
25

S
shangliang Xu 已提交
26
__all__ = ['DETRLoss', 'DINOLoss']
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69


@register
class DETRLoss(nn.Layer):
    __shared__ = ['num_classes', 'use_focal_loss']
    __inject__ = ['matcher']

    def __init__(self,
                 num_classes=80,
                 matcher='HungarianMatcher',
                 loss_coeff={
                     'class': 1,
                     'bbox': 5,
                     'giou': 2,
                     'no_object': 0.1,
                     'mask': 1,
                     'dice': 1
                 },
                 aux_loss=True,
                 use_focal_loss=False):
        r"""
        Args:
            num_classes (int): The number of classes.
            matcher (HungarianMatcher): It computes an assignment between the targets
                and the predictions of the network.
            loss_coeff (dict): The coefficient of loss.
            aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
            use_focal_loss (bool): Use focal loss or not.
        """
        super(DETRLoss, self).__init__()
        self.num_classes = num_classes

        self.matcher = matcher
        self.loss_coeff = loss_coeff
        self.aux_loss = aux_loss
        self.use_focal_loss = use_focal_loss

        if not self.use_focal_loss:
            self.loss_coeff['class'] = paddle.full([num_classes + 1],
                                                   loss_coeff['class'])
            self.loss_coeff['class'][-1] = loss_coeff['no_object']
        self.giou_loss = GIoULoss()

S
shangliang Xu 已提交
70 71 72 73 74 75 76
    def _get_loss_class(self,
                        logits,
                        gt_class,
                        match_indices,
                        bg_index,
                        num_gts,
                        postfix=""):
77
        # logits: [b, query, num_classes], gt_class: list[[n, 1]]
S
shangliang Xu 已提交
78 79 80
        name_class = "loss_class" + postfix
        if logits is None:
            return {name_class: paddle.zeros([1])}
81 82 83 84 85 86 87 88 89 90
        target_label = paddle.full(logits.shape[:2], bg_index, dtype='int64')
        bs, num_query_objects = target_label.shape
        if sum(len(a) for a in gt_class) > 0:
            index, updates = self._get_index_updates(num_query_objects,
                                                     gt_class, match_indices)
            target_label = paddle.scatter(
                target_label.reshape([-1, 1]), index, updates.astype('int64'))
            target_label = target_label.reshape([bs, num_query_objects])
        if self.use_focal_loss:
            target_label = F.one_hot(target_label,
S
shangliang Xu 已提交
91
                                     self.num_classes + 1)[..., :-1]
92
        return {
S
shangliang Xu 已提交
93
            name_class: self.loss_coeff['class'] * sigmoid_focal_loss(
94 95 96 97 98
                logits, target_label, num_gts / num_query_objects)
            if self.use_focal_loss else F.cross_entropy(
                logits, target_label, weight=self.loss_coeff['class'])
        }

S
shangliang Xu 已提交
99 100
    def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts,
                       postfix=""):
101
        # boxes: [b, query, 4], gt_bbox: list[[n, 4]]
S
shangliang Xu 已提交
102 103 104 105
        name_bbox = "loss_bbox" + postfix
        name_giou = "loss_giou" + postfix
        if boxes is None:
            return {name_bbox: paddle.zeros([1]), name_giou: paddle.zeros([1])}
106 107
        loss = dict()
        if sum(len(a) for a in gt_bbox) == 0:
S
shangliang Xu 已提交
108 109
            loss[name_bbox] = paddle.to_tensor([0.])
            loss[name_giou] = paddle.to_tensor([0.])
110 111 112 113
            return loss

        src_bbox, target_bbox = self._get_src_target_assign(boxes, gt_bbox,
                                                            match_indices)
S
shangliang Xu 已提交
114
        loss[name_bbox] = self.loss_coeff['bbox'] * F.l1_loss(
115
            src_bbox, target_bbox, reduction='sum') / num_gts
S
shangliang Xu 已提交
116
        loss[name_giou] = self.giou_loss(
117
            bbox_cxcywh_to_xyxy(src_bbox), bbox_cxcywh_to_xyxy(target_bbox))
S
shangliang Xu 已提交
118 119
        loss[name_giou] = loss[name_giou].sum() / num_gts
        loss[name_giou] = self.loss_coeff['giou'] * loss[name_giou]
120 121
        return loss

S
shangliang Xu 已提交
122 123
    def _get_loss_mask(self, masks, gt_mask, match_indices, num_gts,
                       postfix=""):
124
        # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
S
shangliang Xu 已提交
125 126 127 128
        name_mask = "loss_mask" + postfix
        name_dice = "loss_dice" + postfix
        if masks is None:
            return {name_mask: paddle.zeros([1]), name_dice: paddle.zeros([1])}
129 130
        loss = dict()
        if sum(len(a) for a in gt_mask) == 0:
S
shangliang Xu 已提交
131 132
            loss[name_mask] = paddle.to_tensor([0.])
            loss[name_dice] = paddle.to_tensor([0.])
133 134 135 136 137 138 139 140
            return loss

        src_masks, target_masks = self._get_src_target_assign(masks, gt_mask,
                                                              match_indices)
        src_masks = F.interpolate(
            src_masks.unsqueeze(0),
            size=target_masks.shape[-2:],
            mode="bilinear")[0]
S
shangliang Xu 已提交
141
        loss[name_mask] = self.loss_coeff['mask'] * F.sigmoid_focal_loss(
142 143 144 145
            src_masks,
            target_masks,
            paddle.to_tensor(
                [num_gts], dtype='float32'))
S
shangliang Xu 已提交
146
        loss[name_dice] = self.loss_coeff['dice'] * self._dice_loss(
147 148 149 150 151 152 153 154 155 156 157 158
            src_masks, target_masks, num_gts)
        return loss

    def _dice_loss(self, inputs, targets, num_gts):
        inputs = F.sigmoid(inputs)
        inputs = inputs.flatten(1)
        targets = targets.flatten(1)
        numerator = 2 * (inputs * targets).sum(1)
        denominator = inputs.sum(-1) + targets.sum(-1)
        loss = 1 - (numerator + 1) / (denominator + 1)
        return loss.sum() / num_gts

S
shangliang Xu 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
    def _get_loss_aux(self,
                      boxes,
                      logits,
                      gt_bbox,
                      gt_class,
                      bg_index,
                      num_gts,
                      match_indices=None,
                      postfix=""):
        if boxes is None and logits is None:
            return {
                "loss_class_aux" + postfix: paddle.paddle.zeros([1]),
                "loss_bbox_aux" + postfix: paddle.paddle.zeros([1]),
                "loss_giou_aux" + postfix: paddle.paddle.zeros([1])
            }
174 175 176 177
        loss_class = []
        loss_bbox = []
        loss_giou = []
        for aux_boxes, aux_logits in zip(boxes, logits):
S
shangliang Xu 已提交
178 179 180
            if match_indices is None:
                match_indices = self.matcher(aux_boxes, aux_logits, gt_bbox,
                                             gt_class)
181 182
            loss_class.append(
                self._get_loss_class(aux_logits, gt_class, match_indices,
S
shangliang Xu 已提交
183 184
                                     bg_index, num_gts, postfix)['loss_class' +
                                                                 postfix])
185
            loss_ = self._get_loss_bbox(aux_boxes, gt_bbox, match_indices,
S
shangliang Xu 已提交
186 187 188
                                        num_gts, postfix)
            loss_bbox.append(loss_['loss_bbox' + postfix])
            loss_giou.append(loss_['loss_giou' + postfix])
189
        loss = {
S
shangliang Xu 已提交
190 191 192
            "loss_class_aux" + postfix: paddle.add_n(loss_class),
            "loss_bbox_aux" + postfix: paddle.add_n(loss_bbox),
            "loss_giou_aux" + postfix: paddle.add_n(loss_giou)
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
        }
        return loss

    def _get_index_updates(self, num_query_objects, target, match_indices):
        batch_idx = paddle.concat([
            paddle.full_like(src, i) for i, (src, _) in enumerate(match_indices)
        ])
        src_idx = paddle.concat([src for (src, _) in match_indices])
        src_idx += (batch_idx * num_query_objects)
        target_assign = paddle.concat([
            paddle.gather(
                t, dst, axis=0) for t, (_, dst) in zip(target, match_indices)
        ])
        return src_idx, target_assign

    def _get_src_target_assign(self, src, target, match_indices):
        src_assign = paddle.concat([
            paddle.gather(
                t, I, axis=0) if len(I) > 0 else paddle.zeros([0, t.shape[-1]])
            for t, (I, _) in zip(src, match_indices)
        ])
        target_assign = paddle.concat([
            paddle.gather(
                t, J, axis=0) if len(J) > 0 else paddle.zeros([0, t.shape[-1]])
            for t, (_, J) in zip(target, match_indices)
        ])
        return src_assign, target_assign

    def forward(self,
                boxes,
                logits,
                gt_bbox,
                gt_class,
                masks=None,
S
shangliang Xu 已提交
227 228 229
                gt_mask=None,
                postfix="",
                **kwargs):
230 231
        r"""
        Args:
S
shangliang Xu 已提交
232 233
            boxes (Tensor|None): [l, b, query, 4]
            logits (Tensor|None): [l, b, query, num_classes]
234 235 236 237
            gt_bbox (List(Tensor)): list[[n, 4]]
            gt_class (List(Tensor)): list[[n, 1]]
            masks (Tensor, optional): [b, query, h, w]
            gt_mask (List(Tensor), optional): list[[n, H, W]]
S
shangliang Xu 已提交
238
            postfix (str): postfix of loss name
239
        """
S
shangliang Xu 已提交
240 241 242 243 244 245
        if "match_indices" in kwargs:
            match_indices = kwargs["match_indices"]
        else:
            match_indices = self.matcher(boxes[-1].detach(),
                                         logits[-1].detach(), gt_bbox, gt_class)

246
        num_gts = sum(len(a) for a in gt_bbox)
S
shangliang Xu 已提交
247 248
        num_gts = paddle.to_tensor([num_gts], dtype="float32")
        if paddle.distributed.get_world_size() > 1:
249
            paddle.distributed.all_reduce(num_gts)
S
shangliang Xu 已提交
250 251 252
            num_gts /= paddle.distributed.get_world_size()
        num_gts = paddle.clip(num_gts, min=1.) * kwargs.get("dn_num_group", 1.)

253 254
        total_loss = dict()
        total_loss.update(
S
shangliang Xu 已提交
255 256 257
            self._get_loss_class(logits[
                -1] if logits is not None else None, gt_class, match_indices,
                                 self.num_classes, num_gts, postfix))
258
        total_loss.update(
S
shangliang Xu 已提交
259 260
            self._get_loss_bbox(boxes[-1] if boxes is not None else None,
                                gt_bbox, match_indices, num_gts, postfix))
261 262
        if masks is not None and gt_mask is not None:
            total_loss.update(
S
shangliang Xu 已提交
263 264
                self._get_loss_mask(masks if masks is not None else None,
                                    gt_mask, match_indices, num_gts, postfix))
265 266

        if self.aux_loss:
S
shangliang Xu 已提交
267 268
            if "match_indices" not in kwargs:
                match_indices = None
269
            total_loss.update(
S
shangliang Xu 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
                self._get_loss_aux(
                    boxes[:-1] if boxes is not None else None, logits[:-1]
                    if logits is not None else None, gt_bbox, gt_class,
                    self.num_classes, num_gts, match_indices, postfix))

        return total_loss


@register
class DINOLoss(DETRLoss):
    def forward(self,
                boxes,
                logits,
                gt_bbox,
                gt_class,
                masks=None,
                gt_mask=None,
                postfix="",
                dn_out_bboxes=None,
                dn_out_logits=None,
                dn_meta=None,
                **kwargs):
        total_loss = super(DINOLoss, self).forward(boxes, logits, gt_bbox,
                                                   gt_class)

        # denoising training loss
        if dn_meta is not None:
            dn_positive_idx, dn_num_group = \
                dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
            assert len(gt_class) == len(dn_positive_idx)

            # denoising match indices
            dn_match_indices = []
            for i in range(len(gt_class)):
                num_gt = len(gt_class[i])
                if num_gt > 0:
                    gt_idx = paddle.arange(end=num_gt, dtype="int64")
                    gt_idx = gt_idx.unsqueeze(0).tile(
                        [dn_num_group, 1]).flatten()
                    assert len(gt_idx) == len(dn_positive_idx[i])
                    dn_match_indices.append((dn_positive_idx[i], gt_idx))
                else:
                    dn_match_indices.append((paddle.zeros(
                        [0], dtype="int64"), paddle.zeros(
                            [0], dtype="int64")))
        else:
            dn_match_indices, dn_num_group = None, 1.

        dn_loss = super(DINOLoss, self).forward(
            dn_out_bboxes,
            dn_out_logits,
            gt_bbox,
            gt_class,
            postfix="_dn",
            match_indices=dn_match_indices,
            dn_num_group=dn_num_group)
        total_loss.update(dn_loss)
327 328

        return total_loss