ppyoloe.py 8.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19
import copy
20
import paddle
21 22 23
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch

24
__all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead']
F
Feng Ni 已提交
25 26
# PP-YOLOE and PP-YOLOE+ are recommended to use this architecture, especially when use distillation or aux head
# PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py when not use distillation or aux head
27 28 29 30


@register
class PPYOLOE(BaseArch):
F
Feng Ni 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
    """
    PPYOLOE network, see https://arxiv.org/abs/2203.16250

    Args:
        backbone (nn.Layer): backbone instance
        neck (nn.Layer): neck instance
        yolo_head (nn.Layer): anchor_head instance
        post_process (object): `BBoxPostProcess` instance
        ssod_loss (object): 'SSODPPYOLOELoss' instance, only used for semi-det(ssod)
        for_distill (bool): whether for distillation
        feat_distill_place (str): distill which feature for distillation
        for_mot (bool): whether return other features for multi-object tracking
            models, default False in pure object detection models.
    """

46
    __category__ = 'architecture'
F
Feng Ni 已提交
47
    __shared__ = ['for_distill']
F
Feng Ni 已提交
48
    __inject__ = ['post_process', 'ssod_loss']
49 50 51 52 53 54

    def __init__(self,
                 backbone='CSPResNet',
                 neck='CustomCSPPAN',
                 yolo_head='PPYOLOEHead',
                 post_process='BBoxPostProcess',
F
Feng Ni 已提交
55
                 ssod_loss='SSODPPYOLOELoss',
F
Feng Ni 已提交
56 57
                 for_distill=False,
                 feat_distill_place='neck_feats',
58 59 60 61 62 63 64
                 for_mot=False):
        super(PPYOLOE, self).__init__()
        self.backbone = backbone
        self.neck = neck
        self.yolo_head = yolo_head
        self.post_process = post_process
        self.for_mot = for_mot
65

F
Feng Ni 已提交
66
        # for ssod, semi-det
67
        self.is_teacher = False
F
Feng Ni 已提交
68
        self.ssod_loss = ssod_loss
69 70

        # distill
F
Feng Ni 已提交
71 72 73 74
        self.for_distill = for_distill
        self.feat_distill_place = feat_distill_place
        if for_distill:
            assert feat_distill_place in ['backbone_feats', 'neck_feats']
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95

    @classmethod
    def from_config(cls, cfg, *args, **kwargs):
        backbone = create(cfg['backbone'])

        kwargs = {'input_shape': backbone.out_shape}
        neck = create(cfg['neck'], **kwargs)

        kwargs = {'input_shape': neck.out_shape}
        yolo_head = create(cfg['yolo_head'], **kwargs)

        return {
            'backbone': backbone,
            'neck': neck,
            "yolo_head": yolo_head,
        }

    def _forward(self):
        body_feats = self.backbone(self.inputs)
        neck_feats = self.neck(body_feats, self.for_mot)

96 97
        self.is_teacher = self.inputs.get('is_teacher', False)  # for semi-det
        if self.training or self.is_teacher:
98
            yolo_losses = self.yolo_head(neck_feats, self.inputs)
F
Feng Ni 已提交
99 100 101 102 103 104 105 106

            if self.for_distill:
                if self.feat_distill_place == 'backbone_feats':
                    self.yolo_head.distill_pairs['backbone_feats'] = body_feats
                elif self.feat_distill_place == 'neck_feats':
                    self.yolo_head.distill_pairs['neck_feats'] = neck_feats
                else:
                    raise ValueError
107 108
            return yolo_losses
        else:
F
Feng Ni 已提交
109
            cam_data = {}  # record bbox scores and index before nms
110
            yolo_head_outs = self.yolo_head(neck_feats)
F
Feng Ni 已提交
111 112
            cam_data['scores'] = yolo_head_outs[0]

113
            if self.post_process is not None:
F
Feng Ni 已提交
114
                bbox, bbox_num, before_nms_indexes = self.post_process(
115 116
                    yolo_head_outs, self.yolo_head.mask_anchors,
                    self.inputs['im_shape'], self.inputs['scale_factor'])
F
Feng Ni 已提交
117
                cam_data['before_nms_indexes'] = before_nms_indexes
118
            else:
F
Feng Ni 已提交
119
                bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
120
                    yolo_head_outs, self.inputs['scale_factor'])
F
Feng Ni 已提交
121 122 123
                # data for cam
                cam_data['before_nms_indexes'] = before_nms_indexes
            output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
124 125 126 127 128 129 130 131

            return output

    def get_loss(self):
        return self._forward()

    def get_pred(self):
        return self._forward()
132

133 134 135
    def get_loss_keys(self):
        return ['loss_cls', 'loss_iou', 'loss_dfl', 'loss_contrast']

F
Feng Ni 已提交
136 137 138 139
    def get_ssod_loss(self, student_head_outs, teacher_head_outs, train_cfg):
        ssod_losses = self.ssod_loss(student_head_outs, teacher_head_outs,
                                     train_cfg)
        return ssod_losses
140

141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220

@register
class PPYOLOEWithAuxHead(BaseArch):
    __category__ = 'architecture'
    __inject__ = ['post_process']

    def __init__(self,
                 backbone='CSPResNet',
                 neck='CustomCSPPAN',
                 yolo_head='PPYOLOEHead',
                 aux_head='SimpleConvHead',
                 post_process='BBoxPostProcess',
                 for_mot=False,
                 detach_epoch=5):
        """
        PPYOLOE network, see https://arxiv.org/abs/2203.16250

        Args:
            backbone (nn.Layer): backbone instance
            neck (nn.Layer): neck instance
            yolo_head (nn.Layer): anchor_head instance
            post_process (object): `BBoxPostProcess` instance
            for_mot (bool): whether return other features for multi-object tracking
                models, default False in pure object detection models.
        """
        super(PPYOLOEWithAuxHead, self).__init__()
        self.backbone = backbone
        self.neck = neck
        self.aux_neck = copy.deepcopy(self.neck)

        self.yolo_head = yolo_head
        self.aux_head = aux_head
        self.post_process = post_process
        self.for_mot = for_mot
        self.detach_epoch = detach_epoch

    @classmethod
    def from_config(cls, cfg, *args, **kwargs):
        # backbone
        backbone = create(cfg['backbone'])

        # fpn
        kwargs = {'input_shape': backbone.out_shape}
        neck = create(cfg['neck'], **kwargs)
        aux_neck = copy.deepcopy(neck)

        # head
        kwargs = {'input_shape': neck.out_shape}
        yolo_head = create(cfg['yolo_head'], **kwargs)
        aux_head = create(cfg['aux_head'], **kwargs)

        return {
            'backbone': backbone,
            'neck': neck,
            "yolo_head": yolo_head,
            'aux_head': aux_head,
        }

    def _forward(self):
        body_feats = self.backbone(self.inputs)
        neck_feats = self.neck(body_feats, self.for_mot)

        if self.training:
            if self.inputs['epoch_id'] >= self.detach_epoch:
                aux_neck_feats = self.aux_neck([f.detach() for f in body_feats])
                dual_neck_feats = (paddle.concat(
                    [f.detach(), aux_f], axis=1) for f, aux_f in
                                   zip(neck_feats, aux_neck_feats))
            else:
                aux_neck_feats = self.aux_neck(body_feats)
                dual_neck_feats = (paddle.concat(
                    [f, aux_f], axis=1) for f, aux_f in
                                   zip(neck_feats, aux_neck_feats))
            aux_cls_scores, aux_bbox_preds = self.aux_head(dual_neck_feats)
            loss = self.yolo_head(
                neck_feats,
                self.inputs,
                aux_pred=[aux_cls_scores, aux_bbox_preds])
            return loss
        else:
F
Feng Ni 已提交
221
            cam_data = {}  # record bbox scores and index before nms
222
            yolo_head_outs = self.yolo_head(neck_feats)
F
Feng Ni 已提交
223 224
            cam_data['scores'] = yolo_head_outs[0]

225
            if self.post_process is not None:
F
Feng Ni 已提交
226
                bbox, bbox_num, before_nms_indexes = self.post_process(
227 228
                    yolo_head_outs, self.yolo_head.mask_anchors,
                    self.inputs['im_shape'], self.inputs['scale_factor'])
F
Feng Ni 已提交
229
                cam_data['before_nms_indexes'] = before_nms_indexes
230
            else:
F
Feng Ni 已提交
231
                bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
232
                    yolo_head_outs, self.inputs['scale_factor'])
F
Feng Ni 已提交
233 234 235
                # data for cam
                cam_data['before_nms_indexes'] = before_nms_indexes
            output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
236 237 238 239 240 241 242 243

            return output

    def get_loss(self):
        return self._forward()

    def get_pred(self):
        return self._forward()