retina_head.py 9.9 KB
Newer Older
F
Feng Ni 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
B
Blake 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

F
Feng Ni 已提交
19 20
import math
import paddle
B
Blake 已提交
21 22 23 24
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Normal, Constant
F
Feng Ni 已提交
25
from ppdet.modeling.bbox_utils import bbox2delta, delta2bbox
B
Blake 已提交
26 27
from ppdet.modeling.heads.fcos_head import FCOSFeat

F
Feng Ni 已提交
28 29
from ppdet.core.workspace import register

B
Blake 已提交
30 31
__all__ = ['RetinaHead']

F
Feng Ni 已提交
32

B
Blake 已提交
33 34 35 36 37 38 39 40 41 42 43 44
@register
class RetinaFeat(FCOSFeat):
    """We use FCOSFeat to construct conv layers in RetinaNet.
    We rename FCOSFeat to RetinaFeat to avoid confusion.
    """
    pass


@register
class RetinaHead(nn.Layer):
    """Used in RetinaNet proposed in paper https://arxiv.org/pdf/1708.02002.pdf
    """
F
Feng Ni 已提交
45
    __shared__ = ['num_classes']
B
Blake 已提交
46
    __inject__ = [
F
Feng Ni 已提交
47 48 49 50
        'conv_feat', 'anchor_generator', 'bbox_assigner', 'loss_class',
        'loss_bbox', 'nms'
    ]

B
Blake 已提交
51 52
    def __init__(self,
                 num_classes=80,
F
Feng Ni 已提交
53 54 55 56 57 58
                 conv_feat='RetinaFeat',
                 anchor_generator='RetinaAnchorGenerator',
                 bbox_assigner='MaxIoUAssigner',
                 loss_class='FocalLoss',
                 loss_bbox='SmoothL1Loss',
                 nms='MultiClassNMS',
B
Blake 已提交
59 60
                 prior_prob=0.01,
                 nms_pre=1000,
F
Feng Ni 已提交
61
                 weights=[1., 1., 1., 1.]):
B
Blake 已提交
62 63 64 65 66 67 68 69
        super(RetinaHead, self).__init__()
        self.num_classes = num_classes
        self.conv_feat = conv_feat
        self.anchor_generator = anchor_generator
        self.bbox_assigner = bbox_assigner
        self.loss_class = loss_class
        self.loss_bbox = loss_bbox
        self.nms = nms
F
Feng Ni 已提交
70 71
        self.nms_pre = nms_pre
        self.weights = weights
B
Blake 已提交
72

F
Feng Ni 已提交
73
        bias_init_value = -math.log((1 - prior_prob) / prior_prob)
B
Blake 已提交
74 75 76
        num_anchors = self.anchor_generator.num_anchors
        self.retina_cls = nn.Conv2D(
            in_channels=self.conv_feat.feat_out,
F
Feng Ni 已提交
77
            out_channels=self.num_classes * num_anchors,
B
Blake 已提交
78 79 80
            kernel_size=3,
            stride=1,
            padding=1,
F
Feng Ni 已提交
81 82
            weight_attr=ParamAttr(initializer=Normal(
                mean=0.0, std=0.01)),
B
Blake 已提交
83 84 85 86 87 88 89
            bias_attr=ParamAttr(initializer=Constant(value=bias_init_value)))
        self.retina_reg = nn.Conv2D(
            in_channels=self.conv_feat.feat_out,
            out_channels=4 * num_anchors,
            kernel_size=3,
            stride=1,
            padding=1,
F
Feng Ni 已提交
90 91
            weight_attr=ParamAttr(initializer=Normal(
                mean=0.0, std=0.01)),
B
Blake 已提交
92 93
            bias_attr=ParamAttr(initializer=Constant(value=0)))

F
Feng Ni 已提交
94
    def forward(self, neck_feats, targets=None):
B
Blake 已提交
95 96 97 98 99 100 101 102 103
        cls_logits_list = []
        bboxes_reg_list = []
        for neck_feat in neck_feats:
            conv_cls_feat, conv_reg_feat = self.conv_feat(neck_feat)
            cls_logits = self.retina_cls(conv_cls_feat)
            bbox_reg = self.retina_reg(conv_reg_feat)
            cls_logits_list.append(cls_logits)
            bboxes_reg_list.append(bbox_reg)

F
Feng Ni 已提交
104 105 106 107 108 109
        if self.training:
            return self.get_loss([cls_logits_list, bboxes_reg_list], targets)
        else:
            return [cls_logits_list, bboxes_reg_list]

    def get_loss(self, head_outputs, targets):
B
Blake 已提交
110 111 112 113 114
        """Here we calculate loss for a batch of images.
        We assign anchors to gts in each image and gather all the assigned
        postive and negative samples. Then loss is calculated on the gathered
        samples.
        """
F
Feng Ni 已提交
115 116
        cls_logits_list, bboxes_reg_list = head_outputs
        anchors = self.anchor_generator(cls_logits_list)
B
Blake 已提交
117 118 119 120 121 122
        anchors = paddle.concat(anchors)

        # matches: contain gt_inds
        # match_labels: -1(ignore), 0(neg) or 1(pos)
        matches_list, match_labels_list = [], []
        # assign anchors to gts, no sampling is involved
F
Feng Ni 已提交
123
        for gt_bbox in targets['gt_bbox']:
B
Blake 已提交
124 125 126
            matches, match_labels = self.bbox_assigner(anchors, gt_bbox)
            matches_list.append(matches)
            match_labels_list.append(match_labels)
F
Feng Ni 已提交
127

B
Blake 已提交
128
        # reshape network outputs
F
Feng Ni 已提交
129 130 131 132 133 134 135 136
        cls_logits = [
            _.transpose([0, 2, 3, 1]).reshape([0, -1, self.num_classes])
            for _ in cls_logits_list
        ]
        bboxes_reg = [
            _.transpose([0, 2, 3, 1]).reshape([0, -1, 4])
            for _ in bboxes_reg_list
        ]
B
Blake 已提交
137 138 139 140 141 142 143 144
        cls_logits = paddle.concat(cls_logits, axis=1)
        bboxes_reg = paddle.concat(bboxes_reg, axis=1)

        cls_pred_list, cls_tar_list = [], []
        reg_pred_list, reg_tar_list = [], []
        # find and gather preds and targets in each image
        for matches, match_labels, cls_logit, bbox_reg, gt_bbox, gt_class in \
            zip(matches_list, match_labels_list, cls_logits, bboxes_reg,
F
Feng Ni 已提交
145
                targets['gt_bbox'], targets['gt_class']):
B
Blake 已提交
146 147 148 149 150 151 152 153
            pos_mask = (match_labels == 1)
            neg_mask = (match_labels == 0)
            chosen_mask = paddle.logical_or(pos_mask, neg_mask)

            gt_class = gt_class.reshape([-1])
            bg_class = paddle.to_tensor(
                [self.num_classes], dtype=gt_class.dtype)
            # a trick to assign num_classes to negative targets
F
Feng Ni 已提交
154 155 156 157
            gt_class = paddle.concat([gt_class, bg_class], axis=-1)
            matches = paddle.where(neg_mask,
                                   paddle.full_like(matches, gt_class.size - 1),
                                   matches)
B
Blake 已提交
158 159

            cls_pred = cls_logit[chosen_mask]
F
Feng Ni 已提交
160
            cls_tar = gt_class[matches[chosen_mask]]
B
Blake 已提交
161 162
            reg_pred = bbox_reg[pos_mask].reshape([-1, 4])
            reg_tar = gt_bbox[matches[pos_mask]].reshape([-1, 4])
F
Feng Ni 已提交
163
            reg_tar = bbox2delta(anchors[pos_mask], reg_tar, self.weights)
B
Blake 已提交
164 165 166 167 168
            cls_pred_list.append(cls_pred)
            cls_tar_list.append(cls_tar)
            reg_pred_list.append(reg_pred)
            reg_tar_list.append(reg_tar)
        cls_pred = paddle.concat(cls_pred_list)
F
Feng Ni 已提交
169
        cls_tar = paddle.concat(cls_tar_list)
B
Blake 已提交
170
        reg_pred = paddle.concat(reg_pred_list)
F
Feng Ni 已提交
171 172
        reg_tar = paddle.concat(reg_tar_list)

B
Blake 已提交
173 174
        avg_factor = max(1.0, reg_pred.shape[0])
        cls_loss = self.loss_class(
F
Feng Ni 已提交
175 176 177 178 179
            cls_pred, cls_tar, reduction='sum') / avg_factor

        if reg_pred.shape[0] == 0:
            reg_loss = paddle.zeros([1])
            reg_loss.stop_gradient = False
B
Blake 已提交
180 181
        else:
            reg_loss = self.loss_bbox(
F
Feng Ni 已提交
182 183 184 185 186 187 188 189 190
                reg_pred, reg_tar, reduction='sum') / avg_factor

        loss = cls_loss + reg_loss
        out_dict = {
            'loss_cls': cls_loss,
            'loss_reg': reg_loss,
            'loss': loss,
        }
        return out_dict
B
Blake 已提交
191 192 193

    def get_bboxes_single(self,
                          anchors,
F
Feng Ni 已提交
194 195
                          cls_scores_list,
                          bbox_preds_list,
B
Blake 已提交
196 197 198
                          im_shape,
                          scale_factor,
                          rescale=True):
F
Feng Ni 已提交
199
        assert len(cls_scores_list) == len(bbox_preds_list)
B
Blake 已提交
200 201
        mlvl_bboxes = []
        mlvl_scores = []
F
Feng Ni 已提交
202 203
        for anchor, cls_score, bbox_pred in zip(anchors, cls_scores_list,
                                                bbox_preds_list):
B
Blake 已提交
204 205 206 207 208 209
            cls_score = cls_score.reshape([-1, self.num_classes])
            bbox_pred = bbox_pred.reshape([-1, 4])
            if self.nms_pre is not None and cls_score.shape[0] > self.nms_pre:
                max_score = cls_score.max(axis=1)
                _, topk_inds = max_score.topk(self.nms_pre)
                bbox_pred = bbox_pred.gather(topk_inds)
F
Feng Ni 已提交
210
                anchor = anchor.gather(topk_inds)
B
Blake 已提交
211
                cls_score = cls_score.gather(topk_inds)
F
Feng Ni 已提交
212
            bbox_pred = delta2bbox(bbox_pred, anchor, self.weights).squeeze()
B
Blake 已提交
213 214 215 216 217 218 219 220 221 222 223
            mlvl_bboxes.append(bbox_pred)
            mlvl_scores.append(F.sigmoid(cls_score))
        mlvl_bboxes = paddle.concat(mlvl_bboxes)
        mlvl_bboxes = paddle.squeeze(mlvl_bboxes)
        if rescale:
            mlvl_bboxes = mlvl_bboxes / paddle.concat(
                [scale_factor[::-1], scale_factor[::-1]])
        mlvl_scores = paddle.concat(mlvl_scores)
        mlvl_scores = mlvl_scores.transpose([1, 0])
        return mlvl_bboxes, mlvl_scores

F
Feng Ni 已提交
224
    def decode(self, anchors, cls_logits, bboxes_reg, im_shape, scale_factor):
B
Blake 已提交
225 226
        batch_bboxes = []
        batch_scores = []
F
Feng Ni 已提交
227 228 229 230
        for img_id in range(cls_logits[0].shape[0]):
            num_lvls = len(cls_logits)
            cls_scores_list = [cls_logits[i][img_id] for i in range(num_lvls)]
            bbox_preds_list = [bboxes_reg[i][img_id] for i in range(num_lvls)]
B
Blake 已提交
231
            bboxes, scores = self.get_bboxes_single(
F
Feng Ni 已提交
232
                anchors, cls_scores_list, bbox_preds_list, im_shape[img_id],
B
Blake 已提交
233 234 235 236 237 238 239 240
                scale_factor[img_id])
            batch_bboxes.append(bboxes)
            batch_scores.append(scores)
        batch_bboxes = paddle.stack(batch_bboxes, axis=0)
        batch_scores = paddle.stack(batch_scores, axis=0)
        return batch_bboxes, batch_scores

    def post_process(self, head_outputs, im_shape, scale_factor):
F
Feng Ni 已提交
241 242 243 244 245 246 247
        cls_logits_list, bboxes_reg_list = head_outputs
        anchors = self.anchor_generator(cls_logits_list)
        cls_logits = [_.transpose([0, 2, 3, 1]) for _ in cls_logits_list]
        bboxes_reg = [_.transpose([0, 2, 3, 1]) for _ in bboxes_reg_list]
        bboxes, scores = self.decode(anchors, cls_logits, bboxes_reg, im_shape,
                                     scale_factor)

B
Blake 已提交
248 249
        bbox_pred, bbox_num, _ = self.nms(bboxes, scores)
        return bbox_pred, bbox_num