mask_rcnn.py 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from paddle import fluid
from ppdet.core.workspace import register

__all__ = ['MaskRCNN']


@register
class MaskRCNN(object):
    """
    Mask R-CNN architecture, see https://arxiv.org/abs/1703.06870
    Args:
        backbone (object): backbone instance
        rpn_head (object): `RPNhead` instance
        bbox_assigner (object): `BBoxAssigner` instance
        roi_extractor (object): ROI extractor instance
        bbox_head (object): `BBoxHead` instance
        mask_assigner (object): `MaskAssigner` instance
        mask_head (object): `MaskHead` instance
        fpn (object): feature pyramid network instance
    """

    __category__ = 'architecture'
    __inject__ = [
        'backbone', 'rpn_head', 'bbox_assigner', 'roi_extractor', 'bbox_head',
        'mask_assigner', 'mask_head', 'fpn'
    ]

    def __init__(self,
                 backbone,
                 rpn_head,
                 bbox_head='BBoxHead',
                 bbox_assigner='BBoxAssigner',
                 roi_extractor='RoIAlign',
                 mask_assigner='MaskAssigner',
                 mask_head='MaskHead',
                 fpn=None):
        super(MaskRCNN, self).__init__()
        self.backbone = backbone
        self.rpn_head = rpn_head
        self.bbox_assigner = bbox_assigner
        self.roi_extractor = roi_extractor
        self.bbox_head = bbox_head
        self.mask_assigner = mask_assigner
        self.mask_head = mask_head
        self.fpn = fpn

65
    def build(self, feed_vars, mode='train'):
66
        im = feed_vars['image']
67 68
        assert mode in ['train', 'test'], \
            "only 'train' and 'test' mode is supported"
69
        if mode == 'train':
70 71 72
            required_fields = [
                'gt_label', 'gt_box', 'gt_mask', 'is_crowd', 'im_info'
            ]
73 74 75
        else:
            required_fields = ['im_shape', 'im_info']
        for var in required_fields:
76 77
            assert var in feed_vars, \
                "{} has no {} field".format(feed_vars, var)
78 79 80 81 82 83 84 85
        im_info = feed_vars['im_info']

        body_feats = self.backbone(im)

        # FPN
        if self.fpn is not None:
            body_feats, spatial_scale = self.fpn.get_output(body_feats)

86 87
        # RPN proposals
        rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode)
88

89 90 91 92 93 94 95 96 97 98 99 100 101
        if mode == 'train':
            rpn_loss = self.rpn_head.get_loss(im_info, feed_vars['gt_box'],
                                              feed_vars['is_crowd'])

            outs = self.bbox_assigner(
                rpn_rois=rois,
                gt_classes=feed_vars['gt_label'],
                is_crowd=feed_vars['is_crowd'],
                gt_boxes=feed_vars['gt_box'],
                im_info=feed_vars['im_info'])
            rois = outs[0]
            labels_int32 = outs[1]

102 103 104 105 106 107
            if self.fpn is None:
                last_feat = body_feats[list(body_feats.keys())[-1]]
                roi_feat = self.roi_extractor(last_feat, rois)
            else:
                roi_feat = self.roi_extractor(body_feats, rois, spatial_scale)

108 109 110 111 112 113 114 115 116 117 118 119 120 121
            loss = self.bbox_head.get_loss(roi_feat, labels_int32, *outs[2:])
            loss.update(rpn_loss)

            mask_rois, roi_has_mask_int32, mask_int32 = self.mask_assigner(
                rois=rois,
                gt_classes=feed_vars['gt_label'],
                is_crowd=feed_vars['is_crowd'],
                gt_segms=feed_vars['gt_mask'],
                im_info=feed_vars['im_info'],
                labels_int32=labels_int32)
            if self.fpn is None:
                bbox_head_feat = self.bbox_head.get_head_feat()
                feat = fluid.layers.gather(bbox_head_feat, roi_has_mask_int32)
            else:
122 123
                feat = self.roi_extractor(
                    body_feats, mask_rois, spatial_scale, is_mask=True)
124 125 126 127 128 129 130 131

            mask_loss = self.mask_head.get_loss(feat, mask_int32)
            loss.update(mask_loss)

            total_loss = fluid.layers.sum(list(loss.values()))
            loss.update({'loss': total_loss})
            return loss

132
        else:
133 134 135 136 137 138 139

            if self.fpn is None:
                last_feat = body_feats[list(body_feats.keys())[-1]]
                roi_feat = self.roi_extractor(last_feat, rois)
            else:
                roi_feat = self.roi_extractor(body_feats, rois, spatial_scale)

140 141 142 143 144 145 146 147 148 149 150 151
            bbox_pred = self.bbox_head.get_prediction(roi_feat, rois, im_info,
                                                      feed_vars['im_shape'])
            bbox_pred = bbox_pred['bbox']

            # share weight
            bbox_shape = fluid.layers.shape(bbox_pred)
            bbox_size = fluid.layers.reduce_prod(bbox_shape)
            bbox_size = fluid.layers.reshape(bbox_size, [1, 1])
            size = fluid.layers.fill_constant([1, 1], value=6, dtype='int32')
            cond = fluid.layers.less_than(x=bbox_size, y=size)

            mask_pred = fluid.layers.create_global_var(
152 153 154 155 156
                shape=[1],
                value=0.0,
                dtype='float32',
                persistable=False,
                name='mask_pred')
157 158 159 160 161

            with fluid.layers.control_flow.Switch() as switch:
                with switch.case(cond):
                    fluid.layers.assign(input=bbox_pred, output=mask_pred)
                with switch.default():
162 163
                    bbox = fluid.layers.slice(
                        bbox_pred, [1], starts=[2], ends=[6])
164 165 166 167 168 169 170 171 172 173

                    im_scale = fluid.layers.slice(
                        im_info, [1], starts=[2], ends=[3])
                    im_scale = fluid.layers.sequence_expand(im_scale, bbox)

                    mask_rois = bbox * im_scale
                    if self.fpn is None:
                        mask_feat = self.roi_extractor(last_feat, mask_rois)
                        mask_feat = self.bbox_head.get_head_feat(mask_feat)
                    else:
174 175
                        mask_feat = self.roi_extractor(
                            body_feats, mask_rois, spatial_scale, is_mask=True)
176 177 178 179

                    mask_out = self.mask_head.get_prediction(mask_feat, bbox)
                    fluid.layers.assign(input=mask_out, output=mask_pred)
            return {'bbox': bbox_pred, 'mask': mask_pred}
180

181 182
    def train(self, feed_vars):
        return self.build(feed_vars, 'train')
183

184 185
    def eval(self, feed_vars):
        return self.build(feed_vars, 'test')
186 187

    def test(self, feed_vars):
188
        return self.build(feed_vars, 'test')