mask_rcnn.py 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

F
FDInSky 已提交
15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19
import paddle
F
FDInSky 已提交
20 21
from ppdet.core.workspace import register
from .meta_arch import BaseArch
22

F
FDInSky 已提交
23 24 25 26 27 28 29 30 31 32 33
__all__ = ['MaskRCNN']


@register
class MaskRCNN(BaseArch):
    __category__ = 'architecture'
    __inject__ = [
        'anchor',
        'proposal',
        'mask',
        'backbone',
34
        'neck',
F
FDInSky 已提交
35 36 37
        'rpn_head',
        'bbox_head',
        'mask_head',
38 39
        'bbox_post_process',
        'mask_post_process',
F
FDInSky 已提交
40 41
    ]

42 43 44 45 46 47 48 49
    def __init__(self,
                 anchor,
                 proposal,
                 mask,
                 backbone,
                 rpn_head,
                 bbox_head,
                 mask_head,
50 51
                 bbox_post_process,
                 mask_post_process,
52 53
                 neck=None):
        super(MaskRCNN, self).__init__()
F
FDInSky 已提交
54 55 56 57
        self.anchor = anchor
        self.proposal = proposal
        self.mask = mask
        self.backbone = backbone
58
        self.neck = neck
F
FDInSky 已提交
59 60 61
        self.rpn_head = rpn_head
        self.bbox_head = bbox_head
        self.mask_head = mask_head
62 63
        self.bbox_post_process = bbox_post_process
        self.mask_post_process = mask_post_process
F
FDInSky 已提交
64

65
    def model_arch(self):
F
FDInSky 已提交
66
        # Backbone
67 68 69 70 71 72
        body_feats = self.backbone(self.inputs)
        spatial_scale = None

        # Neck
        if self.neck is not None:
            body_feats, spatial_scale = self.neck(body_feats)
F
FDInSky 已提交
73 74

        # RPN
75 76 77 78 79
        # rpn_head returns two list: rpn_feat, rpn_head_out 
        # each element in rpn_feats contains rpn feature on each level,
        # and the length is 1 when the neck is not applied.
        # each element in rpn_head_out contains (rpn_rois_score, rpn_rois_delta)
        rpn_feat, self.rpn_head_out = self.rpn_head(self.inputs, body_feats)
F
FDInSky 已提交
80 81

        # Anchor
82 83 84
        # anchor_out returns a list,
        # each element contains (anchor, anchor_var)
        self.anchor_out = self.anchor(rpn_feat)
F
FDInSky 已提交
85

86 87 88
        # Proposal RoI 
        # compute targets here when training
        rois = self.proposal(self.inputs, self.rpn_head_out, self.anchor_out)
F
FDInSky 已提交
89
        # BBox Head
90 91 92 93 94
        bbox_feat, self.bbox_head_out = self.bbox_head(body_feats, rois,
                                                       spatial_scale)

        rois_has_mask_int32 = None
        if self.inputs['mode'] == 'infer':
95 96
            bbox_pred, bboxes = self.bbox_head.get_prediction(
                self.bbox_head_out, rois)
97
            # Refine bbox by the output from bbox_head at test stage
98 99
            self.bboxes = self.bbox_post_process(bbox_pred, bboxes,
                                                 self.inputs['im_info'])
100 101 102 103 104 105 106 107 108 109 110
        else:
            # Proposal RoI for Mask branch
            # bboxes update at training stage only
            bbox_targets = self.proposal.get_targets()[0]
            self.bboxes, rois_has_mask_int32 = self.mask(self.inputs, rois,
                                                         bbox_targets)

        # Mask Head 
        self.mask_head_out = self.mask_head(self.inputs, body_feats,
                                            self.bboxes, bbox_feat,
                                            rois_has_mask_int32, spatial_scale)
F
FDInSky 已提交
111

K
Kaipeng Deng 已提交
112
    def get_loss(self, ):
113
        loss = {}
F
FDInSky 已提交
114

115 116 117
        # RPN loss
        rpn_loss_inputs = self.anchor.generate_loss_inputs(
            self.inputs, self.rpn_head_out, self.anchor_out)
K
Kaipeng Deng 已提交
118
        loss_rpn = self.rpn_head.get_loss(rpn_loss_inputs)
119
        loss.update(loss_rpn)
F
FDInSky 已提交
120

121 122
        # BBox loss
        bbox_targets = self.proposal.get_targets()
K
Kaipeng Deng 已提交
123
        loss_bbox = self.bbox_head.get_loss(self.bbox_head_out, bbox_targets)
124
        loss.update(loss_bbox)
F
FDInSky 已提交
125

126 127
        # Mask loss
        mask_targets = self.mask.get_targets()
K
Kaipeng Deng 已提交
128
        loss_mask = self.mask_head.get_loss(self.mask_head_out, mask_targets)
129
        loss.update(loss_mask)
F
FDInSky 已提交
130

131
        total_loss = paddle.add_n(list(loss.values()))
132 133
        loss.update({'loss': total_loss})
        return loss
F
FDInSky 已提交
134

K
Kaipeng Deng 已提交
135
    def get_pred(self, ):
136
        mask = self.mask_post_process(self.bboxes, self.mask_head_out,
137 138 139 140 141 142
                                      self.inputs['im_info'])
        bbox, bbox_num = self.bboxes
        output = {
            'bbox': bbox.numpy(),
            'bbox_num': bbox_num.numpy(),
            'im_id': self.inputs['im_id'].numpy()
F
FDInSky 已提交
143
        }
144 145
        output.update(mask)
        return output