mask_head.py 7.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.
F
FDInSky 已提交
14

15 16 17 18 19 20 21
import paddle
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn import Layer, Sequential
from paddle.nn import Conv2D, Conv2DTranspose, ReLU
from paddle.nn.initializer import KaimingNormal
from paddle.regularizer import L2Decay
F
FDInSky 已提交
22
from ppdet.core.workspace import register
23
from ppdet.modeling import ops
F
FDInSky 已提交
24 25 26 27


@register
class MaskFeat(Layer):
28 29 30
    __inject__ = ['mask_roi_extractor']

    def __init__(self,
G
Guanghua Yu 已提交
31 32
                 mask_roi_extractor=None,
                 num_convs=0,
33 34 35 36
                 feat_in=2048,
                 feat_out=256,
                 mask_num_stages=1,
                 share_bbox_feat=False):
F
FDInSky 已提交
37
        super(MaskFeat, self).__init__()
38
        self.num_convs = num_convs
F
FDInSky 已提交
39 40
        self.feat_in = feat_in
        self.feat_out = feat_out
41 42 43 44 45 46 47 48 49 50 51 52 53 54
        self.mask_roi_extractor = mask_roi_extractor
        self.mask_num_stages = mask_num_stages
        self.share_bbox_feat = share_bbox_feat
        self.upsample_module = []
        fan_conv = feat_out * 3 * 3
        fan_deconv = feat_out * 2 * 2
        for i in range(self.mask_num_stages):
            name = 'stage_{}'.format(i)
            mask_conv = Sequential()
            for j in range(self.num_convs):
                conv_name = 'mask_inter_feat_{}'.format(j + 1)
                mask_conv.add_sublayer(
                    conv_name,
                    Conv2D(
55 56 57
                        in_channels=feat_in if j == 0 else feat_out,
                        out_channels=feat_out,
                        kernel_size=3,
58
                        padding=1,
59 60
                        weight_attr=ParamAttr(
                            initializer=KaimingNormal(fan_in=fan_conv)),
61
                        bias_attr=ParamAttr(
62 63
                            learning_rate=2., regularizer=L2Decay(0.))))
                mask_conv.add_sublayer(conv_name + 'act', ReLU())
64 65 66
            mask_conv.add_sublayer(
                'conv5_mask',
                Conv2DTranspose(
67 68 69
                    in_channels=self.feat_in,
                    out_channels=self.feat_out,
                    kernel_size=2,
70
                    stride=2,
71 72
                    weight_attr=ParamAttr(
                        initializer=KaimingNormal(fan_in=fan_deconv)),
73
                    bias_attr=ParamAttr(
74 75
                        learning_rate=2., regularizer=L2Decay(0.))))
            mask_conv.add_sublayer('conv5_mask' + 'act', ReLU())
76 77 78 79 80 81 82 83 84
            upsample = self.add_sublayer(name, mask_conv)
            self.upsample_module.append(upsample)

    def forward(self,
                body_feats,
                bboxes,
                bbox_feat,
                mask_index,
                spatial_scale,
G
Guanghua Yu 已提交
85 86 87 88
                stage=0,
                bbox_head_feat_func=None,
                mode='train'):
        if self.share_bbox_feat and mask_index:
89
            rois_feat = paddle.gather(bbox_feat, mask_index)
W
wangguanzhong 已提交
90 91
            if bbox_head_feat_func is not None and mode == 'infer':
                rois_feat = bbox_head_feat_func(rois_feat)
92 93 94
        else:
            rois_feat = self.mask_roi_extractor(body_feats, bboxes,
                                                spatial_scale)
F
FDInSky 已提交
95
        # upsample 
96 97
        mask_feat = self.upsample_module[stage](rois_feat)
        return mask_feat
F
FDInSky 已提交
98 99 100 101


@register
class MaskHead(Layer):
102
    __shared__ = ['num_classes', 'mask_num_stages']
F
FDInSky 已提交
103 104 105
    __inject__ = ['mask_feat']

    def __init__(self,
106
                 mask_feat,
F
FDInSky 已提交
107
                 feat_in=256,
108 109
                 num_classes=81,
                 mask_num_stages=1):
F
FDInSky 已提交
110
        super(MaskHead, self).__init__()
111
        self.mask_feat = mask_feat
F
FDInSky 已提交
112 113
        self.feat_in = feat_in
        self.num_classes = num_classes
114 115 116 117 118 119 120
        self.mask_num_stages = mask_num_stages
        self.mask_fcn_logits = []
        for i in range(self.mask_num_stages):
            name = 'mask_fcn_logits_{}'.format(i)
            self.mask_fcn_logits.append(
                self.add_sublayer(
                    name,
121 122 123 124 125 126
                    Conv2D(
                        in_channels=self.feat_in,
                        out_channels=self.num_classes,
                        kernel_size=1,
                        weight_attr=ParamAttr(initializer=KaimingNormal(
                            fan_in=self.num_classes)),
127
                        bias_attr=ParamAttr(
128
                            learning_rate=2., regularizer=L2Decay(0.0)))))
129 130 131 132 133 134 135 136 137

    def forward_train(self,
                      body_feats,
                      bboxes,
                      bbox_feat,
                      mask_index,
                      spatial_scale,
                      stage=0):
        # feat
G
Guanghua Yu 已提交
138 139 140 141 142 143 144 145
        mask_feat = self.mask_feat(
            body_feats,
            bboxes,
            bbox_feat,
            mask_index,
            spatial_scale,
            stage,
            mode='train')
146 147 148 149 150
        # logits
        mask_head_out = self.mask_fcn_logits[stage](mask_feat)
        return mask_head_out

    def forward_test(self,
151
                     scale_factor,
152 153 154 155 156
                     body_feats,
                     bboxes,
                     bbox_feat,
                     mask_index,
                     spatial_scale,
G
Guanghua Yu 已提交
157 158
                     stage=0,
                     bbox_head_feat_func=None):
159 160 161 162
        bbox, bbox_num = bboxes
        if bbox.shape[0] == 0:
            mask_head_out = bbox
        else:
163
            scale_factor_list = []
164 165
            for idx, num in enumerate(bbox_num):
                for n in range(num):
166 167 168
                    scale_factor_list.append(scale_factor[idx, 0])
            scale_factor_list = paddle.cast(
                paddle.concat(scale_factor_list), 'float32')
G
Guanghua Yu 已提交
169 170
            scale_factor_list = paddle.reshape(scale_factor_list, shape=[-1, 1])
            scaled_bbox = paddle.multiply(bbox[:, 2:], scale_factor_list)
171
            scaled_bboxes = (scaled_bbox, bbox_num)
G
Guanghua Yu 已提交
172 173 174 175 176 177 178 179 180
            mask_feat = self.mask_feat(
                body_feats,
                scaled_bboxes,
                bbox_feat,
                mask_index,
                spatial_scale,
                stage,
                bbox_head_feat_func,
                mode='infer')
181
            mask_logit = self.mask_fcn_logits[stage](mask_feat)
182
            mask_head_out = F.sigmoid(mask_logit)
183 184 185 186 187 188 189 190 191
        return mask_head_out

    def forward(self,
                inputs,
                body_feats,
                bboxes,
                bbox_feat,
                mask_index,
                spatial_scale,
G
Guanghua Yu 已提交
192
                bbox_head_feat_func=None,
193 194 195 196 197
                stage=0):
        if inputs['mode'] == 'train':
            mask_head_out = self.forward_train(body_feats, bboxes, bbox_feat,
                                               mask_index, spatial_scale, stage)
        else:
198
            scale_factor = inputs['scale_factor']
G
Guanghua Yu 已提交
199 200 201
            mask_head_out = self.forward_test(
                scale_factor, body_feats, bboxes, bbox_feat, mask_index,
                spatial_scale, stage, bbox_head_feat_func)
202 203
        return mask_head_out

K
Kaipeng Deng 已提交
204
    def get_loss(self, mask_head_out, mask_target):
205 206
        mask_logits = paddle.flatten(mask_head_out, start_axis=1, stop_axis=-1)
        mask_label = paddle.cast(x=mask_target, dtype='float32')
207
        mask_label.stop_gradient = True
208 209 210 211 212 213
        loss_mask = ops.sigmoid_cross_entropy_with_logits(
            input=mask_logits,
            label=mask_label,
            ignore_index=-1,
            normalize=True)
        loss_mask = paddle.sum(loss_mask)
F
FDInSky 已提交
214

215
        return {'loss_mask': loss_mask}