mask_head.py 9.3 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

import paddle
16
import paddle.nn as nn
Q
qingqing01 已提交
17 18
import paddle.nn.functional as F
from paddle.nn.initializer import KaimingNormal
19 20

from ppdet.core.workspace import register, create
F
Feng Ni 已提交
21
from ppdet.modeling.layers import ConvNormLayer
22
from .roi_extractor import RoIAlign
Q
qingqing01 已提交
23 24


25 26
@register
class MaskFeat(nn.Layer):
W
wangguanzhong 已提交
27 28 29 30 31 32 33 34 35 36 37
    """
    Feature extraction in Mask head

    Args:
        in_channel (int): Input channels
        out_channel (int): Output channels
        num_convs (int): The number of conv layers, default 4
        norm_type (string | None): Norm type, bn, gn, sync_bn are available,
            default None
    """

F
Feng Ni 已提交
38
    def __init__(self,
W
wangguanzhong 已提交
39 40
                 in_channel=256,
                 out_channel=256,
F
Feng Ni 已提交
41 42
                 num_convs=4,
                 norm_type=None):
Q
qingqing01 已提交
43 44
        super(MaskFeat, self).__init__()
        self.num_convs = num_convs
W
wangguanzhong 已提交
45 46
        self.in_channel = in_channel
        self.out_channel = out_channel
F
Feng Ni 已提交
47
        self.norm_type = norm_type
W
wangguanzhong 已提交
48 49
        fan_conv = out_channel * 3 * 3
        fan_deconv = out_channel * 2 * 2
50 51

        mask_conv = nn.Sequential()
F
Feng Ni 已提交
52 53 54 55 56 57
        if norm_type == 'gn':
            for i in range(self.num_convs):
                conv_name = 'mask_inter_feat_{}'.format(i + 1)
                mask_conv.add_sublayer(
                    conv_name,
                    ConvNormLayer(
W
wangguanzhong 已提交
58 59
                        ch_in=in_channel if i == 0 else out_channel,
                        ch_out=out_channel,
F
Feng Ni 已提交
60 61 62 63
                        filter_size=3,
                        stride=1,
                        norm_type=self.norm_type,
                        initializer=KaimingNormal(fan_in=fan_conv),
64
                        skip_quant=True))
F
Feng Ni 已提交
65 66 67 68
                mask_conv.add_sublayer(conv_name + 'act', nn.ReLU())
        else:
            for i in range(self.num_convs):
                conv_name = 'mask_inter_feat_{}'.format(i + 1)
G
Guanghua Yu 已提交
69 70 71 72 73 74 75 76 77
                conv = nn.Conv2D(
                    in_channels=in_channel if i == 0 else out_channel,
                    out_channels=out_channel,
                    kernel_size=3,
                    padding=1,
                    weight_attr=paddle.ParamAttr(
                        initializer=KaimingNormal(fan_in=fan_conv)))
                conv.skip_quant = True
                mask_conv.add_sublayer(conv_name, conv)
F
Feng Ni 已提交
78
                mask_conv.add_sublayer(conv_name + 'act', nn.ReLU())
79 80 81
        mask_conv.add_sublayer(
            'conv5_mask',
            nn.Conv2DTranspose(
W
wangguanzhong 已提交
82 83
                in_channels=self.in_channel,
                out_channels=self.out_channel,
84 85 86 87 88 89
                kernel_size=2,
                stride=2,
                weight_attr=paddle.ParamAttr(
                    initializer=KaimingNormal(fan_in=fan_deconv))))
        mask_conv.add_sublayer('conv5_mask' + 'act', nn.ReLU())
        self.upsample = mask_conv
Q
qingqing01 已提交
90

91 92 93 94
    @classmethod
    def from_config(cls, cfg, input_shape):
        if isinstance(input_shape, (list, tuple)):
            input_shape = input_shape[0]
W
wangguanzhong 已提交
95
        return {'in_channel': input_shape.channels, }
Q
qingqing01 已提交
96

W
wangguanzhong 已提交
97 98
    def out_channels(self):
        return self.out_channel
99 100 101

    def forward(self, feats):
        return self.upsample(feats)
Q
qingqing01 已提交
102 103 104


@register
105 106 107
class MaskHead(nn.Layer):
    __shared__ = ['num_classes']
    __inject__ = ['mask_assigner']
W
wangguanzhong 已提交
108 109 110 111 112 113 114 115 116 117 118 119
    """
    RCNN mask head

    Args:
        head (nn.Layer): Extract feature in mask head
        roi_extractor (object): The module of RoI Extractor
        mask_assigner (object): The module of Mask Assigner, 
            label and sample the mask
        num_classes (int): The number of classes
        share_bbox_feat (bool): Whether to share the feature from bbox head,
            default false
    """
Q
qingqing01 已提交
120 121

    def __init__(self,
122 123 124 125 126
                 head,
                 roi_extractor=RoIAlign().__dict__,
                 mask_assigner='MaskAssigner',
                 num_classes=80,
                 share_bbox_feat=False):
Q
qingqing01 已提交
127 128
        super(MaskHead, self).__init__()
        self.num_classes = num_classes
129 130 131 132 133

        self.roi_extractor = roi_extractor
        if isinstance(roi_extractor, dict):
            self.roi_extractor = RoIAlign(**roi_extractor)
        self.head = head
W
wangguanzhong 已提交
134
        self.in_channels = head.out_channels()
135 136 137 138 139 140 141 142 143 144
        self.mask_assigner = mask_assigner
        self.share_bbox_feat = share_bbox_feat
        self.bbox_head = None

        self.mask_fcn_logits = nn.Conv2D(
            in_channels=self.in_channels,
            out_channels=self.num_classes,
            kernel_size=1,
            weight_attr=paddle.ParamAttr(initializer=KaimingNormal(
                fan_in=self.num_classes)))
G
Guanghua Yu 已提交
145
        self.mask_fcn_logits.skip_quant = True
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195

    @classmethod
    def from_config(cls, cfg, input_shape):
        roi_pooler = cfg['roi_extractor']
        assert isinstance(roi_pooler, dict)
        kwargs = RoIAlign.from_config(cfg, input_shape)
        roi_pooler.update(kwargs)
        kwargs = {'input_shape': input_shape}
        head = create(cfg['head'], **kwargs)
        return {
            'roi_extractor': roi_pooler,
            'head': head,
        }

    def get_loss(self, mask_logits, mask_label, mask_target, mask_weight):
        mask_label = F.one_hot(mask_label, self.num_classes).unsqueeze([2, 3])
        mask_label = paddle.expand_as(mask_label, mask_logits)
        mask_label.stop_gradient = True
        mask_pred = paddle.gather_nd(mask_logits, paddle.nonzero(mask_label))
        shape = mask_logits.shape
        mask_pred = paddle.reshape(mask_pred, [shape[0], shape[2], shape[3]])

        mask_target = mask_target.cast('float32')
        mask_weight = mask_weight.unsqueeze([1, 2])
        loss_mask = F.binary_cross_entropy_with_logits(
            mask_pred, mask_target, weight=mask_weight, reduction="mean")
        return loss_mask

    def forward_train(self, body_feats, rois, rois_num, inputs, targets,
                      bbox_feat):
        """
        body_feats (list[Tensor]): Multi-level backbone features
        rois (list[Tensor]): Proposals for each batch with shape [N, 4]
        rois_num (Tensor): The number of proposals for each batch
        inputs (dict): ground truth info
        """
        tgt_labels, _, tgt_gt_inds = targets
        rois, rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights = self.mask_assigner(
            rois, tgt_labels, tgt_gt_inds, inputs)

        if self.share_bbox_feat:
            rois_feat = paddle.gather(bbox_feat, mask_index)
        else:
            rois_feat = self.roi_extractor(body_feats, rois, rois_num)
        mask_feat = self.head(rois_feat)
        mask_logits = self.mask_fcn_logits(mask_feat)

        loss_mask = self.get_loss(mask_logits, tgt_classes, tgt_masks,
                                  tgt_weights)
        return {'loss_mask': loss_mask}
Q
qingqing01 已提交
196 197 198

    def forward_test(self,
                     body_feats,
199 200 201 202 203 204 205 206 207 208 209 210
                     rois,
                     rois_num,
                     scale_factor,
                     feat_func=None):
        """
        body_feats (list[Tensor]): Multi-level backbone features
        rois (Tensor): Prediction from bbox head with shape [N, 6]
        rois_num (Tensor): The number of prediction for each batch
        scale_factor (Tensor): The scale factor from origin size to input size
        """
        if rois.shape[0] == 0:
            mask_out = paddle.full([1, 1, 1, 1], -1)
Q
qingqing01 已提交
211
        else:
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
            bbox = [rois[:, 2:]]
            labels = rois[:, 0].cast('int32')
            rois_feat = self.roi_extractor(body_feats, bbox, rois_num)
            if self.share_bbox_feat:
                assert feat_func is not None
                rois_feat = feat_func(rois_feat)

            mask_feat = self.head(rois_feat)
            mask_logit = self.mask_fcn_logits(mask_feat)
            mask_num_class = mask_logit.shape[1]
            if mask_num_class == 1:
                mask_out = F.sigmoid(mask_logit)
            else:
                num_masks = mask_logit.shape[0]
                mask_out = []
                # TODO: need to optimize gather
G
Guanghua Yu 已提交
228 229 230 231
                for i in range(mask_logit.shape[0]):
                    pred_masks = paddle.unsqueeze(
                        mask_logit[i, :, :, :], axis=0)
                    mask = paddle.gather(pred_masks, labels[i], axis=1)
232 233 234
                    mask_out.append(mask)
                mask_out = F.sigmoid(paddle.concat(mask_out))
        return mask_out
Q
qingqing01 已提交
235 236 237

    def forward(self,
                body_feats,
238 239 240 241 242 243
                rois,
                rois_num,
                inputs,
                targets=None,
                bbox_feat=None,
                feat_func=None):
244
        if self.training:
245 246
            return self.forward_train(body_feats, rois, rois_num, inputs,
                                      targets, bbox_feat)
Q
qingqing01 已提交
247
        else:
248 249 250
            im_scale = inputs['scale_factor']
            return self.forward_test(body_feats, rois, rois_num, im_scale,
                                     feat_func)