bbox_head.py 10.8 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
F
Feng Ni 已提交
18
from paddle.nn.initializer import Normal, XavierUniform, KaimingNormal
Q
qingqing01 已提交
19
from paddle.regularizer import L2Decay
20 21

from ppdet.core.workspace import register, create
Q
qingqing01 已提交
22 23
from ppdet.modeling import ops

24 25 26
from .roi_extractor import RoIAlign
from ..shape_spec import ShapeSpec
from ..bbox_utils import bbox2delta
F
Feng Ni 已提交
27 28 29
from ppdet.modeling.layers import ConvNormLayer

__all__ = ['TwoFCHead', 'XConvNormHead', 'BBoxHead']
30

Q
qingqing01 已提交
31 32 33

@register
class TwoFCHead(nn.Layer):
34
    def __init__(self, in_dim=256, mlp_dim=1024, resolution=7):
Q
qingqing01 已提交
35 36 37 38
        super(TwoFCHead, self).__init__()
        self.in_dim = in_dim
        self.mlp_dim = mlp_dim
        fan = in_dim * resolution * resolution
39 40 41 42 43 44 45 46 47
        self.fc6 = nn.Linear(
            in_dim * resolution * resolution,
            mlp_dim,
            weight_attr=paddle.ParamAttr(
                initializer=XavierUniform(fan_out=fan)))

        self.fc7 = nn.Linear(
            mlp_dim,
            mlp_dim,
W
wangguanzhong 已提交
48
            weight_attr=paddle.ParamAttr(initializer=XavierUniform()))
49 50 51 52 53 54 55 56 57 58 59 60

    @classmethod
    def from_config(cls, cfg, input_shape):
        s = input_shape
        s = s[0] if isinstance(s, (list, tuple)) else s
        return {'in_dim': s.channels}

    @property
    def out_shape(self):
        return [ShapeSpec(channels=self.mlp_dim, )]

    def forward(self, rois_feat):
Q
qingqing01 已提交
61
        rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
62 63 64 65 66
        fc6 = self.fc6(rois_feat)
        fc6 = F.relu(fc6)
        fc7 = self.fc7(fc6)
        fc7 = F.relu(fc7)
        return fc7
Q
qingqing01 已提交
67 68


F
Feng Ni 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
@register
class XConvNormHead(nn.Layer):
    """
    RCNN bbox head with serveral convolution layers
    Args:
        in_dim(int): num of channels for the input rois_feat
        num_convs(int): num of convolution layers for the rcnn bbox head
        conv_dim(int): num of channels for the conv layers
        mlp_dim(int): num of channels for the fc layers
        resolution(int): resolution of the rois_feat
        norm_type(str): norm type, 'gn' by defalut
        freeze_norm(bool): whether to freeze the norm
        stage_name(str): used in CascadeXConvNormHead, '' by default
    """
    __shared__ = ['norm_type', 'freeze_norm']

    def __init__(self,
                 in_dim=256,
                 num_convs=4,
                 conv_dim=256,
                 mlp_dim=1024,
                 resolution=7,
                 norm_type='gn',
                 freeze_norm=False,
                 stage_name=''):
        super(XConvNormHead, self).__init__()
        self.in_dim = in_dim
        self.num_convs = num_convs
        self.conv_dim = conv_dim
        self.mlp_dim = mlp_dim
        self.norm_type = norm_type
        self.freeze_norm = freeze_norm

        self.bbox_head_convs = []
        fan = conv_dim * 3 * 3
        initializer = KaimingNormal(fan_in=fan)
        for i in range(self.num_convs):
            in_c = in_dim if i == 0 else conv_dim
            head_conv_name = stage_name + 'bbox_head_conv{}'.format(i)
            head_conv = self.add_sublayer(
                head_conv_name,
                ConvNormLayer(
                    ch_in=in_c,
                    ch_out=conv_dim,
                    filter_size=3,
                    stride=1,
                    norm_type=self.norm_type,
                    norm_name=head_conv_name + '_norm',
                    freeze_norm=self.freeze_norm,
                    initializer=initializer,
                    name=head_conv_name))
            self.bbox_head_convs.append(head_conv)

        fan = conv_dim * resolution * resolution
        self.fc6 = nn.Linear(
            conv_dim * resolution * resolution,
            mlp_dim,
            weight_attr=paddle.ParamAttr(
                initializer=XavierUniform(fan_out=fan)),
            bias_attr=paddle.ParamAttr(
                learning_rate=2., regularizer=L2Decay(0.)))

    @classmethod
    def from_config(cls, cfg, input_shape):
        s = input_shape
        s = s[0] if isinstance(s, (list, tuple)) else s
        return {'in_dim': s.channels}

    @property
    def out_shape(self):
        return [ShapeSpec(channels=self.mlp_dim, )]

    def forward(self, rois_feat):
        for i in range(self.num_convs):
            rois_feat = F.relu(self.bbox_head_convs[i](rois_feat))
        rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
        fc6 = F.relu(self.fc6(rois_feat))
        return fc6


Q
qingqing01 已提交
149 150
@register
class BBoxHead(nn.Layer):
151 152 153 154 155
    __shared__ = ['num_classes']
    __inject__ = ['bbox_assigner']
    """
    head (nn.Layer): Extract feature in bbox head
    in_channel (int): Input channel after RoI extractor
W
wangguanzhong 已提交
156 157 158 159 160 161
    roi_extractor (object): The module of RoI Extractor
    bbox_assigner (object): The module of Box Assigner, label and sample the 
                            box.
    with_pool (bool): Whether to use pooling for the RoI feature.
    num_classes (int): The number of classes
    bbox_weight (List[float]): The weight to get the decode box 
162
    """
Q
qingqing01 已提交
163 164

    def __init__(self,
165 166 167 168
                 head,
                 in_channel,
                 roi_extractor=RoIAlign().__dict__,
                 bbox_assigner='BboxAssigner',
Q
qingqing01 已提交
169
                 with_pool=False,
170 171
                 num_classes=80,
                 bbox_weight=[10., 10., 5., 5.]):
Q
qingqing01 已提交
172
        super(BBoxHead, self).__init__()
173 174 175 176 177 178
        self.head = head
        self.roi_extractor = roi_extractor
        if isinstance(roi_extractor, dict):
            self.roi_extractor = RoIAlign(**roi_extractor)
        self.bbox_assigner = bbox_assigner

Q
qingqing01 已提交
179
        self.with_pool = with_pool
180 181 182 183 184 185
        self.num_classes = num_classes
        self.bbox_weight = bbox_weight

        self.bbox_score = nn.Linear(
            in_channel,
            self.num_classes + 1,
W
wangguanzhong 已提交
186 187
            weight_attr=paddle.ParamAttr(initializer=Normal(
                mean=0.0, std=0.01)))
188 189 190 191

        self.bbox_delta = nn.Linear(
            in_channel,
            4 * self.num_classes,
W
wangguanzhong 已提交
192 193
            weight_attr=paddle.ParamAttr(initializer=Normal(
                mean=0.0, std=0.001)))
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
        self.assigned_label = None
        self.assigned_rois = None

    @classmethod
    def from_config(cls, cfg, input_shape):
        roi_pooler = cfg['roi_extractor']
        assert isinstance(roi_pooler, dict)
        kwargs = RoIAlign.from_config(cfg, input_shape)
        roi_pooler.update(kwargs)
        kwargs = {'input_shape': input_shape}
        head = create(cfg['head'], **kwargs)
        return {
            'roi_extractor': roi_pooler,
            'head': head,
            'in_channel': head.out_shape[0].channels
        }

    def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None):
        """
W
wangguanzhong 已提交
213
        body_feats (list[Tensor]): Feature maps from backbone
214
        rois (list[Tensor]): RoIs generated from RPN module
W
wangguanzhong 已提交
215 216
        rois_num (Tensor): The number of RoIs in each image
        inputs (dict{Tensor}): The ground-truth of image
217 218
        """
        if self.training:
W
wangguanzhong 已提交
219
            rois, rois_num, targets = self.bbox_assigner(rois, rois_num, inputs)
220 221 222 223 224
            self.assigned_rois = (rois, rois_num)
            self.assigned_targets = targets

        rois_feat = self.roi_extractor(body_feats, rois, rois_num)
        bbox_feat = self.head(rois_feat)
G
Guanghua Yu 已提交
225
        if self.with_pool:
226 227
            feat = F.adaptive_avg_pool2d(bbox_feat, output_size=1)
            feat = paddle.squeeze(feat, axis=[2, 3])
Q
qingqing01 已提交
228
        else:
229 230 231 232 233
            feat = bbox_feat
        scores = self.bbox_score(feat)
        deltas = self.bbox_delta(feat)

        if self.training:
W
wangguanzhong 已提交
234 235
            loss = self.get_loss(scores, deltas, targets, rois,
                                 self.bbox_weight)
236
            return loss, bbox_feat
Q
qingqing01 已提交
237
        else:
238 239 240
            pred = self.get_prediction(scores, deltas)
            return pred, self.head

W
wangguanzhong 已提交
241
    def get_loss(self, scores, deltas, targets, rois, bbox_weight):
242 243 244 245 246 247 248 249 250 251 252 253 254 255
        """
        scores (Tensor): scores from bbox head outputs
        deltas (Tensor): deltas from bbox head outputs
        targets (list[List[Tensor]]): bbox targets containing tgt_labels, tgt_bboxes and tgt_gt_inds
        rois (List[Tensor]): RoIs generated in each batch
        """
        # TODO: better pass args
        tgt_labels, tgt_bboxes, tgt_gt_inds = targets
        tgt_labels = paddle.concat(tgt_labels) if len(
            tgt_labels) > 1 else tgt_labels[0]
        tgt_labels = tgt_labels.cast('int64')
        tgt_labels.stop_gradient = True
        loss_bbox_cls = F.cross_entropy(
            input=scores, label=tgt_labels, reduction='mean')
Q
qingqing01 已提交
256
        # bbox reg
257 258 259 260 261 262 263

        cls_agnostic_bbox_reg = deltas.shape[1] == 4

        fg_inds = paddle.nonzero(
            paddle.logical_and(tgt_labels >= 0, tgt_labels <
                               self.num_classes)).flatten()

W
wangguanzhong 已提交
264 265 266 267
        cls_name = 'loss_bbox_cls'
        reg_name = 'loss_bbox_reg'
        loss_bbox = {}

268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
        if cls_agnostic_bbox_reg:
            reg_delta = paddle.gather(deltas, fg_inds)
        else:
            fg_gt_classes = paddle.gather(tgt_labels, fg_inds)

            reg_row_inds = paddle.arange(fg_gt_classes.shape[0]).unsqueeze(1)
            reg_row_inds = paddle.tile(reg_row_inds, [1, 4]).reshape([-1, 1])

            reg_col_inds = 4 * fg_gt_classes.unsqueeze(1) + paddle.arange(4)

            reg_col_inds = reg_col_inds.reshape([-1, 1])
            reg_inds = paddle.concat([reg_row_inds, reg_col_inds], axis=1)

            reg_delta = paddle.gather(deltas, fg_inds)
            reg_delta = paddle.gather_nd(reg_delta, reg_inds).reshape([-1, 4])
        rois = paddle.concat(rois) if len(rois) > 1 else rois[0]
        tgt_bboxes = paddle.concat(tgt_bboxes) if len(
            tgt_bboxes) > 1 else tgt_bboxes[0]

W
wangguanzhong 已提交
287
        reg_target = bbox2delta(rois, tgt_bboxes, bbox_weight)
288 289 290 291 292 293 294 295 296
        reg_target = paddle.gather(reg_target, fg_inds)
        reg_target.stop_gradient = True

        loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum(
        ) / tgt_labels.shape[0]

        loss_bbox[cls_name] = loss_bbox_cls
        loss_bbox[reg_name] = loss_bbox_reg

Q
qingqing01 已提交
297 298
        return loss_bbox

299
    def get_prediction(self, score, delta):
Q
qingqing01 已提交
300
        bbox_prob = F.softmax(score)
301 302 303 304 305 306 307 308 309 310
        return delta, bbox_prob

    def get_head(self, ):
        return self.head

    def get_assigned_targets(self, ):
        return self.assigned_targets

    def get_assigned_rois(self, ):
        return self.assigned_rois