mask_rcnn.py 4.3 KB
Newer Older
F
FDInSky 已提交
1 2 3 4 5 6 7
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from paddle import fluid
from ppdet.core.workspace import register
from .meta_arch import BaseArch
8

F
FDInSky 已提交
9 10 11 12 13 14 15 16 17 18 19
__all__ = ['MaskRCNN']


@register
class MaskRCNN(BaseArch):
    __category__ = 'architecture'
    __inject__ = [
        'anchor',
        'proposal',
        'mask',
        'backbone',
20
        'neck',
F
FDInSky 已提交
21 22 23
        'rpn_head',
        'bbox_head',
        'mask_head',
24 25
        'bbox_post_process',
        'mask_post_process',
F
FDInSky 已提交
26 27
    ]

28 29 30 31 32 33 34 35
    def __init__(self,
                 anchor,
                 proposal,
                 mask,
                 backbone,
                 rpn_head,
                 bbox_head,
                 mask_head,
36 37
                 bbox_post_process,
                 mask_post_process,
38 39
                 neck=None):
        super(MaskRCNN, self).__init__()
F
FDInSky 已提交
40 41 42 43
        self.anchor = anchor
        self.proposal = proposal
        self.mask = mask
        self.backbone = backbone
44
        self.neck = neck
F
FDInSky 已提交
45 46 47
        self.rpn_head = rpn_head
        self.bbox_head = bbox_head
        self.mask_head = mask_head
48 49
        self.bbox_post_process = bbox_post_process
        self.mask_post_process = mask_post_process
F
FDInSky 已提交
50

51
    def model_arch(self):
F
FDInSky 已提交
52
        # Backbone
53 54 55 56 57 58
        body_feats = self.backbone(self.inputs)
        spatial_scale = None

        # Neck
        if self.neck is not None:
            body_feats, spatial_scale = self.neck(body_feats)
F
FDInSky 已提交
59 60

        # RPN
61 62 63 64 65
        # rpn_head returns two list: rpn_feat, rpn_head_out 
        # each element in rpn_feats contains rpn feature on each level,
        # and the length is 1 when the neck is not applied.
        # each element in rpn_head_out contains (rpn_rois_score, rpn_rois_delta)
        rpn_feat, self.rpn_head_out = self.rpn_head(self.inputs, body_feats)
F
FDInSky 已提交
66 67

        # Anchor
68 69 70
        # anchor_out returns a list,
        # each element contains (anchor, anchor_var)
        self.anchor_out = self.anchor(rpn_feat)
F
FDInSky 已提交
71

72 73 74
        # Proposal RoI 
        # compute targets here when training
        rois = self.proposal(self.inputs, self.rpn_head_out, self.anchor_out)
F
FDInSky 已提交
75
        # BBox Head
76 77 78 79 80
        bbox_feat, self.bbox_head_out = self.bbox_head(body_feats, rois,
                                                       spatial_scale)

        rois_has_mask_int32 = None
        if self.inputs['mode'] == 'infer':
81 82
            bbox_pred, bboxes = self.bbox_head.get_prediction(
                self.bbox_head_out, rois)
83
            # Refine bbox by the output from bbox_head at test stage
84 85
            self.bboxes = self.bbox_post_process(bbox_pred, bboxes,
                                                 self.inputs['im_info'])
86 87 88 89 90 91 92 93 94 95 96
        else:
            # Proposal RoI for Mask branch
            # bboxes update at training stage only
            bbox_targets = self.proposal.get_targets()[0]
            self.bboxes, rois_has_mask_int32 = self.mask(self.inputs, rois,
                                                         bbox_targets)

        # Mask Head 
        self.mask_head_out = self.mask_head(self.inputs, body_feats,
                                            self.bboxes, bbox_feat,
                                            rois_has_mask_int32, spatial_scale)
F
FDInSky 已提交
97

K
Kaipeng Deng 已提交
98
    def get_loss(self, ):
99
        loss = {}
F
FDInSky 已提交
100

101 102 103
        # RPN loss
        rpn_loss_inputs = self.anchor.generate_loss_inputs(
            self.inputs, self.rpn_head_out, self.anchor_out)
K
Kaipeng Deng 已提交
104
        loss_rpn = self.rpn_head.get_loss(rpn_loss_inputs)
105
        loss.update(loss_rpn)
F
FDInSky 已提交
106

107 108
        # BBox loss
        bbox_targets = self.proposal.get_targets()
K
Kaipeng Deng 已提交
109
        loss_bbox = self.bbox_head.get_loss(self.bbox_head_out, bbox_targets)
110
        loss.update(loss_bbox)
F
FDInSky 已提交
111

112 113
        # Mask loss
        mask_targets = self.mask.get_targets()
K
Kaipeng Deng 已提交
114
        loss_mask = self.mask_head.get_loss(self.mask_head_out, mask_targets)
115
        loss.update(loss_mask)
F
FDInSky 已提交
116

117 118 119
        total_loss = fluid.layers.sums(list(loss.values()))
        loss.update({'loss': total_loss})
        return loss
F
FDInSky 已提交
120

K
Kaipeng Deng 已提交
121
    def get_pred(self, ):
122
        mask = self.mask_post_process(self.bboxes, self.mask_head_out,
123 124 125 126 127 128
                                      self.inputs['im_info'])
        bbox, bbox_num = self.bboxes
        output = {
            'bbox': bbox.numpy(),
            'bbox_num': bbox_num.numpy(),
            'im_id': self.inputs['im_id'].numpy()
F
FDInSky 已提交
129
        }
130 131
        output.update(mask)
        return output