detr.py 3.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
from .meta_arch import BaseArch
from ppdet.core.workspace import register, create

__all__ = ['DETR']
24
# Deformable DETR, DINO use the same architecture as DETR
25 26 27 28 29 30


@register
class DETR(BaseArch):
    __category__ = 'architecture'
    __inject__ = ['post_process']
31
    __shared__ = ['exclude_post_process']
32 33 34

    def __init__(self,
                 backbone,
35 36
                 transformer='DETRTransformer',
                 detr_head='DETRHead',
37 38
                 post_process='DETRBBoxPostProcess',
                 exclude_post_process=False):
39 40 41 42 43
        super(DETR, self).__init__()
        self.backbone = backbone
        self.transformer = transformer
        self.detr_head = detr_head
        self.post_process = post_process
44
        self.exclude_post_process = exclude_post_process
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71

    @classmethod
    def from_config(cls, cfg, *args, **kwargs):
        # backbone
        backbone = create(cfg['backbone'])
        # transformer
        kwargs = {'input_shape': backbone.out_shape}
        transformer = create(cfg['transformer'], **kwargs)
        # head
        kwargs = {
            'hidden_dim': transformer.hidden_dim,
            'nhead': transformer.nhead,
            'input_shape': backbone.out_shape
        }
        detr_head = create(cfg['detr_head'], **kwargs)

        return {
            'backbone': backbone,
            'transformer': transformer,
            "detr_head": detr_head,
        }

    def _forward(self):
        # Backbone
        body_feats = self.backbone(self.inputs)

        # Transformer
72
        pad_mask = self.inputs.get('pad_mask', None)
S
shangliang Xu 已提交
73
        out_transformer = self.transformer(body_feats, pad_mask, self.inputs)
74 75 76

        # DETR Head
        if self.training:
77 78 79 80 81 82 83
            detr_losses = self.detr_head(out_transformer, body_feats,
                                         self.inputs)
            detr_losses.update({
                'loss': paddle.add_n(
                    [v for k, v in detr_losses.items() if 'log' not in k])
            })
            return detr_losses
84 85
        else:
            preds = self.detr_head(out_transformer, body_feats)
86 87 88 89 90 91
            if self.exclude_post_process:
                bboxes, logits, masks = preds
                return bboxes, logits
            else:
                bbox, bbox_num = self.post_process(
                    preds, self.inputs['im_shape'], self.inputs['scale_factor'])
92 93
                output = {'bbox': bbox, 'bbox_num': bbox_num}
                return output
94

95
    def get_loss(self):
96
        return self._forward()
97 98

    def get_pred(self):
99
        return self._forward()