ppyoloe.py 12.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19
import copy
20 21 22

import paddle
import paddle.nn.functional as F
23 24
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
25 26
from ..ssod_utils import QFLv2
from ..losses import GIoULoss
27

28
__all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead']
F
Feng Ni 已提交
29 30
# PP-YOLOE and PP-YOLOE+ are recommended to use this architecture, especially when use distillation or aux head
# PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py when not use distillation or aux head
31 32 33 34 35


@register
class PPYOLOE(BaseArch):
    __category__ = 'architecture'
F
Feng Ni 已提交
36
    __shared__ = ['for_distill']
37 38 39 40 41 42 43
    __inject__ = ['post_process']

    def __init__(self,
                 backbone='CSPResNet',
                 neck='CustomCSPPAN',
                 yolo_head='PPYOLOEHead',
                 post_process='BBoxPostProcess',
F
Feng Ni 已提交
44 45
                 for_distill=False,
                 feat_distill_place='neck_feats',
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
                 for_mot=False):
        """
        PPYOLOE network, see https://arxiv.org/abs/2203.16250

        Args:
            backbone (nn.Layer): backbone instance
            neck (nn.Layer): neck instance
            yolo_head (nn.Layer): anchor_head instance
            post_process (object): `BBoxPostProcess` instance
            for_mot (bool): whether return other features for multi-object tracking
                models, default False in pure object detection models.
        """
        super(PPYOLOE, self).__init__()
        self.backbone = backbone
        self.neck = neck
        self.yolo_head = yolo_head
        self.post_process = post_process
        self.for_mot = for_mot
64 65 66 67 68

        # semi-det
        self.is_teacher = False

        # distill
F
Feng Ni 已提交
69 70 71 72
        self.for_distill = for_distill
        self.feat_distill_place = feat_distill_place
        if for_distill:
            assert feat_distill_place in ['backbone_feats', 'neck_feats']
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

    @classmethod
    def from_config(cls, cfg, *args, **kwargs):
        # backbone
        backbone = create(cfg['backbone'])

        # fpn
        kwargs = {'input_shape': backbone.out_shape}
        neck = create(cfg['neck'], **kwargs)

        # head
        kwargs = {'input_shape': neck.out_shape}
        yolo_head = create(cfg['yolo_head'], **kwargs)

        return {
            'backbone': backbone,
            'neck': neck,
            "yolo_head": yolo_head,
        }

    def _forward(self):
        body_feats = self.backbone(self.inputs)
        neck_feats = self.neck(body_feats, self.for_mot)

97 98
        self.is_teacher = self.inputs.get('is_teacher', False)  # for semi-det
        if self.training or self.is_teacher:
99
            yolo_losses = self.yolo_head(neck_feats, self.inputs)
F
Feng Ni 已提交
100 101 102 103 104 105 106 107

            if self.for_distill:
                if self.feat_distill_place == 'backbone_feats':
                    self.yolo_head.distill_pairs['backbone_feats'] = body_feats
                elif self.feat_distill_place == 'neck_feats':
                    self.yolo_head.distill_pairs['neck_feats'] = neck_feats
                else:
                    raise ValueError
108 109
            return yolo_losses
        else:
F
Feng Ni 已提交
110
            cam_data = {}  # record bbox scores and index before nms
111
            yolo_head_outs = self.yolo_head(neck_feats)
F
Feng Ni 已提交
112 113
            cam_data['scores'] = yolo_head_outs[0]

114
            if self.post_process is not None:
F
Feng Ni 已提交
115
                bbox, bbox_num, before_nms_indexes = self.post_process(
116 117
                    yolo_head_outs, self.yolo_head.mask_anchors,
                    self.inputs['im_shape'], self.inputs['scale_factor'])
F
Feng Ni 已提交
118
                cam_data['before_nms_indexes'] = before_nms_indexes
119
            else:
F
Feng Ni 已提交
120
                bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
121
                    yolo_head_outs, self.inputs['scale_factor'])
F
Feng Ni 已提交
122 123 124
                # data for cam
                cam_data['before_nms_indexes'] = before_nms_indexes
            output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
125 126 127 128 129 130 131 132

            return output

    def get_loss(self):
        return self._forward()

    def get_pred(self):
        return self._forward()
133

134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
    def get_loss_keys(self):
        return ['loss_cls', 'loss_iou', 'loss_dfl', 'loss_contrast']

    def get_ssod_distill_loss(self, student_head_outs, teacher_head_outs,
                              train_cfg):
        # for semi-det distill
        # student_probs: already sigmoid
        student_probs, student_deltas, student_dfl = student_head_outs
        teacher_probs, teacher_deltas, teacher_dfl = teacher_head_outs
        bs, l, nc = student_probs.shape[:]
        student_probs = student_probs.reshape([-1, nc])
        teacher_probs = teacher_probs.reshape([-1, nc])
        student_deltas = student_deltas.reshape([-1, 4])
        teacher_deltas = teacher_deltas.reshape([-1, 4])
        student_dfl = student_dfl.reshape([-1, 4, self.yolo_head.reg_channels])
        teacher_dfl = teacher_dfl.reshape([-1, 4, self.yolo_head.reg_channels])

        ratio = train_cfg.get('ratio', 0.01)

        # for contrast loss
        curr_iter = train_cfg['curr_iter']
        st_iter = train_cfg['st_iter']
        if curr_iter == st_iter + 1:
            # start semi-det training
            self.queue_ptr = 0
            self.queue_size = int(bs * l * ratio)
            self.queue_feats = paddle.zeros([self.queue_size, nc])
            self.queue_probs = paddle.zeros([self.queue_size, nc])
        contrast_loss_cfg = train_cfg['contrast_loss']
        temperature = contrast_loss_cfg.get('temperature', 0.2)
        alpha = contrast_loss_cfg.get('alpha', 0.9)
        smooth_iter = contrast_loss_cfg.get('smooth_iter', 100) + st_iter

        with paddle.no_grad():
            # Region Selection
            count_num = int(teacher_probs.shape[0] * ratio)
            max_vals = paddle.max(teacher_probs, 1)
            sorted_vals, sorted_inds = paddle.topk(max_vals,
                                                   teacher_probs.shape[0])
            mask = paddle.zeros_like(max_vals)
            mask[sorted_inds[:count_num]] = 1.
            fg_num = sorted_vals[:count_num].sum()
            b_mask = mask > 0.

            # for contrast loss
            probs = teacher_probs[b_mask].detach()
            if curr_iter > smooth_iter:  # memory-smoothing
                A = paddle.exp(
                    paddle.mm(teacher_probs[b_mask], self.queue_probs.t()) /
                    temperature)
                A = A / A.sum(1, keepdim=True)
                probs = alpha * probs + (1 - alpha) * paddle.mm(
                    A, self.queue_probs)
            n = student_probs[b_mask].shape[0]
            # update memory bank
            self.queue_feats[self.queue_ptr:self.queue_ptr +
                             n, :] = teacher_probs[b_mask].detach()
            self.queue_probs[self.queue_ptr:self.queue_ptr +
                             n, :] = teacher_probs[b_mask].detach()
            self.queue_ptr = (self.queue_ptr + n) % self.queue_size

        # embedding similarity
        sim = paddle.exp(
            paddle.mm(student_probs[b_mask], teacher_probs[b_mask].t()) / 0.2)
        sim_probs = sim / sim.sum(1, keepdim=True)
        # pseudo-label graph with self-loop
        Q = paddle.mm(probs, probs.t())
        Q.fill_diagonal_(1)
        pos_mask = (Q >= 0.5).astype('float32')
        Q = Q * pos_mask
        Q = Q / Q.sum(1, keepdim=True)
        # contrastive loss
        loss_contrast = -(paddle.log(sim_probs + 1e-7) * Q).sum(1)
        loss_contrast = loss_contrast.mean()

        # distill_loss_cls
        loss_cls = QFLv2(
            student_probs, teacher_probs, weight=mask, reduction="sum") / fg_num

        # distill_loss_iou
        inputs = paddle.concat(
            (-student_deltas[b_mask][..., :2], student_deltas[b_mask][..., 2:]),
            -1)
        targets = paddle.concat(
            (-teacher_deltas[b_mask][..., :2], teacher_deltas[b_mask][..., 2:]),
            -1)
        iou_loss = GIoULoss(reduction='mean')
        loss_iou = iou_loss(inputs, targets)

        # distill_loss_dfl
        loss_dfl = F.cross_entropy(
            student_dfl[b_mask].reshape([-1, self.yolo_head.reg_channels]),
            teacher_dfl[b_mask].reshape([-1, self.yolo_head.reg_channels]),
            soft_label=True,
            reduction='mean')

        return {
            "distill_loss_cls": loss_cls,
            "distill_loss_iou": loss_iou,
            "distill_loss_dfl": loss_dfl,
            "distill_loss_contrast": loss_contrast,
            "fg_sum": fg_num,
        }

238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317

@register
class PPYOLOEWithAuxHead(BaseArch):
    __category__ = 'architecture'
    __inject__ = ['post_process']

    def __init__(self,
                 backbone='CSPResNet',
                 neck='CustomCSPPAN',
                 yolo_head='PPYOLOEHead',
                 aux_head='SimpleConvHead',
                 post_process='BBoxPostProcess',
                 for_mot=False,
                 detach_epoch=5):
        """
        PPYOLOE network, see https://arxiv.org/abs/2203.16250

        Args:
            backbone (nn.Layer): backbone instance
            neck (nn.Layer): neck instance
            yolo_head (nn.Layer): anchor_head instance
            post_process (object): `BBoxPostProcess` instance
            for_mot (bool): whether return other features for multi-object tracking
                models, default False in pure object detection models.
        """
        super(PPYOLOEWithAuxHead, self).__init__()
        self.backbone = backbone
        self.neck = neck
        self.aux_neck = copy.deepcopy(self.neck)

        self.yolo_head = yolo_head
        self.aux_head = aux_head
        self.post_process = post_process
        self.for_mot = for_mot
        self.detach_epoch = detach_epoch

    @classmethod
    def from_config(cls, cfg, *args, **kwargs):
        # backbone
        backbone = create(cfg['backbone'])

        # fpn
        kwargs = {'input_shape': backbone.out_shape}
        neck = create(cfg['neck'], **kwargs)
        aux_neck = copy.deepcopy(neck)

        # head
        kwargs = {'input_shape': neck.out_shape}
        yolo_head = create(cfg['yolo_head'], **kwargs)
        aux_head = create(cfg['aux_head'], **kwargs)

        return {
            'backbone': backbone,
            'neck': neck,
            "yolo_head": yolo_head,
            'aux_head': aux_head,
        }

    def _forward(self):
        body_feats = self.backbone(self.inputs)
        neck_feats = self.neck(body_feats, self.for_mot)

        if self.training:
            if self.inputs['epoch_id'] >= self.detach_epoch:
                aux_neck_feats = self.aux_neck([f.detach() for f in body_feats])
                dual_neck_feats = (paddle.concat(
                    [f.detach(), aux_f], axis=1) for f, aux_f in
                                   zip(neck_feats, aux_neck_feats))
            else:
                aux_neck_feats = self.aux_neck(body_feats)
                dual_neck_feats = (paddle.concat(
                    [f, aux_f], axis=1) for f, aux_f in
                                   zip(neck_feats, aux_neck_feats))
            aux_cls_scores, aux_bbox_preds = self.aux_head(dual_neck_feats)
            loss = self.yolo_head(
                neck_feats,
                self.inputs,
                aux_pred=[aux_cls_scores, aux_bbox_preds])
            return loss
        else:
F
Feng Ni 已提交
318
            cam_data = {}  # record bbox scores and index before nms
319
            yolo_head_outs = self.yolo_head(neck_feats)
F
Feng Ni 已提交
320 321
            cam_data['scores'] = yolo_head_outs[0]

322
            if self.post_process is not None:
F
Feng Ni 已提交
323
                bbox, bbox_num, before_nms_indexes = self.post_process(
324 325
                    yolo_head_outs, self.yolo_head.mask_anchors,
                    self.inputs['im_shape'], self.inputs['scale_factor'])
F
Feng Ni 已提交
326
                cam_data['before_nms_indexes'] = before_nms_indexes
327
            else:
F
Feng Ni 已提交
328
                bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
329
                    yolo_head_outs, self.inputs['scale_factor'])
F
Feng Ni 已提交
330 331 332
                # data for cam
                cam_data['before_nms_indexes'] = before_nms_indexes
            output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
333 334 335 336 337 338 339 340

            return output

    def get_loss(self):
        return self._forward()

    def get_pred(self):
        return self._forward()