cascade_rcnn.py 5.5 KB
Newer Older
W
wangguanzhong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
wangguanzhong 已提交
19
import paddle
20 21 22 23 24 25 26 27 28
from ppdet.core.workspace import register
from .meta_arch import BaseArch

__all__ = ['CascadeRCNN']


@register
class CascadeRCNN(BaseArch):
    __category__ = 'architecture'
W
wangguanzhong 已提交
29
    __shared__ = ['roi_stages']
30 31 32 33 34
    __inject__ = [
        'anchor',
        'proposal',
        'mask',
        'backbone',
W
wangguanzhong 已提交
35
        'neck',
36 37 38
        'rpn_head',
        'bbox_head',
        'mask_head',
W
wangguanzhong 已提交
39 40
        'bbox_post_process',
        'mask_post_process',
41 42 43 44 45 46 47 48
    ]

    def __init__(self,
                 anchor,
                 proposal,
                 backbone,
                 rpn_head,
                 bbox_head,
W
wangguanzhong 已提交
49 50 51 52 53 54 55
                 bbox_post_process,
                 neck=None,
                 mask=None,
                 mask_head=None,
                 mask_post_process=None,
                 roi_stages=3):
        super(CascadeRCNN, self).__init__()
56 57 58 59 60
        self.anchor = anchor
        self.proposal = proposal
        self.backbone = backbone
        self.rpn_head = rpn_head
        self.bbox_head = bbox_head
W
wangguanzhong 已提交
61 62 63
        self.bbox_post_process = bbox_post_process
        self.neck = neck
        self.mask = mask
64
        self.mask_head = mask_head
W
wangguanzhong 已提交
65 66 67
        self.mask_post_process = mask_post_process
        self.roi_stages = roi_stages
        self.with_mask = mask is not None
68 69 70

    def model_arch(self, ):
        # Backbone
W
wangguanzhong 已提交
71 72 73 74 75
        body_feats = self.backbone(self.inputs)

        # Neck
        if self.neck is not None:
            body_feats, spatial_scale = self.neck(body_feats)
76 77

        # RPN
W
wangguanzhong 已提交
78 79 80 81 82
        # rpn_head returns two list: rpn_feat, rpn_head_out 
        # each element in rpn_feats contains rpn feature on each level,
        # and the length is 1 when the neck is not applied.
        # each element in rpn_head_out contains (rpn_rois_score, rpn_rois_delta)
        rpn_feat, self.rpn_head_out = self.rpn_head(self.inputs, body_feats)
83 84

        # Anchor
W
wangguanzhong 已提交
85 86 87 88 89 90 91 92 93 94 95 96
        # anchor_out returns a list,
        # each element contains (anchor, anchor_var)
        self.anchor_out = self.anchor(rpn_feat)

        # Proposal RoI
        # compute targets here when training
        rois = None
        bbox_head_out = None
        max_overlap = None
        self.bbox_head_list = []
        rois_list = []
        for i in range(self.roi_stages):
97
            # Proposal BBox
W
wangguanzhong 已提交
98 99 100 101 102 103 104 105 106 107
            rois = self.proposal(
                self.inputs,
                self.rpn_head_out,
                self.anchor_out,
                i,
                rois,
                bbox_head_out,
                max_overlap=max_overlap)
            rois_list.append(rois)
            max_overlap = self.proposal.get_max_overlap()
108
            # BBox Head
W
wangguanzhong 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
            bbox_feat, bbox_head_out, _ = self.bbox_head(body_feats, rois,
                                                         spatial_scale, i)
            self.bbox_head_list.append(bbox_head_out)

        if self.inputs['mode'] == 'infer':
            bbox_pred, bboxes = self.bbox_head.get_cascade_prediction(
                self.bbox_head_list, rois_list)
            self.bboxes = self.bbox_post_process(
                bbox_pred,
                bboxes,
                self.inputs['im_shape'],
                self.inputs['scale_factor'],
                var_weight=3.)

        if self.with_mask:
            rois = rois_list[-1]
            rois_has_mask_int32 = None
            if self.inputs['mode'] == 'train':
                bbox_targets = self.proposal.get_targets()[-1]
                self.bboxes, rois_has_mask_int32 = self.mask(self.inputs, rois,
                                                             bbox_targets)
            # Mask Head 
            self.mask_head_out = self.mask_head(
                self.inputs, body_feats, self.bboxes, bbox_feat,
                rois_has_mask_int32, spatial_scale)
134

K
Kaipeng Deng 已提交
135
    def get_loss(self, ):
W
wangguanzhong 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
        loss = {}

        # RPN loss
        rpn_loss_inputs = self.anchor.generate_loss_inputs(
            self.inputs, self.rpn_head_out, self.anchor_out)
        loss_rpn = self.rpn_head.get_loss(rpn_loss_inputs)
        loss.update(loss_rpn)

        # BBox loss
        bbox_targets_list = self.proposal.get_targets()
        loss_bbox = self.bbox_head.get_loss(self.bbox_head_list,
                                            bbox_targets_list)
        loss.update(loss_bbox)

        if self.with_mask:
            # Mask loss
            mask_targets = self.mask.get_targets()
            loss_mask = self.mask_head.get_loss(self.mask_head_out,
                                                mask_targets)
            loss.update(loss_mask)

        total_loss = paddle.add_n(list(loss.values()))
        loss.update({'loss': total_loss})
        return loss

161
    def get_pred(self):
W
wangguanzhong 已提交
162 163
        bbox, bbox_num = self.bboxes
        output = {
164 165
            'bbox': bbox,
            'bbox_num': bbox_num,
166
        }
W
wangguanzhong 已提交
167
        if self.with_mask:
168
            output.update(self.mask_head_out)
W
wangguanzhong 已提交
169
        return output