mask_head.py 8.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.
F
FDInSky 已提交
14

15 16 17 18 19 20 21
import paddle
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn import Layer, Sequential
from paddle.nn import Conv2D, Conv2DTranspose, ReLU
from paddle.nn.initializer import KaimingNormal
from paddle.regularizer import L2Decay
F
FDInSky 已提交
22
from ppdet.core.workspace import register
23
from ppdet.modeling import ops
F
FDInSky 已提交
24 25 26 27


@register
class MaskFeat(Layer):
28 29 30
    __inject__ = ['mask_roi_extractor']

    def __init__(self,
G
Guanghua Yu 已提交
31 32
                 mask_roi_extractor=None,
                 num_convs=0,
33 34 35 36
                 feat_in=2048,
                 feat_out=256,
                 mask_num_stages=1,
                 share_bbox_feat=False):
F
FDInSky 已提交
37
        super(MaskFeat, self).__init__()
38
        self.num_convs = num_convs
F
FDInSky 已提交
39 40
        self.feat_in = feat_in
        self.feat_out = feat_out
41 42 43 44 45 46 47 48 49 50 51 52 53 54
        self.mask_roi_extractor = mask_roi_extractor
        self.mask_num_stages = mask_num_stages
        self.share_bbox_feat = share_bbox_feat
        self.upsample_module = []
        fan_conv = feat_out * 3 * 3
        fan_deconv = feat_out * 2 * 2
        for i in range(self.mask_num_stages):
            name = 'stage_{}'.format(i)
            mask_conv = Sequential()
            for j in range(self.num_convs):
                conv_name = 'mask_inter_feat_{}'.format(j + 1)
                mask_conv.add_sublayer(
                    conv_name,
                    Conv2D(
55 56 57
                        in_channels=feat_in if j == 0 else feat_out,
                        out_channels=feat_out,
                        kernel_size=3,
58
                        padding=1,
59 60
                        weight_attr=ParamAttr(
                            initializer=KaimingNormal(fan_in=fan_conv)),
61
                        bias_attr=ParamAttr(
62 63
                            learning_rate=2., regularizer=L2Decay(0.))))
                mask_conv.add_sublayer(conv_name + 'act', ReLU())
64 65 66
            mask_conv.add_sublayer(
                'conv5_mask',
                Conv2DTranspose(
67 68 69
                    in_channels=self.feat_in,
                    out_channels=self.feat_out,
                    kernel_size=2,
70
                    stride=2,
71 72
                    weight_attr=ParamAttr(
                        initializer=KaimingNormal(fan_in=fan_deconv)),
73
                    bias_attr=ParamAttr(
74 75
                        learning_rate=2., regularizer=L2Decay(0.))))
            mask_conv.add_sublayer('conv5_mask' + 'act', ReLU())
76 77 78 79 80 81 82 83 84
            upsample = self.add_sublayer(name, mask_conv)
            self.upsample_module.append(upsample)

    def forward(self,
                body_feats,
                bboxes,
                bbox_feat,
                mask_index,
                spatial_scale,
G
Guanghua Yu 已提交
85 86 87
                stage=0,
                bbox_head_feat_func=None,
                mode='train'):
W
wangguanzhong 已提交
88
        if self.share_bbox_feat and mask_index is not None:
89
            rois_feat = paddle.gather(bbox_feat, mask_index)
90 91 92
        else:
            rois_feat = self.mask_roi_extractor(body_feats, bboxes,
                                                spatial_scale)
W
wangguanzhong 已提交
93 94 95
        if self.share_bbox_feat and bbox_head_feat_func is not None and mode == 'infer':
            rois_feat = bbox_head_feat_func(rois_feat)

F
FDInSky 已提交
96
        # upsample 
97 98
        mask_feat = self.upsample_module[stage](rois_feat)
        return mask_feat
F
FDInSky 已提交
99 100 101 102


@register
class MaskHead(Layer):
103
    __shared__ = ['num_classes', 'mask_num_stages']
F
FDInSky 已提交
104 105 106
    __inject__ = ['mask_feat']

    def __init__(self,
107
                 mask_feat,
F
FDInSky 已提交
108
                 feat_in=256,
109 110
                 num_classes=81,
                 mask_num_stages=1):
F
FDInSky 已提交
111
        super(MaskHead, self).__init__()
112
        self.mask_feat = mask_feat
F
FDInSky 已提交
113 114
        self.feat_in = feat_in
        self.num_classes = num_classes
115 116 117 118 119 120 121
        self.mask_num_stages = mask_num_stages
        self.mask_fcn_logits = []
        for i in range(self.mask_num_stages):
            name = 'mask_fcn_logits_{}'.format(i)
            self.mask_fcn_logits.append(
                self.add_sublayer(
                    name,
122 123 124 125 126 127
                    Conv2D(
                        in_channels=self.feat_in,
                        out_channels=self.num_classes,
                        kernel_size=1,
                        weight_attr=ParamAttr(initializer=KaimingNormal(
                            fan_in=self.num_classes)),
128
                        bias_attr=ParamAttr(
129
                            learning_rate=2., regularizer=L2Decay(0.0)))))
130 131 132 133 134 135 136 137 138

    def forward_train(self,
                      body_feats,
                      bboxes,
                      bbox_feat,
                      mask_index,
                      spatial_scale,
                      stage=0):
        # feat
G
Guanghua Yu 已提交
139 140 141 142 143 144 145 146
        mask_feat = self.mask_feat(
            body_feats,
            bboxes,
            bbox_feat,
            mask_index,
            spatial_scale,
            stage,
            mode='train')
147 148 149 150 151
        # logits
        mask_head_out = self.mask_fcn_logits[stage](mask_feat)
        return mask_head_out

    def forward_test(self,
152
                     scale_factor,
153 154 155 156 157
                     body_feats,
                     bboxes,
                     bbox_feat,
                     mask_index,
                     spatial_scale,
G
Guanghua Yu 已提交
158 159
                     stage=0,
                     bbox_head_feat_func=None):
160 161 162 163
        bbox, bbox_num = bboxes
        if bbox.shape[0] == 0:
            mask_head_out = bbox
        else:
164
            scale_factor_list = []
165 166
            for idx, num in enumerate(bbox_num):
                for n in range(num):
167 168 169
                    scale_factor_list.append(scale_factor[idx, 0])
            scale_factor_list = paddle.cast(
                paddle.concat(scale_factor_list), 'float32')
G
Guanghua Yu 已提交
170 171
            scale_factor_list = paddle.reshape(scale_factor_list, shape=[-1, 1])
            scaled_bbox = paddle.multiply(bbox[:, 2:], scale_factor_list)
172
            scaled_bboxes = (scaled_bbox, bbox_num)
G
Guanghua Yu 已提交
173 174 175 176 177 178 179 180 181
            mask_feat = self.mask_feat(
                body_feats,
                scaled_bboxes,
                bbox_feat,
                mask_index,
                spatial_scale,
                stage,
                bbox_head_feat_func,
                mode='infer')
182
            mask_logit = self.mask_fcn_logits[stage](mask_feat)
183
            mask_head_out = F.sigmoid(mask_logit)
184 185 186 187 188 189 190 191 192
        return mask_head_out

    def forward(self,
                inputs,
                body_feats,
                bboxes,
                bbox_feat,
                mask_index,
                spatial_scale,
G
Guanghua Yu 已提交
193
                bbox_head_feat_func=None,
194 195 196 197 198
                stage=0):
        if inputs['mode'] == 'train':
            mask_head_out = self.forward_train(body_feats, bboxes, bbox_feat,
                                               mask_index, spatial_scale, stage)
        else:
199
            scale_factor = inputs['scale_factor']
G
Guanghua Yu 已提交
200 201 202
            mask_head_out = self.forward_test(
                scale_factor, body_feats, bboxes, bbox_feat, mask_index,
                spatial_scale, stage, bbox_head_feat_func)
203 204
        return mask_head_out

K
Kaipeng Deng 已提交
205
    def get_loss(self, mask_head_out, mask_target):
206 207
        mask_logits = paddle.flatten(mask_head_out, start_axis=1, stop_axis=-1)
        mask_label = paddle.cast(x=mask_target, dtype='float32')
208
        mask_label.stop_gradient = True
209 210 211 212 213 214
        loss_mask = ops.sigmoid_cross_entropy_with_logits(
            input=mask_logits,
            label=mask_label,
            ignore_index=-1,
            normalize=True)
        loss_mask = paddle.sum(loss_mask)
F
FDInSky 已提交
215

216
        return {'loss_mask': loss_mask}