mask_head.py 7.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.
F
FDInSky 已提交
14

15 16 17 18 19 20 21
import paddle
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn import Layer, Sequential
from paddle.nn import Conv2D, Conv2DTranspose, ReLU
from paddle.nn.initializer import KaimingNormal
from paddle.regularizer import L2Decay
F
FDInSky 已提交
22
from ppdet.core.workspace import register
23
from ppdet.modeling import ops
F
FDInSky 已提交
24 25 26 27


@register
class MaskFeat(Layer):
28 29 30 31 32 33 34 35 36
    __inject__ = ['mask_roi_extractor']

    def __init__(self,
                 mask_roi_extractor,
                 num_convs=1,
                 feat_in=2048,
                 feat_out=256,
                 mask_num_stages=1,
                 share_bbox_feat=False):
F
FDInSky 已提交
37
        super(MaskFeat, self).__init__()
38
        self.num_convs = num_convs
F
FDInSky 已提交
39 40
        self.feat_in = feat_in
        self.feat_out = feat_out
41 42 43 44 45 46 47 48 49 50 51 52 53 54
        self.mask_roi_extractor = mask_roi_extractor
        self.mask_num_stages = mask_num_stages
        self.share_bbox_feat = share_bbox_feat
        self.upsample_module = []
        fan_conv = feat_out * 3 * 3
        fan_deconv = feat_out * 2 * 2
        for i in range(self.mask_num_stages):
            name = 'stage_{}'.format(i)
            mask_conv = Sequential()
            for j in range(self.num_convs):
                conv_name = 'mask_inter_feat_{}'.format(j + 1)
                mask_conv.add_sublayer(
                    conv_name,
                    Conv2D(
55 56 57
                        in_channels=feat_in if j == 0 else feat_out,
                        out_channels=feat_out,
                        kernel_size=3,
58
                        padding=1,
59 60
                        weight_attr=ParamAttr(
                            initializer=KaimingNormal(fan_in=fan_conv)),
61
                        bias_attr=ParamAttr(
62 63
                            learning_rate=2., regularizer=L2Decay(0.))))
                mask_conv.add_sublayer(conv_name + 'act', ReLU())
64 65 66
            mask_conv.add_sublayer(
                'conv5_mask',
                Conv2DTranspose(
67 68 69
                    in_channels=self.feat_in,
                    out_channels=self.feat_out,
                    kernel_size=2,
70
                    stride=2,
71 72
                    weight_attr=ParamAttr(
                        initializer=KaimingNormal(fan_in=fan_deconv)),
73
                    bias_attr=ParamAttr(
74 75
                        learning_rate=2., regularizer=L2Decay(0.))))
            mask_conv.add_sublayer('conv5_mask' + 'act', ReLU())
76 77 78 79 80 81 82 83 84 85 86
            upsample = self.add_sublayer(name, mask_conv)
            self.upsample_module.append(upsample)

    def forward(self,
                body_feats,
                bboxes,
                bbox_feat,
                mask_index,
                spatial_scale,
                stage=0):
        if self.share_bbox_feat:
87
            rois_feat = paddle.gather(bbox_feat, mask_index)
88 89 90
        else:
            rois_feat = self.mask_roi_extractor(body_feats, bboxes,
                                                spatial_scale)
F
FDInSky 已提交
91
        # upsample 
92 93
        mask_feat = self.upsample_module[stage](rois_feat)
        return mask_feat
F
FDInSky 已提交
94 95 96 97


@register
class MaskHead(Layer):
98
    __shared__ = ['num_classes', 'mask_num_stages']
F
FDInSky 已提交
99 100 101
    __inject__ = ['mask_feat']

    def __init__(self,
102
                 mask_feat,
F
FDInSky 已提交
103
                 feat_in=256,
104 105
                 num_classes=81,
                 mask_num_stages=1):
F
FDInSky 已提交
106
        super(MaskHead, self).__init__()
107
        self.mask_feat = mask_feat
F
FDInSky 已提交
108 109
        self.feat_in = feat_in
        self.num_classes = num_classes
110 111 112 113 114 115 116
        self.mask_num_stages = mask_num_stages
        self.mask_fcn_logits = []
        for i in range(self.mask_num_stages):
            name = 'mask_fcn_logits_{}'.format(i)
            self.mask_fcn_logits.append(
                self.add_sublayer(
                    name,
117 118 119 120 121 122
                    Conv2D(
                        in_channels=self.feat_in,
                        out_channels=self.num_classes,
                        kernel_size=1,
                        weight_attr=ParamAttr(initializer=KaimingNormal(
                            fan_in=self.num_classes)),
123
                        bias_attr=ParamAttr(
124
                            learning_rate=2., regularizer=L2Decay(0.0)))))
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140

    def forward_train(self,
                      body_feats,
                      bboxes,
                      bbox_feat,
                      mask_index,
                      spatial_scale,
                      stage=0):
        # feat
        mask_feat = self.mask_feat(body_feats, bboxes, bbox_feat, mask_index,
                                   spatial_scale, stage)
        # logits
        mask_head_out = self.mask_fcn_logits[stage](mask_feat)
        return mask_head_out

    def forward_test(self,
141
                     scale_factor,
142 143 144 145 146 147 148 149 150 151
                     body_feats,
                     bboxes,
                     bbox_feat,
                     mask_index,
                     spatial_scale,
                     stage=0):
        bbox, bbox_num = bboxes
        if bbox.shape[0] == 0:
            mask_head_out = bbox
        else:
152
            scale_factor_list = []
153 154
            for idx, num in enumerate(bbox_num):
                for n in range(num):
155 156 157 158 159
                    scale_factor_list.append(scale_factor[idx, 0])
            scale_factor_list = paddle.cast(
                paddle.concat(scale_factor_list), 'float32')
            scaled_bbox = paddle.multiply(
                bbox[:, 2:], scale_factor_list, axis=0)
160 161 162 163
            scaled_bboxes = (scaled_bbox, bbox_num)
            mask_feat = self.mask_feat(body_feats, scaled_bboxes, bbox_feat,
                                       mask_index, spatial_scale, stage)
            mask_logit = self.mask_fcn_logits[stage](mask_feat)
164
            mask_head_out = F.sigmoid(mask_logit)
165 166 167 168 169 170 171 172 173 174 175 176 177 178
        return mask_head_out

    def forward(self,
                inputs,
                body_feats,
                bboxes,
                bbox_feat,
                mask_index,
                spatial_scale,
                stage=0):
        if inputs['mode'] == 'train':
            mask_head_out = self.forward_train(body_feats, bboxes, bbox_feat,
                                               mask_index, spatial_scale, stage)
        else:
179 180
            scale_factor = inputs['scale_factor']
            mask_head_out = self.forward_test(scale_factor, body_feats, bboxes,
181 182 183 184
                                              bbox_feat, mask_index,
                                              spatial_scale, stage)
        return mask_head_out

K
Kaipeng Deng 已提交
185
    def get_loss(self, mask_head_out, mask_target):
186 187
        mask_logits = paddle.flatten(mask_head_out, start_axis=1, stop_axis=-1)
        mask_label = paddle.cast(x=mask_target, dtype='float32')
188
        mask_label.stop_gradient = True
189 190 191 192 193 194
        loss_mask = ops.sigmoid_cross_entropy_with_logits(
            input=mask_logits,
            label=mask_label,
            ignore_index=-1,
            normalize=True)
        loss_mask = paddle.sum(loss_mask)
F
FDInSky 已提交
195

196
        return {'loss_mask': loss_mask}