mask_head.py 9.4 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

import paddle
16
import paddle.nn as nn
Q
qingqing01 已提交
17 18
import paddle.nn.functional as F
from paddle.nn.initializer import KaimingNormal
19 20

from ppdet.core.workspace import register, create
F
Feng Ni 已提交
21
from ppdet.modeling.layers import ConvNormLayer
22
from .roi_extractor import RoIAlign
23
from ..cls_utils import _get_class_default_kwargs
Q
qingqing01 已提交
24 25


26 27
@register
class MaskFeat(nn.Layer):
W
wangguanzhong 已提交
28 29 30 31 32 33 34 35 36 37 38
    """
    Feature extraction in Mask head

    Args:
        in_channel (int): Input channels
        out_channel (int): Output channels
        num_convs (int): The number of conv layers, default 4
        norm_type (string | None): Norm type, bn, gn, sync_bn are available,
            default None
    """

F
Feng Ni 已提交
39
    def __init__(self,
W
wangguanzhong 已提交
40 41
                 in_channel=256,
                 out_channel=256,
F
Feng Ni 已提交
42 43
                 num_convs=4,
                 norm_type=None):
Q
qingqing01 已提交
44 45
        super(MaskFeat, self).__init__()
        self.num_convs = num_convs
W
wangguanzhong 已提交
46 47
        self.in_channel = in_channel
        self.out_channel = out_channel
F
Feng Ni 已提交
48
        self.norm_type = norm_type
W
wangguanzhong 已提交
49 50
        fan_conv = out_channel * 3 * 3
        fan_deconv = out_channel * 2 * 2
51 52

        mask_conv = nn.Sequential()
F
Feng Ni 已提交
53 54 55 56 57 58
        if norm_type == 'gn':
            for i in range(self.num_convs):
                conv_name = 'mask_inter_feat_{}'.format(i + 1)
                mask_conv.add_sublayer(
                    conv_name,
                    ConvNormLayer(
W
wangguanzhong 已提交
59 60
                        ch_in=in_channel if i == 0 else out_channel,
                        ch_out=out_channel,
F
Feng Ni 已提交
61 62 63 64
                        filter_size=3,
                        stride=1,
                        norm_type=self.norm_type,
                        initializer=KaimingNormal(fan_in=fan_conv),
65
                        skip_quant=True))
F
Feng Ni 已提交
66 67 68 69
                mask_conv.add_sublayer(conv_name + 'act', nn.ReLU())
        else:
            for i in range(self.num_convs):
                conv_name = 'mask_inter_feat_{}'.format(i + 1)
G
Guanghua Yu 已提交
70 71 72 73 74 75 76 77 78
                conv = nn.Conv2D(
                    in_channels=in_channel if i == 0 else out_channel,
                    out_channels=out_channel,
                    kernel_size=3,
                    padding=1,
                    weight_attr=paddle.ParamAttr(
                        initializer=KaimingNormal(fan_in=fan_conv)))
                conv.skip_quant = True
                mask_conv.add_sublayer(conv_name, conv)
F
Feng Ni 已提交
79
                mask_conv.add_sublayer(conv_name + 'act', nn.ReLU())
80 81 82
        mask_conv.add_sublayer(
            'conv5_mask',
            nn.Conv2DTranspose(
W
wangguanzhong 已提交
83 84
                in_channels=self.in_channel,
                out_channels=self.out_channel,
85 86 87 88 89 90
                kernel_size=2,
                stride=2,
                weight_attr=paddle.ParamAttr(
                    initializer=KaimingNormal(fan_in=fan_deconv))))
        mask_conv.add_sublayer('conv5_mask' + 'act', nn.ReLU())
        self.upsample = mask_conv
Q
qingqing01 已提交
91

92 93 94 95
    @classmethod
    def from_config(cls, cfg, input_shape):
        if isinstance(input_shape, (list, tuple)):
            input_shape = input_shape[0]
W
wangguanzhong 已提交
96
        return {'in_channel': input_shape.channels, }
Q
qingqing01 已提交
97

W
wangguanzhong 已提交
98 99
    def out_channels(self):
        return self.out_channel
100 101 102

    def forward(self, feats):
        return self.upsample(feats)
Q
qingqing01 已提交
103 104 105


@register
106
class MaskHead(nn.Layer):
107
    __shared__ = ['num_classes', 'export_onnx']
108
    __inject__ = ['mask_assigner']
W
wangguanzhong 已提交
109 110 111 112 113 114 115 116 117 118 119 120
    """
    RCNN mask head

    Args:
        head (nn.Layer): Extract feature in mask head
        roi_extractor (object): The module of RoI Extractor
        mask_assigner (object): The module of Mask Assigner, 
            label and sample the mask
        num_classes (int): The number of classes
        share_bbox_feat (bool): Whether to share the feature from bbox head,
            default false
    """
Q
qingqing01 已提交
121 122

    def __init__(self,
123
                 head,
124
                 roi_extractor=_get_class_default_kwargs(RoIAlign),
125 126
                 mask_assigner='MaskAssigner',
                 num_classes=80,
127 128
                 share_bbox_feat=False,
                 export_onnx=False):
Q
qingqing01 已提交
129 130
        super(MaskHead, self).__init__()
        self.num_classes = num_classes
131
        self.export_onnx = export_onnx
132 133 134 135 136

        self.roi_extractor = roi_extractor
        if isinstance(roi_extractor, dict):
            self.roi_extractor = RoIAlign(**roi_extractor)
        self.head = head
W
wangguanzhong 已提交
137
        self.in_channels = head.out_channels()
138 139 140 141 142 143 144 145 146 147
        self.mask_assigner = mask_assigner
        self.share_bbox_feat = share_bbox_feat
        self.bbox_head = None

        self.mask_fcn_logits = nn.Conv2D(
            in_channels=self.in_channels,
            out_channels=self.num_classes,
            kernel_size=1,
            weight_attr=paddle.ParamAttr(initializer=KaimingNormal(
                fan_in=self.num_classes)))
G
Guanghua Yu 已提交
148
        self.mask_fcn_logits.skip_quant = True
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198

    @classmethod
    def from_config(cls, cfg, input_shape):
        roi_pooler = cfg['roi_extractor']
        assert isinstance(roi_pooler, dict)
        kwargs = RoIAlign.from_config(cfg, input_shape)
        roi_pooler.update(kwargs)
        kwargs = {'input_shape': input_shape}
        head = create(cfg['head'], **kwargs)
        return {
            'roi_extractor': roi_pooler,
            'head': head,
        }

    def get_loss(self, mask_logits, mask_label, mask_target, mask_weight):
        mask_label = F.one_hot(mask_label, self.num_classes).unsqueeze([2, 3])
        mask_label = paddle.expand_as(mask_label, mask_logits)
        mask_label.stop_gradient = True
        mask_pred = paddle.gather_nd(mask_logits, paddle.nonzero(mask_label))
        shape = mask_logits.shape
        mask_pred = paddle.reshape(mask_pred, [shape[0], shape[2], shape[3]])

        mask_target = mask_target.cast('float32')
        mask_weight = mask_weight.unsqueeze([1, 2])
        loss_mask = F.binary_cross_entropy_with_logits(
            mask_pred, mask_target, weight=mask_weight, reduction="mean")
        return loss_mask

    def forward_train(self, body_feats, rois, rois_num, inputs, targets,
                      bbox_feat):
        """
        body_feats (list[Tensor]): Multi-level backbone features
        rois (list[Tensor]): Proposals for each batch with shape [N, 4]
        rois_num (Tensor): The number of proposals for each batch
        inputs (dict): ground truth info
        """
        tgt_labels, _, tgt_gt_inds = targets
        rois, rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights = self.mask_assigner(
            rois, tgt_labels, tgt_gt_inds, inputs)

        if self.share_bbox_feat:
            rois_feat = paddle.gather(bbox_feat, mask_index)
        else:
            rois_feat = self.roi_extractor(body_feats, rois, rois_num)
        mask_feat = self.head(rois_feat)
        mask_logits = self.mask_fcn_logits(mask_feat)

        loss_mask = self.get_loss(mask_logits, tgt_classes, tgt_masks,
                                  tgt_weights)
        return {'loss_mask': loss_mask}
Q
qingqing01 已提交
199 200 201

    def forward_test(self,
                     body_feats,
202 203 204 205 206 207 208 209 210 211
                     rois,
                     rois_num,
                     scale_factor,
                     feat_func=None):
        """
        body_feats (list[Tensor]): Multi-level backbone features
        rois (Tensor): Prediction from bbox head with shape [N, 6]
        rois_num (Tensor): The number of prediction for each batch
        scale_factor (Tensor): The scale factor from origin size to input size
        """
212
        if not self.export_onnx and rois.shape[0] == 0:
W
wangguanzhong 已提交
213
            mask_out = paddle.full([1, 1, 1], -1)
Q
qingqing01 已提交
214
        else:
215 216 217 218 219 220 221 222 223
            bbox = [rois[:, 2:]]
            labels = rois[:, 0].cast('int32')
            rois_feat = self.roi_extractor(body_feats, bbox, rois_num)
            if self.share_bbox_feat:
                assert feat_func is not None
                rois_feat = feat_func(rois_feat)

            mask_feat = self.head(rois_feat)
            mask_logit = self.mask_fcn_logits(mask_feat)
224
            if self.num_classes == 1:
225
                mask_out = F.sigmoid(mask_logit)[:, 0, :, :]
226
            else:
227 228 229
                num_masks = paddle.shape(mask_logit)[0]
                index = paddle.arange(num_masks).cast('int32')
                mask_out = mask_logit[index, labels]
230 231 232 233
                mask_out_shape = paddle.shape(mask_out)
                mask_out = paddle.reshape(mask_out, [
                    paddle.shape(index), mask_out_shape[-2], mask_out_shape[-1]
                ])
234
                mask_out = F.sigmoid(mask_out)
235
        return mask_out
Q
qingqing01 已提交
236 237 238

    def forward(self,
                body_feats,
239 240 241 242 243 244
                rois,
                rois_num,
                inputs,
                targets=None,
                bbox_feat=None,
                feat_func=None):
245
        if self.training:
246 247
            return self.forward_train(body_feats, rois, rois_num, inputs,
                                      targets, bbox_feat)
Q
qingqing01 已提交
248
        else:
249 250 251
            im_scale = inputs['scale_factor']
            return self.forward_test(body_feats, rois, rois_num, im_scale,
                                     feat_func)