mask_rcnn.py 2.9 KB
Newer Older
F
FDInSky 已提交
1 2 3 4 5 6 7 8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from paddle import fluid
from ppdet.core.workspace import register
from ppdet.utils.data_structure import BufferDict
from .meta_arch import BaseArch
9

F
FDInSky 已提交
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
__all__ = ['MaskRCNN']


@register
class MaskRCNN(BaseArch):
    __category__ = 'architecture'
    __inject__ = [
        'anchor',
        'proposal',
        'mask',
        'backbone',
        'rpn_head',
        'bbox_head',
        'mask_head',
    ]

26 27 28
    def __init__(self, anchor, proposal, mask, backbone, rpn_head, bbox_head,
                 mask_head, *args, **kwargs):
        super(MaskRCNN, self).__init__(*args, **kwargs)
F
FDInSky 已提交
29 30 31 32 33 34 35 36
        self.anchor = anchor
        self.proposal = proposal
        self.mask = mask
        self.backbone = backbone
        self.rpn_head = rpn_head
        self.bbox_head = bbox_head
        self.mask_head = mask_head

37
    def model_arch(self, ):
F
FDInSky 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50
        # Backbone
        bb_out = self.backbone(self.gbd)
        self.gbd.update(bb_out)

        # RPN
        rpn_head_out = self.rpn_head(self.gbd)
        self.gbd.update(rpn_head_out)

        # Anchor
        anchor_out = self.anchor(self.gbd)
        self.gbd.update(anchor_out)

        # Proposal BBox
51
        self.gbd['stage'] = 0
F
FDInSky 已提交
52
        proposal_out = self.proposal(self.gbd)
53
        self.gbd.update({'proposal_0': proposal_out})
F
FDInSky 已提交
54 55

        # BBox Head
56 57
        bboxhead_out = self.bbox_head(self.gbd)
        self.gbd.update({'bbox_head_0': bboxhead_out})
F
FDInSky 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

        if self.gbd['mode'] == 'infer':
            bbox_out = self.proposal.post_process(self.gbd)
            self.gbd.update(bbox_out)

        # Mask 
        mask_out = self.mask(self.gbd)
        self.gbd.update(mask_out)

        # Mask Head 
        mask_head_out = self.mask_head(self.gbd)
        self.gbd.update(mask_head_out)

        if self.gbd['mode'] == 'infer':
            mask_out = self.mask.post_process(self.gbd)
            self.gbd.update(mask_out)

75
    def loss(self, ):
F
FDInSky 已提交
76
        losses = []
77 78 79
        rpn_cls_loss, rpn_reg_loss = self.rpn_head.loss(self.gbd)
        bbox_cls_loss, bbox_reg_loss = self.bbox_head.loss(self.gbd)
        mask_loss = self.mask_head.loss(self.gbd)
F
FDInSky 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93
        losses = [
            rpn_cls_loss, rpn_reg_loss, bbox_cls_loss, bbox_reg_loss, mask_loss
        ]
        loss = fluid.layers.sum(losses)
        out = {
            'loss': loss,
            'loss_rpn_cls': rpn_cls_loss,
            'loss_rpn_reg': rpn_reg_loss,
            'loss_bbox_cls': bbox_cls_loss,
            'loss_bbox_reg': bbox_reg_loss,
            'loss_mask': mask_loss
        }
        return out

94
    def infer(self, ):
F
FDInSky 已提交
95
        outs = {
96 97 98
            'bbox': self.gbd['predicted_bbox'].numpy(),
            'bbox_nums': self.gbd['predicted_bbox_nums'].numpy(),
            'mask': self.gbd['predicted_mask'].numpy(),
99
            'im_id': self.gbd['im_id'].numpy()
F
FDInSky 已提交
100 101
        }
        return inputs