mask_rcnn.py 4.0 KB
Newer Older
F
FDInSky 已提交
1 2 3 4 5 6 7
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from paddle import fluid
from ppdet.core.workspace import register
from .meta_arch import BaseArch
8

F
FDInSky 已提交
9 10 11 12 13 14 15 16 17 18 19
__all__ = ['MaskRCNN']


@register
class MaskRCNN(BaseArch):
    __category__ = 'architecture'
    __inject__ = [
        'anchor',
        'proposal',
        'mask',
        'backbone',
20
        'neck',
F
FDInSky 已提交
21 22 23 24 25
        'rpn_head',
        'bbox_head',
        'mask_head',
    ]

26 27 28 29 30 31 32 33 34 35
    def __init__(self,
                 anchor,
                 proposal,
                 mask,
                 backbone,
                 rpn_head,
                 bbox_head,
                 mask_head,
                 neck=None):
        super(MaskRCNN, self).__init__()
F
FDInSky 已提交
36 37 38 39
        self.anchor = anchor
        self.proposal = proposal
        self.mask = mask
        self.backbone = backbone
40
        self.neck = neck
F
FDInSky 已提交
41 42 43 44
        self.rpn_head = rpn_head
        self.bbox_head = bbox_head
        self.mask_head = mask_head

45
    def model_arch(self):
F
FDInSky 已提交
46
        # Backbone
47 48 49 50 51 52
        body_feats = self.backbone(self.inputs)
        spatial_scale = None

        # Neck
        if self.neck is not None:
            body_feats, spatial_scale = self.neck(body_feats)
F
FDInSky 已提交
53 54

        # RPN
55 56 57 58 59
        # rpn_head returns two list: rpn_feat, rpn_head_out 
        # each element in rpn_feats contains rpn feature on each level,
        # and the length is 1 when the neck is not applied.
        # each element in rpn_head_out contains (rpn_rois_score, rpn_rois_delta)
        rpn_feat, self.rpn_head_out = self.rpn_head(self.inputs, body_feats)
F
FDInSky 已提交
60 61

        # Anchor
62 63 64
        # anchor_out returns a list,
        # each element contains (anchor, anchor_var)
        self.anchor_out = self.anchor(rpn_feat)
F
FDInSky 已提交
65

66 67 68
        # Proposal RoI 
        # compute targets here when training
        rois = self.proposal(self.inputs, self.rpn_head_out, self.anchor_out)
F
FDInSky 已提交
69
        # BBox Head
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
        bbox_feat, self.bbox_head_out = self.bbox_head(body_feats, rois,
                                                       spatial_scale)

        rois_has_mask_int32 = None
        if self.inputs['mode'] == 'infer':
            # Refine bbox by the output from bbox_head at test stage
            self.bboxes = self.proposal.post_process(self.inputs,
                                                     self.bbox_head_out, rois)
        else:
            # Proposal RoI for Mask branch
            # bboxes update at training stage only
            bbox_targets = self.proposal.get_targets()[0]
            self.bboxes, rois_has_mask_int32 = self.mask(self.inputs, rois,
                                                         bbox_targets)

        # Mask Head 
        self.mask_head_out = self.mask_head(self.inputs, body_feats,
                                            self.bboxes, bbox_feat,
                                            rois_has_mask_int32, spatial_scale)
F
FDInSky 已提交
89

90 91
    def loss(self, ):
        loss = {}
F
FDInSky 已提交
92

93 94 95 96 97
        # RPN loss
        rpn_loss_inputs = self.anchor.generate_loss_inputs(
            self.inputs, self.rpn_head_out, self.anchor_out)
        loss_rpn = self.rpn_head.loss(rpn_loss_inputs)
        loss.update(loss_rpn)
F
FDInSky 已提交
98

99 100 101 102
        # BBox loss
        bbox_targets = self.proposal.get_targets()
        loss_bbox = self.bbox_head.loss(self.bbox_head_out, bbox_targets)
        loss.update(loss_bbox)
F
FDInSky 已提交
103

104 105 106 107
        # Mask loss
        mask_targets = self.mask.get_targets()
        loss_mask = self.mask_head.loss(self.mask_head_out, mask_targets)
        loss.update(loss_mask)
F
FDInSky 已提交
108

109 110 111
        total_loss = fluid.layers.sums(list(loss.values()))
        loss.update({'loss': total_loss})
        return loss
F
FDInSky 已提交
112

113
    def infer(self, ):
114 115 116 117 118 119 120
        mask = self.mask.post_process(self.bboxes, self.mask_head_out,
                                      self.inputs['im_info'])
        bbox, bbox_num = self.bboxes
        output = {
            'bbox': bbox.numpy(),
            'bbox_num': bbox_num.numpy(),
            'im_id': self.inputs['im_id'].numpy()
F
FDInSky 已提交
121
        }
122 123
        output.update(mask)
        return output