fcos.py 5.6 KB
Newer Older
F
Feng Ni 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
20
import paddle.nn.functional as F
F
Feng Ni 已提交
21
from ppdet.core.workspace import register, create
F
Feng Ni 已提交
22
from .meta_arch import BaseArch
23 24
from ..ssod_utils import permute_to_N_HWA_K, QFLv2
from ..losses import GIoULoss
F
Feng Ni 已提交
25 26 27 28 29 30

__all__ = ['FCOS']


@register
class FCOS(BaseArch):
F
Feng Ni 已提交
31 32 33 34 35 36 37 38 39
    """
    FCOS network, see https://arxiv.org/abs/1904.01355

    Args:
        backbone (object): backbone instance
        neck (object): 'FPN' instance
        fcos_head (object): 'FCOSHead' instance
    """

F
Feng Ni 已提交
40 41
    __category__ = 'architecture'

F
Feng Ni 已提交
42
    def __init__(self, backbone, neck='FPN', fcos_head='FCOSHead'):
F
Feng Ni 已提交
43 44 45 46
        super(FCOS, self).__init__()
        self.backbone = backbone
        self.neck = neck
        self.fcos_head = fcos_head
47
        self.is_teacher = False
F
Feng Ni 已提交
48

F
Feng Ni 已提交
49 50 51 52 53 54
    @classmethod
    def from_config(cls, cfg, *args, **kwargs):
        backbone = create(cfg['backbone'])

        kwargs = {'input_shape': backbone.out_shape}
        neck = create(cfg['neck'], **kwargs)
F
Feng Ni 已提交
55

F
Feng Ni 已提交
56 57
        kwargs = {'input_shape': neck.out_shape}
        fcos_head = create(cfg['fcos_head'], **kwargs)
F
Feng Ni 已提交
58

F
Feng Ni 已提交
59 60 61 62 63
        return {
            'backbone': backbone,
            'neck': neck,
            "fcos_head": fcos_head,
        }
F
Feng Ni 已提交
64

F
Feng Ni 已提交
65 66 67
    def _forward(self):
        body_feats = self.backbone(self.inputs)
        fpn_feats = self.neck(body_feats)
F
Feng Ni 已提交
68

69 70
        self.is_teacher = self.inputs.get('is_teacher', False)
        if self.training or self.is_teacher:
F
Feng Ni 已提交
71 72
            losses = self.fcos_head(fpn_feats, self.inputs)
            return losses
F
Feng Ni 已提交
73
        else:
F
Feng Ni 已提交
74 75 76 77 78 79 80
            fcos_head_outs = self.fcos_head(fpn_feats)
            bbox_pred, bbox_num = self.fcos_head.post_process(
                fcos_head_outs, self.inputs['scale_factor'])
            return {'bbox': bbox_pred, 'bbox_num': bbox_num}

    def get_loss(self):
        return self._forward()
F
Feng Ni 已提交
81 82

    def get_pred(self):
F
Feng Ni 已提交
83
        return self._forward()
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175

    def get_loss_keys(self):
        return ['loss_cls', 'loss_box', 'loss_quality']

    def get_distill_loss(self,
                         fcos_head_outs,
                         teacher_fcos_head_outs,
                         ratio=0.01):
        student_logits, student_deltas, student_quality = fcos_head_outs
        teacher_logits, teacher_deltas, teacher_quality = teacher_fcos_head_outs
        nc = student_logits[0].shape[1]

        student_logits = paddle.concat(
            [
                _.transpose([0, 2, 3, 1]).reshape([-1, nc])
                for _ in student_logits
            ],
            axis=0)
        teacher_logits = paddle.concat(
            [
                _.transpose([0, 2, 3, 1]).reshape([-1, nc])
                for _ in teacher_logits
            ],
            axis=0)

        student_deltas = paddle.concat(
            [
                _.transpose([0, 2, 3, 1]).reshape([-1, 4])
                for _ in student_deltas
            ],
            axis=0)
        teacher_deltas = paddle.concat(
            [
                _.transpose([0, 2, 3, 1]).reshape([-1, 4])
                for _ in teacher_deltas
            ],
            axis=0)

        student_quality = paddle.concat(
            [
                _.transpose([0, 2, 3, 1]).reshape([-1, 1])
                for _ in student_quality
            ],
            axis=0)
        teacher_quality = paddle.concat(
            [
                _.transpose([0, 2, 3, 1]).reshape([-1, 1])
                for _ in teacher_quality
            ],
            axis=0)

        with paddle.no_grad():
            # Region Selection
            count_num = int(teacher_logits.shape[0] * ratio)
            teacher_probs = F.sigmoid(teacher_logits)
            max_vals = paddle.max(teacher_probs, 1)
            sorted_vals, sorted_inds = paddle.topk(max_vals,
                                                   teacher_logits.shape[0])
            mask = paddle.zeros_like(max_vals)
            mask[sorted_inds[:count_num]] = 1.
            fg_num = sorted_vals[:count_num].sum()
            b_mask = mask > 0

        # distill_loss_cls
        loss_logits = QFLv2(
            F.sigmoid(student_logits),
            teacher_probs,
            weight=mask,
            reduction="sum") / fg_num

        # distill_loss_box
        inputs = paddle.concat(
            (-student_deltas[b_mask][..., :2], student_deltas[b_mask][..., 2:]),
            axis=-1)
        targets = paddle.concat(
            (-teacher_deltas[b_mask][..., :2], teacher_deltas[b_mask][..., 2:]),
            axis=-1)
        iou_loss = GIoULoss(reduction='mean')
        loss_deltas = iou_loss(inputs, targets)

        # distill_loss_quality
        loss_quality = F.binary_cross_entropy(
            F.sigmoid(student_quality[b_mask]),
            F.sigmoid(teacher_quality[b_mask]),
            reduction='mean')

        return {
            "distill_loss_cls": loss_logits,
            "distill_loss_box": loss_deltas,
            "distill_loss_quality": loss_quality,
            "fg_sum": fg_num,
        }