bbox_head.py 11.4 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
F
Feng Ni 已提交
18
from paddle.nn.initializer import Normal, XavierUniform, KaimingNormal
Q
qingqing01 已提交
19
from paddle.regularizer import L2Decay
20 21

from ppdet.core.workspace import register, create
Q
qingqing01 已提交
22 23
from ppdet.modeling import ops

24 25 26
from .roi_extractor import RoIAlign
from ..shape_spec import ShapeSpec
from ..bbox_utils import bbox2delta
F
Feng Ni 已提交
27 28 29
from ppdet.modeling.layers import ConvNormLayer

__all__ = ['TwoFCHead', 'XConvNormHead', 'BBoxHead']
30

Q
qingqing01 已提交
31 32 33

@register
class TwoFCHead(nn.Layer):
W
wangguanzhong 已提交
34 35 36 37 38 39 40 41 42 43
    """
    RCNN bbox head with Two fc layers to extract feature

    Args:
        in_channel (int): Input channel which can be derived by from_config
        out_channel (int): Output channel
        resolution (int): Resolution of input feature map, default 7
    """

    def __init__(self, in_channel=256, out_channel=1024, resolution=7):
Q
qingqing01 已提交
44
        super(TwoFCHead, self).__init__()
W
wangguanzhong 已提交
45 46 47
        self.in_channel = in_channel
        self.out_channel = out_channel
        fan = in_channel * resolution * resolution
48
        self.fc6 = nn.Linear(
W
wangguanzhong 已提交
49 50
            in_channel * resolution * resolution,
            out_channel,
51 52 53 54
            weight_attr=paddle.ParamAttr(
                initializer=XavierUniform(fan_out=fan)))

        self.fc7 = nn.Linear(
W
wangguanzhong 已提交
55 56
            out_channel,
            out_channel,
W
wangguanzhong 已提交
57
            weight_attr=paddle.ParamAttr(initializer=XavierUniform()))
58 59 60 61 62

    @classmethod
    def from_config(cls, cfg, input_shape):
        s = input_shape
        s = s[0] if isinstance(s, (list, tuple)) else s
W
wangguanzhong 已提交
63
        return {'in_channel': s.channels}
64 65 66

    @property
    def out_shape(self):
W
wangguanzhong 已提交
67
        return [ShapeSpec(channels=self.out_channel, )]
68 69

    def forward(self, rois_feat):
Q
qingqing01 已提交
70
        rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
71 72 73 74 75
        fc6 = self.fc6(rois_feat)
        fc6 = F.relu(fc6)
        fc7 = self.fc7(fc6)
        fc7 = F.relu(fc7)
        return fc7
Q
qingqing01 已提交
76 77


F
Feng Ni 已提交
78 79
@register
class XConvNormHead(nn.Layer):
W
wangguanzhong 已提交
80
    __shared__ = ['norm_type', 'freeze_norm']
F
Feng Ni 已提交
81 82
    """
    RCNN bbox head with serveral convolution layers
W
wangguanzhong 已提交
83

F
Feng Ni 已提交
84
    Args:
W
wangguanzhong 已提交
85 86 87 88 89 90 91 92 93
        in_channel (int): Input channels which can be derived by from_config
        num_convs (int): The number of conv layers
        conv_dim (int): The number of channels for the conv layers
        out_channel (int): Output channels
        resolution (int): Resolution of input feature map
        norm_type (string): Norm type, bn, gn, sync_bn are available, 
            default `gn`
        freeze_norm (bool): Whether to freeze the norm
        stage_name (string): Prefix name for conv layer,  '' by default
F
Feng Ni 已提交
94 95 96
    """

    def __init__(self,
W
wangguanzhong 已提交
97
                 in_channel=256,
F
Feng Ni 已提交
98 99
                 num_convs=4,
                 conv_dim=256,
W
wangguanzhong 已提交
100
                 out_channel=1024,
F
Feng Ni 已提交
101 102 103 104 105
                 resolution=7,
                 norm_type='gn',
                 freeze_norm=False,
                 stage_name=''):
        super(XConvNormHead, self).__init__()
W
wangguanzhong 已提交
106
        self.in_channel = in_channel
F
Feng Ni 已提交
107 108
        self.num_convs = num_convs
        self.conv_dim = conv_dim
W
wangguanzhong 已提交
109
        self.out_channel = out_channel
F
Feng Ni 已提交
110 111 112 113 114 115 116
        self.norm_type = norm_type
        self.freeze_norm = freeze_norm

        self.bbox_head_convs = []
        fan = conv_dim * 3 * 3
        initializer = KaimingNormal(fan_in=fan)
        for i in range(self.num_convs):
W
wangguanzhong 已提交
117
            in_c = in_channel if i == 0 else conv_dim
F
Feng Ni 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
            head_conv_name = stage_name + 'bbox_head_conv{}'.format(i)
            head_conv = self.add_sublayer(
                head_conv_name,
                ConvNormLayer(
                    ch_in=in_c,
                    ch_out=conv_dim,
                    filter_size=3,
                    stride=1,
                    norm_type=self.norm_type,
                    norm_name=head_conv_name + '_norm',
                    freeze_norm=self.freeze_norm,
                    initializer=initializer,
                    name=head_conv_name))
            self.bbox_head_convs.append(head_conv)

        fan = conv_dim * resolution * resolution
        self.fc6 = nn.Linear(
            conv_dim * resolution * resolution,
W
wangguanzhong 已提交
136
            out_channel,
F
Feng Ni 已提交
137 138 139 140 141 142 143 144 145
            weight_attr=paddle.ParamAttr(
                initializer=XavierUniform(fan_out=fan)),
            bias_attr=paddle.ParamAttr(
                learning_rate=2., regularizer=L2Decay(0.)))

    @classmethod
    def from_config(cls, cfg, input_shape):
        s = input_shape
        s = s[0] if isinstance(s, (list, tuple)) else s
W
wangguanzhong 已提交
146
        return {'in_channel': s.channels}
F
Feng Ni 已提交
147 148 149

    @property
    def out_shape(self):
W
wangguanzhong 已提交
150
        return [ShapeSpec(channels=self.out_channel, )]
F
Feng Ni 已提交
151 152 153 154 155 156 157 158 159

    def forward(self, rois_feat):
        for i in range(self.num_convs):
            rois_feat = F.relu(self.bbox_head_convs[i](rois_feat))
        rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
        fc6 = F.relu(self.fc6(rois_feat))
        return fc6


Q
qingqing01 已提交
160 161
@register
class BBoxHead(nn.Layer):
162 163 164
    __shared__ = ['num_classes']
    __inject__ = ['bbox_assigner']
    """
W
wangguanzhong 已提交
165 166 167 168 169 170 171 172 173 174 175
    RCNN bbox head

    Args:
        head (nn.Layer): Extract feature in bbox head
        in_channel (int): Input channel after RoI extractor
        roi_extractor (object): The module of RoI Extractor
        bbox_assigner (object): The module of Box Assigner, label and sample the 
            box.
        with_pool (bool): Whether to use pooling for the RoI feature.
        num_classes (int): The number of classes
        bbox_weight (List[float]): The weight to get the decode box 
176
    """
Q
qingqing01 已提交
177 178

    def __init__(self,
179 180 181 182
                 head,
                 in_channel,
                 roi_extractor=RoIAlign().__dict__,
                 bbox_assigner='BboxAssigner',
Q
qingqing01 已提交
183
                 with_pool=False,
184 185
                 num_classes=80,
                 bbox_weight=[10., 10., 5., 5.]):
Q
qingqing01 已提交
186
        super(BBoxHead, self).__init__()
187 188 189 190 191 192
        self.head = head
        self.roi_extractor = roi_extractor
        if isinstance(roi_extractor, dict):
            self.roi_extractor = RoIAlign(**roi_extractor)
        self.bbox_assigner = bbox_assigner

Q
qingqing01 已提交
193
        self.with_pool = with_pool
194 195 196 197 198 199
        self.num_classes = num_classes
        self.bbox_weight = bbox_weight

        self.bbox_score = nn.Linear(
            in_channel,
            self.num_classes + 1,
W
wangguanzhong 已提交
200 201
            weight_attr=paddle.ParamAttr(initializer=Normal(
                mean=0.0, std=0.01)))
202 203 204 205

        self.bbox_delta = nn.Linear(
            in_channel,
            4 * self.num_classes,
W
wangguanzhong 已提交
206 207
            weight_attr=paddle.ParamAttr(initializer=Normal(
                mean=0.0, std=0.001)))
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
        self.assigned_label = None
        self.assigned_rois = None

    @classmethod
    def from_config(cls, cfg, input_shape):
        roi_pooler = cfg['roi_extractor']
        assert isinstance(roi_pooler, dict)
        kwargs = RoIAlign.from_config(cfg, input_shape)
        roi_pooler.update(kwargs)
        kwargs = {'input_shape': input_shape}
        head = create(cfg['head'], **kwargs)
        return {
            'roi_extractor': roi_pooler,
            'head': head,
            'in_channel': head.out_shape[0].channels
        }

    def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None):
        """
W
wangguanzhong 已提交
227
        body_feats (list[Tensor]): Feature maps from backbone
228
        rois (list[Tensor]): RoIs generated from RPN module
W
wangguanzhong 已提交
229 230
        rois_num (Tensor): The number of RoIs in each image
        inputs (dict{Tensor}): The ground-truth of image
231 232
        """
        if self.training:
W
wangguanzhong 已提交
233
            rois, rois_num, targets = self.bbox_assigner(rois, rois_num, inputs)
234 235 236 237 238
            self.assigned_rois = (rois, rois_num)
            self.assigned_targets = targets

        rois_feat = self.roi_extractor(body_feats, rois, rois_num)
        bbox_feat = self.head(rois_feat)
G
Guanghua Yu 已提交
239
        if self.with_pool:
240 241
            feat = F.adaptive_avg_pool2d(bbox_feat, output_size=1)
            feat = paddle.squeeze(feat, axis=[2, 3])
Q
qingqing01 已提交
242
        else:
243 244 245 246 247
            feat = bbox_feat
        scores = self.bbox_score(feat)
        deltas = self.bbox_delta(feat)

        if self.training:
W
wangguanzhong 已提交
248 249
            loss = self.get_loss(scores, deltas, targets, rois,
                                 self.bbox_weight)
250
            return loss, bbox_feat
Q
qingqing01 已提交
251
        else:
252 253 254
            pred = self.get_prediction(scores, deltas)
            return pred, self.head

W
wangguanzhong 已提交
255
    def get_loss(self, scores, deltas, targets, rois, bbox_weight):
256 257 258 259 260 261 262 263 264 265 266 267 268 269
        """
        scores (Tensor): scores from bbox head outputs
        deltas (Tensor): deltas from bbox head outputs
        targets (list[List[Tensor]]): bbox targets containing tgt_labels, tgt_bboxes and tgt_gt_inds
        rois (List[Tensor]): RoIs generated in each batch
        """
        # TODO: better pass args
        tgt_labels, tgt_bboxes, tgt_gt_inds = targets
        tgt_labels = paddle.concat(tgt_labels) if len(
            tgt_labels) > 1 else tgt_labels[0]
        tgt_labels = tgt_labels.cast('int64')
        tgt_labels.stop_gradient = True
        loss_bbox_cls = F.cross_entropy(
            input=scores, label=tgt_labels, reduction='mean')
Q
qingqing01 已提交
270
        # bbox reg
271 272 273 274 275 276 277

        cls_agnostic_bbox_reg = deltas.shape[1] == 4

        fg_inds = paddle.nonzero(
            paddle.logical_and(tgt_labels >= 0, tgt_labels <
                               self.num_classes)).flatten()

W
wangguanzhong 已提交
278 279 280 281
        cls_name = 'loss_bbox_cls'
        reg_name = 'loss_bbox_reg'
        loss_bbox = {}

W
wangguanzhong 已提交
282 283 284 285 286
        loss_weight = 1.
        if fg_inds.numel() == 0:
            fg_inds = paddle.zeros([1], dtype='int32')
            loss_weight = 0.

287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
        if cls_agnostic_bbox_reg:
            reg_delta = paddle.gather(deltas, fg_inds)
        else:
            fg_gt_classes = paddle.gather(tgt_labels, fg_inds)

            reg_row_inds = paddle.arange(fg_gt_classes.shape[0]).unsqueeze(1)
            reg_row_inds = paddle.tile(reg_row_inds, [1, 4]).reshape([-1, 1])

            reg_col_inds = 4 * fg_gt_classes.unsqueeze(1) + paddle.arange(4)

            reg_col_inds = reg_col_inds.reshape([-1, 1])
            reg_inds = paddle.concat([reg_row_inds, reg_col_inds], axis=1)

            reg_delta = paddle.gather(deltas, fg_inds)
            reg_delta = paddle.gather_nd(reg_delta, reg_inds).reshape([-1, 4])
        rois = paddle.concat(rois) if len(rois) > 1 else rois[0]
        tgt_bboxes = paddle.concat(tgt_bboxes) if len(
            tgt_bboxes) > 1 else tgt_bboxes[0]

W
wangguanzhong 已提交
306
        reg_target = bbox2delta(rois, tgt_bboxes, bbox_weight)
307 308 309 310 311 312
        reg_target = paddle.gather(reg_target, fg_inds)
        reg_target.stop_gradient = True

        loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum(
        ) / tgt_labels.shape[0]

W
wangguanzhong 已提交
313 314
        loss_bbox[cls_name] = loss_bbox_cls * loss_weight
        loss_bbox[reg_name] = loss_bbox_reg * loss_weight
315

Q
qingqing01 已提交
316 317
        return loss_bbox

318
    def get_prediction(self, score, delta):
Q
qingqing01 已提交
319
        bbox_prob = F.softmax(score)
320 321 322 323 324 325 326 327 328 329
        return delta, bbox_prob

    def get_head(self, ):
        return self.head

    def get_assigned_targets(self, ):
        return self.assigned_targets

    def get_assigned_rois(self, ):
        return self.assigned_rois