queryinst_loss.py 6.3 KB
Newer Older
U
ucsk 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
import paddle.nn.functional as F

from ppdet.core.workspace import register
from ppdet.modeling.losses.iou_loss import GIoULoss
from .sparsercnn_loss import HungarianMatcher

__all__ = ['QueryInstLoss']


@register
class QueryInstLoss(object):
    __shared__ = ['num_classes']

    def __init__(self,
                 num_classes=80,
                 focal_loss_alpha=0.25,
                 focal_loss_gamma=2.0,
                 class_weight=2.0,
                 l1_weight=5.0,
                 giou_weight=2.0,
                 mask_weight=8.0):
        super(QueryInstLoss, self).__init__()

        self.num_classes = num_classes
        self.focal_loss_alpha = focal_loss_alpha
        self.focal_loss_gamma = focal_loss_gamma
        self.loss_weights = {
            "loss_cls": class_weight,
            "loss_bbox": l1_weight,
            "loss_giou": giou_weight,
            "loss_mask": mask_weight
        }
        self.giou_loss = GIoULoss(eps=1e-6, reduction='sum')

        self.matcher = HungarianMatcher(focal_loss_alpha, focal_loss_gamma,
                                        class_weight, l1_weight, giou_weight)

    def loss_classes(self, class_logits, targets, indices, avg_factor):
        tgt_labels = paddle.full(
            class_logits.shape[:2], self.num_classes, dtype='int32')

        if sum(len(v['labels']) for v in targets) > 0:
            tgt_classes = paddle.concat([
                paddle.gather(
                    tgt['labels'], tgt_idx, axis=0)
                for tgt, (_, tgt_idx) in zip(targets, indices)
            ])
            batch_idx, src_idx = self._get_src_permutation_idx(indices)
            for i, (batch_i, src_i) in enumerate(zip(batch_idx, src_idx)):
                tgt_labels[int(batch_i), int(src_i)] = tgt_classes[i]

        tgt_labels = tgt_labels.flatten(0, 1).unsqueeze(-1)

        tgt_labels_onehot = paddle.cast(
            tgt_labels == paddle.arange(0, self.num_classes), dtype='float32')
        tgt_labels_onehot.stop_gradient = True

        src_logits = class_logits.flatten(0, 1)

        loss_cls = F.sigmoid_focal_loss(
            src_logits,
            tgt_labels_onehot,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction='sum') / avg_factor
        losses = {'loss_cls': loss_cls * self.loss_weights['loss_cls']}
        return losses

    def loss_bboxes(self, bbox_pred, targets, indices, avg_factor):
        bboxes = paddle.concat([
            paddle.gather(
                src, src_idx, axis=0)
            for src, (src_idx, _) in zip(bbox_pred, indices)
        ])

        tgt_bboxes = paddle.concat([
            paddle.gather(
                tgt['boxes'], tgt_idx, axis=0)
            for tgt, (_, tgt_idx) in zip(targets, indices)
        ])
        tgt_bboxes.stop_gradient = True

        im_shapes = paddle.concat([tgt['img_whwh_tgt'] for tgt in targets])
        bboxes_norm = bboxes / im_shapes
        tgt_bboxes_norm = tgt_bboxes / im_shapes

        loss_giou = self.giou_loss(bboxes, tgt_bboxes) / avg_factor
        loss_bbox = F.l1_loss(
            bboxes_norm, tgt_bboxes_norm, reduction='sum') / avg_factor
        losses = {
            'loss_bbox': loss_bbox * self.loss_weights['loss_bbox'],
            'loss_giou': loss_giou * self.loss_weights['loss_giou']
        }
        return losses

    def loss_masks(self, pos_bbox_pred, mask_logits, targets, indices,
                   avg_factor):
        tgt_segm = [
            paddle.gather(
                tgt['gt_segm'], tgt_idx, axis=0)
            for tgt, (_, tgt_idx) in zip(targets, indices)
        ]

        tgt_masks = []
        for i in range(len(indices)):
            gt_segm = tgt_segm[i].unsqueeze(1)
            if len(gt_segm) == 0:
                continue
            boxes = pos_bbox_pred[i]
            boxes[:, 0::2] = paddle.clip(
                boxes[:, 0::2], min=0, max=gt_segm.shape[3])
            boxes[:, 1::2] = paddle.clip(
                boxes[:, 1::2], min=0, max=gt_segm.shape[2])
            boxes_num = paddle.to_tensor([1] * len(boxes), dtype='int32')
            gt_mask = paddle.vision.ops.roi_align(
                gt_segm,
                boxes,
                boxes_num,
                output_size=mask_logits.shape[-2:],
                aligned=True)
            tgt_masks.append(gt_mask)
        tgt_masks = paddle.concat(tgt_masks).squeeze(1)
        tgt_masks = paddle.cast(tgt_masks >= 0.5, dtype='float32')
        tgt_masks.stop_gradient = True

        tgt_labels = paddle.concat([
            paddle.gather(
                tgt['labels'], tgt_idx, axis=0)
            for tgt, (_, tgt_idx) in zip(targets, indices)
        ])

        mask_label = F.one_hot(tgt_labels, self.num_classes).unsqueeze([2, 3])
        mask_label = paddle.expand_as(mask_label, mask_logits)
        mask_label.stop_gradient = True

        src_masks = paddle.gather_nd(mask_logits, paddle.nonzero(mask_label))
        shape = mask_logits.shape
        src_masks = paddle.reshape(src_masks, [shape[0], shape[2], shape[3]])
        src_masks = F.sigmoid(src_masks)

        X = src_masks.flatten(1)
        Y = tgt_masks.flatten(1)
        inter = paddle.sum(X * Y, 1)
        union = paddle.sum(X * X, 1) + paddle.sum(Y * Y, 1)
        dice = (2 * inter) / (union + 2e-5)

        loss_mask = (1 - dice).sum() / avg_factor
        losses = {'loss_mask': loss_mask * self.loss_weights['loss_mask']}
        return losses

    @staticmethod
    def _get_src_permutation_idx(indices):
        batch_idx = paddle.concat(
            [paddle.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = paddle.concat([src for (src, _) in indices])
        return batch_idx, src_idx