bbox_head.py 13.1 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

G
Guanghua Yu 已提交
15 16
import numpy as np

Q
qingqing01 已提交
17 18 19
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
F
Feng Ni 已提交
20
from paddle.nn.initializer import Normal, XavierUniform, KaimingNormal
Q
qingqing01 已提交
21
from paddle.regularizer import L2Decay
22 23 24 25 26

from ppdet.core.workspace import register, create
from .roi_extractor import RoIAlign
from ..shape_spec import ShapeSpec
from ..bbox_utils import bbox2delta
F
Feng Ni 已提交
27 28 29
from ppdet.modeling.layers import ConvNormLayer

__all__ = ['TwoFCHead', 'XConvNormHead', 'BBoxHead']
30

Q
qingqing01 已提交
31 32 33

@register
class TwoFCHead(nn.Layer):
W
wangguanzhong 已提交
34 35 36 37 38 39 40 41 42 43
    """
    RCNN bbox head with Two fc layers to extract feature

    Args:
        in_channel (int): Input channel which can be derived by from_config
        out_channel (int): Output channel
        resolution (int): Resolution of input feature map, default 7
    """

    def __init__(self, in_channel=256, out_channel=1024, resolution=7):
Q
qingqing01 已提交
44
        super(TwoFCHead, self).__init__()
W
wangguanzhong 已提交
45 46 47
        self.in_channel = in_channel
        self.out_channel = out_channel
        fan = in_channel * resolution * resolution
48
        self.fc6 = nn.Linear(
W
wangguanzhong 已提交
49 50
            in_channel * resolution * resolution,
            out_channel,
51 52
            weight_attr=paddle.ParamAttr(
                initializer=XavierUniform(fan_out=fan)))
G
Guanghua Yu 已提交
53
        self.fc6.skip_quant = True
54 55

        self.fc7 = nn.Linear(
W
wangguanzhong 已提交
56 57
            out_channel,
            out_channel,
W
wangguanzhong 已提交
58
            weight_attr=paddle.ParamAttr(initializer=XavierUniform()))
G
Guanghua Yu 已提交
59
        self.fc7.skip_quant = True
60 61 62 63 64

    @classmethod
    def from_config(cls, cfg, input_shape):
        s = input_shape
        s = s[0] if isinstance(s, (list, tuple)) else s
W
wangguanzhong 已提交
65
        return {'in_channel': s.channels}
66 67 68

    @property
    def out_shape(self):
W
wangguanzhong 已提交
69
        return [ShapeSpec(channels=self.out_channel, )]
70 71

    def forward(self, rois_feat):
Q
qingqing01 已提交
72
        rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
73 74 75 76 77
        fc6 = self.fc6(rois_feat)
        fc6 = F.relu(fc6)
        fc7 = self.fc7(fc6)
        fc7 = F.relu(fc7)
        return fc7
Q
qingqing01 已提交
78 79


F
Feng Ni 已提交
80 81
@register
class XConvNormHead(nn.Layer):
W
wangguanzhong 已提交
82
    __shared__ = ['norm_type', 'freeze_norm']
F
Feng Ni 已提交
83 84
    """
    RCNN bbox head with serveral convolution layers
W
wangguanzhong 已提交
85

F
Feng Ni 已提交
86
    Args:
W
wangguanzhong 已提交
87 88 89 90 91 92 93 94 95
        in_channel (int): Input channels which can be derived by from_config
        num_convs (int): The number of conv layers
        conv_dim (int): The number of channels for the conv layers
        out_channel (int): Output channels
        resolution (int): Resolution of input feature map
        norm_type (string): Norm type, bn, gn, sync_bn are available, 
            default `gn`
        freeze_norm (bool): Whether to freeze the norm
        stage_name (string): Prefix name for conv layer,  '' by default
F
Feng Ni 已提交
96 97 98
    """

    def __init__(self,
W
wangguanzhong 已提交
99
                 in_channel=256,
F
Feng Ni 已提交
100 101
                 num_convs=4,
                 conv_dim=256,
W
wangguanzhong 已提交
102
                 out_channel=1024,
F
Feng Ni 已提交
103 104 105 106 107
                 resolution=7,
                 norm_type='gn',
                 freeze_norm=False,
                 stage_name=''):
        super(XConvNormHead, self).__init__()
W
wangguanzhong 已提交
108
        self.in_channel = in_channel
F
Feng Ni 已提交
109 110
        self.num_convs = num_convs
        self.conv_dim = conv_dim
W
wangguanzhong 已提交
111
        self.out_channel = out_channel
F
Feng Ni 已提交
112 113 114 115 116 117 118
        self.norm_type = norm_type
        self.freeze_norm = freeze_norm

        self.bbox_head_convs = []
        fan = conv_dim * 3 * 3
        initializer = KaimingNormal(fan_in=fan)
        for i in range(self.num_convs):
W
wangguanzhong 已提交
119
            in_c = in_channel if i == 0 else conv_dim
F
Feng Ni 已提交
120 121 122 123 124 125 126 127 128 129
            head_conv_name = stage_name + 'bbox_head_conv{}'.format(i)
            head_conv = self.add_sublayer(
                head_conv_name,
                ConvNormLayer(
                    ch_in=in_c,
                    ch_out=conv_dim,
                    filter_size=3,
                    stride=1,
                    norm_type=self.norm_type,
                    freeze_norm=self.freeze_norm,
130
                    initializer=initializer))
F
Feng Ni 已提交
131 132 133 134 135
            self.bbox_head_convs.append(head_conv)

        fan = conv_dim * resolution * resolution
        self.fc6 = nn.Linear(
            conv_dim * resolution * resolution,
W
wangguanzhong 已提交
136
            out_channel,
F
Feng Ni 已提交
137 138 139 140 141 142 143 144 145
            weight_attr=paddle.ParamAttr(
                initializer=XavierUniform(fan_out=fan)),
            bias_attr=paddle.ParamAttr(
                learning_rate=2., regularizer=L2Decay(0.)))

    @classmethod
    def from_config(cls, cfg, input_shape):
        s = input_shape
        s = s[0] if isinstance(s, (list, tuple)) else s
W
wangguanzhong 已提交
146
        return {'in_channel': s.channels}
F
Feng Ni 已提交
147 148 149

    @property
    def out_shape(self):
W
wangguanzhong 已提交
150
        return [ShapeSpec(channels=self.out_channel, )]
F
Feng Ni 已提交
151 152 153 154 155 156 157 158 159

    def forward(self, rois_feat):
        for i in range(self.num_convs):
            rois_feat = F.relu(self.bbox_head_convs[i](rois_feat))
        rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
        fc6 = F.relu(self.fc6(rois_feat))
        return fc6


Q
qingqing01 已提交
160 161
@register
class BBoxHead(nn.Layer):
162
    __shared__ = ['num_classes']
G
Guanghua Yu 已提交
163
    __inject__ = ['bbox_assigner', 'bbox_loss']
164
    """
W
wangguanzhong 已提交
165 166 167 168 169 170 171 172 173 174 175
    RCNN bbox head

    Args:
        head (nn.Layer): Extract feature in bbox head
        in_channel (int): Input channel after RoI extractor
        roi_extractor (object): The module of RoI Extractor
        bbox_assigner (object): The module of Box Assigner, label and sample the 
            box.
        with_pool (bool): Whether to use pooling for the RoI feature.
        num_classes (int): The number of classes
        bbox_weight (List[float]): The weight to get the decode box 
176
    """
Q
qingqing01 已提交
177 178

    def __init__(self,
179 180 181 182
                 head,
                 in_channel,
                 roi_extractor=RoIAlign().__dict__,
                 bbox_assigner='BboxAssigner',
Q
qingqing01 已提交
183
                 with_pool=False,
184
                 num_classes=80,
G
Guanghua Yu 已提交
185 186
                 bbox_weight=[10., 10., 5., 5.],
                 bbox_loss=None):
Q
qingqing01 已提交
187
        super(BBoxHead, self).__init__()
188 189 190 191 192 193
        self.head = head
        self.roi_extractor = roi_extractor
        if isinstance(roi_extractor, dict):
            self.roi_extractor = RoIAlign(**roi_extractor)
        self.bbox_assigner = bbox_assigner

Q
qingqing01 已提交
194
        self.with_pool = with_pool
195 196
        self.num_classes = num_classes
        self.bbox_weight = bbox_weight
G
Guanghua Yu 已提交
197
        self.bbox_loss = bbox_loss
198 199 200 201

        self.bbox_score = nn.Linear(
            in_channel,
            self.num_classes + 1,
W
wangguanzhong 已提交
202 203
            weight_attr=paddle.ParamAttr(initializer=Normal(
                mean=0.0, std=0.01)))
G
Guanghua Yu 已提交
204
        self.bbox_score.skip_quant = True
205 206 207 208

        self.bbox_delta = nn.Linear(
            in_channel,
            4 * self.num_classes,
W
wangguanzhong 已提交
209 210
            weight_attr=paddle.ParamAttr(initializer=Normal(
                mean=0.0, std=0.001)))
G
Guanghua Yu 已提交
211
        self.bbox_delta.skip_quant = True
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
        self.assigned_label = None
        self.assigned_rois = None

    @classmethod
    def from_config(cls, cfg, input_shape):
        roi_pooler = cfg['roi_extractor']
        assert isinstance(roi_pooler, dict)
        kwargs = RoIAlign.from_config(cfg, input_shape)
        roi_pooler.update(kwargs)
        kwargs = {'input_shape': input_shape}
        head = create(cfg['head'], **kwargs)
        return {
            'roi_extractor': roi_pooler,
            'head': head,
            'in_channel': head.out_shape[0].channels
        }

    def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None):
        """
W
wangguanzhong 已提交
231
        body_feats (list[Tensor]): Feature maps from backbone
232
        rois (list[Tensor]): RoIs generated from RPN module
W
wangguanzhong 已提交
233 234
        rois_num (Tensor): The number of RoIs in each image
        inputs (dict{Tensor}): The ground-truth of image
235 236
        """
        if self.training:
W
wangguanzhong 已提交
237
            rois, rois_num, targets = self.bbox_assigner(rois, rois_num, inputs)
238 239 240 241 242
            self.assigned_rois = (rois, rois_num)
            self.assigned_targets = targets

        rois_feat = self.roi_extractor(body_feats, rois, rois_num)
        bbox_feat = self.head(rois_feat)
G
Guanghua Yu 已提交
243
        if self.with_pool:
244 245
            feat = F.adaptive_avg_pool2d(bbox_feat, output_size=1)
            feat = paddle.squeeze(feat, axis=[2, 3])
Q
qingqing01 已提交
246
        else:
247 248 249 250 251
            feat = bbox_feat
        scores = self.bbox_score(feat)
        deltas = self.bbox_delta(feat)

        if self.training:
W
wangguanzhong 已提交
252 253
            loss = self.get_loss(scores, deltas, targets, rois,
                                 self.bbox_weight)
254
            return loss, bbox_feat
Q
qingqing01 已提交
255
        else:
256 257 258
            pred = self.get_prediction(scores, deltas)
            return pred, self.head

W
wangguanzhong 已提交
259
    def get_loss(self, scores, deltas, targets, rois, bbox_weight):
260 261 262 263 264 265
        """
        scores (Tensor): scores from bbox head outputs
        deltas (Tensor): deltas from bbox head outputs
        targets (list[List[Tensor]]): bbox targets containing tgt_labels, tgt_bboxes and tgt_gt_inds
        rois (List[Tensor]): RoIs generated in each batch
        """
266 267 268 269
        cls_name = 'loss_bbox_cls'
        reg_name = 'loss_bbox_reg'
        loss_bbox = {}

270 271
        # TODO: better pass args
        tgt_labels, tgt_bboxes, tgt_gt_inds = targets
272 273

        # bbox cls
274 275
        tgt_labels = paddle.concat(tgt_labels) if len(
            tgt_labels) > 1 else tgt_labels[0]
276 277 278 279 280 281 282 283 284 285
        valid_inds = paddle.nonzero(tgt_labels >= 0).flatten()
        if valid_inds.shape[0] == 0:
            loss_bbox[cls_name] = paddle.zeros([1], dtype='float32')
        else:
            tgt_labels = tgt_labels.cast('int64')
            tgt_labels.stop_gradient = True
            loss_bbox_cls = F.cross_entropy(
                input=scores, label=tgt_labels, reduction='mean')
            loss_bbox[cls_name] = loss_bbox_cls

Q
qingqing01 已提交
286
        # bbox reg
287 288 289 290 291 292 293

        cls_agnostic_bbox_reg = deltas.shape[1] == 4

        fg_inds = paddle.nonzero(
            paddle.logical_and(tgt_labels >= 0, tgt_labels <
                               self.num_classes)).flatten()

W
wangguanzhong 已提交
294
        if fg_inds.numel() == 0:
295 296
            loss_bbox[reg_name] = paddle.zeros([1], dtype='float32')
            return loss_bbox
W
wangguanzhong 已提交
297

298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
        if cls_agnostic_bbox_reg:
            reg_delta = paddle.gather(deltas, fg_inds)
        else:
            fg_gt_classes = paddle.gather(tgt_labels, fg_inds)

            reg_row_inds = paddle.arange(fg_gt_classes.shape[0]).unsqueeze(1)
            reg_row_inds = paddle.tile(reg_row_inds, [1, 4]).reshape([-1, 1])

            reg_col_inds = 4 * fg_gt_classes.unsqueeze(1) + paddle.arange(4)

            reg_col_inds = reg_col_inds.reshape([-1, 1])
            reg_inds = paddle.concat([reg_row_inds, reg_col_inds], axis=1)

            reg_delta = paddle.gather(deltas, fg_inds)
            reg_delta = paddle.gather_nd(reg_delta, reg_inds).reshape([-1, 4])
        rois = paddle.concat(rois) if len(rois) > 1 else rois[0]
        tgt_bboxes = paddle.concat(tgt_bboxes) if len(
            tgt_bboxes) > 1 else tgt_bboxes[0]

W
wangguanzhong 已提交
317
        reg_target = bbox2delta(rois, tgt_bboxes, bbox_weight)
318 319 320
        reg_target = paddle.gather(reg_target, fg_inds)
        reg_target.stop_gradient = True

G
Guanghua Yu 已提交
321 322 323 324 325 326 327 328 329
        if self.bbox_loss is not None:
            reg_delta = self.bbox_transform(reg_delta)
            reg_target = self.bbox_transform(reg_target)
            loss_bbox_reg = self.bbox_loss(
                reg_delta, reg_target).sum() / tgt_labels.shape[0]
            loss_bbox_reg *= self.num_classes
        else:
            loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum(
            ) / tgt_labels.shape[0]
330

331
        loss_bbox[reg_name] = loss_bbox_reg
332

Q
qingqing01 已提交
333 334
        return loss_bbox

G
Guanghua Yu 已提交
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
    def bbox_transform(self, deltas, weights=[0.1, 0.1, 0.2, 0.2]):
        wx, wy, ww, wh = weights

        deltas = paddle.reshape(deltas, shape=(0, -1, 4))

        dx = paddle.slice(deltas, axes=[2], starts=[0], ends=[1]) * wx
        dy = paddle.slice(deltas, axes=[2], starts=[1], ends=[2]) * wy
        dw = paddle.slice(deltas, axes=[2], starts=[2], ends=[3]) * ww
        dh = paddle.slice(deltas, axes=[2], starts=[3], ends=[4]) * wh

        dw = paddle.clip(dw, -1.e10, np.log(1000. / 16))
        dh = paddle.clip(dh, -1.e10, np.log(1000. / 16))

        pred_ctr_x = dx
        pred_ctr_y = dy
        pred_w = paddle.exp(dw)
        pred_h = paddle.exp(dh)

        x1 = pred_ctr_x - 0.5 * pred_w
        y1 = pred_ctr_y - 0.5 * pred_h
        x2 = pred_ctr_x + 0.5 * pred_w
        y2 = pred_ctr_y + 0.5 * pred_h

        x1 = paddle.reshape(x1, shape=(-1, ))
        y1 = paddle.reshape(y1, shape=(-1, ))
        x2 = paddle.reshape(x2, shape=(-1, ))
        y2 = paddle.reshape(y2, shape=(-1, ))

        return paddle.concat([x1, y1, x2, y2])

365
    def get_prediction(self, score, delta):
Q
qingqing01 已提交
366
        bbox_prob = F.softmax(score)
367 368 369 370 371 372 373 374 375 376
        return delta, bbox_prob

    def get_head(self, ):
        return self.head

    def get_assigned_targets(self, ):
        return self.assigned_targets

    def get_assigned_rois(self, ):
        return self.assigned_rois