ppyoloe_head.py 26.0 KB
Newer Older
S
shangliang Xu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
19 20 21
from paddle import ParamAttr
from paddle.nn.initializer import KaimingNormal
from paddle.nn.initializer import Normal, Constant
S
shangliang Xu 已提交
22 23 24 25 26

from ..bbox_utils import batch_distance2bbox
from ..losses import GIoULoss
from ..initializer import bias_init_with_prob, constant_, normal_
from ..assigners.utils import generate_anchors_for_grid_cell
27
from ppdet.modeling.backbones.cspresnet import ConvBNLayer, RepVggBlock
W
wangguanzhong 已提交
28
from ppdet.modeling.ops import get_static_shape, get_act_fn
W
wangxinxin08 已提交
29
from ppdet.modeling.layers import MultiClassNMS
S
shangliang Xu 已提交
30

31
__all__ = ['PPYOLOEHead', 'SimpleConvHead']
S
shangliang Xu 已提交
32 33 34


class ESEAttn(nn.Layer):
35
    def __init__(self, feat_channels, act='swish', attn_conv='convbn'):
S
shangliang Xu 已提交
36 37
        super(ESEAttn, self).__init__()
        self.fc = nn.Conv2D(feat_channels, feat_channels, 1)
38 39 40 41
        if attn_conv == 'convbn':
            self.conv = ConvBNLayer(feat_channels, feat_channels, 1, act=act)
        else:
            self.conv = RepVggBlock(feat_channels, feat_channels, act=act)
S
shangliang Xu 已提交
42 43 44 45 46 47 48 49 50 51 52
        self._init_weights()

    def _init_weights(self):
        normal_(self.fc.weight, std=0.001)

    def forward(self, feat, avg_feat):
        weight = F.sigmoid(self.fc(avg_feat))
        return self.conv(feat * weight)


@register
S
shangliang Xu 已提交
53
class PPYOLOEHead(nn.Layer):
54
    __shared__ = [
55
        'num_classes', 'eval_size', 'trt', 'exclude_nms',
F
Feng Ni 已提交
56
        'exclude_post_process', 'use_shared_conv', 'for_distill'
57
    ]
S
shangliang Xu 已提交
58 59 60 61 62 63 64 65 66 67
    __inject__ = ['static_assigner', 'assigner', 'nms']

    def __init__(self,
                 in_channels=[1024, 512, 256],
                 num_classes=80,
                 act='swish',
                 fpn_strides=(32, 16, 8),
                 grid_cell_scale=5.0,
                 grid_cell_offset=0.5,
                 reg_max=16,
68
                 reg_range=None,
S
shangliang Xu 已提交
69 70 71 72 73
                 static_assigner_epoch=4,
                 use_varifocal_loss=True,
                 static_assigner='ATSSAssigner',
                 assigner='TaskAlignedAssigner',
                 nms='MultiClassNMS',
74
                 eval_size=None,
S
shangliang Xu 已提交
75 76 77 78 79
                 loss_weight={
                     'class': 1.0,
                     'iou': 2.5,
                     'dfl': 0.5,
                 },
S
shangliang Xu 已提交
80
                 trt=False,
81
                 attn_conv='convbn',
82
                 exclude_nms=False,
83
                 exclude_post_process=False,
F
Feng Ni 已提交
84 85
                 use_shared_conv=True,
                 for_distill=False):
S
shangliang Xu 已提交
86
        super(PPYOLOEHead, self).__init__()
S
shangliang Xu 已提交
87 88 89 90 91 92
        assert len(in_channels) > 0, "len(in_channels) should > 0"
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.fpn_strides = fpn_strides
        self.grid_cell_scale = grid_cell_scale
        self.grid_cell_offset = grid_cell_offset
93 94 95 96
        if reg_range:
            self.sm_use = True
            self.reg_range = reg_range
        else:
97
            self.sm_use = False
98 99
            self.reg_range = (0, reg_max + 1)
        self.reg_channels = self.reg_range[1] - self.reg_range[0]
S
shangliang Xu 已提交
100 101 102
        self.iou_loss = GIoULoss()
        self.loss_weight = loss_weight
        self.use_varifocal_loss = use_varifocal_loss
103
        self.eval_size = eval_size
S
shangliang Xu 已提交
104 105 106 107 108

        self.static_assigner_epoch = static_assigner_epoch
        self.static_assigner = static_assigner
        self.assigner = assigner
        self.nms = nms
W
wangxinxin08 已提交
109 110
        if isinstance(self.nms, MultiClassNMS) and trt:
            self.nms.trt = trt
S
shangliang Xu 已提交
111
        self.exclude_nms = exclude_nms
112
        self.exclude_post_process = exclude_post_process
113
        self.use_shared_conv = use_shared_conv
F
Feng Ni 已提交
114
        self.for_distill = for_distill
115

S
shangliang Xu 已提交
116 117 118 119 120 121 122
        # stem
        self.stem_cls = nn.LayerList()
        self.stem_reg = nn.LayerList()
        act = get_act_fn(
            act, trt=trt) if act is None or isinstance(act,
                                                       (str, dict)) else act
        for in_c in self.in_channels:
123 124
            self.stem_cls.append(ESEAttn(in_c, act=act, attn_conv=attn_conv))
            self.stem_reg.append(ESEAttn(in_c, act=act, attn_conv=attn_conv))
S
shangliang Xu 已提交
125 126 127 128 129 130 131 132 133
        # pred head
        self.pred_cls = nn.LayerList()
        self.pred_reg = nn.LayerList()
        for in_c in self.in_channels:
            self.pred_cls.append(
                nn.Conv2D(
                    in_c, self.num_classes, 3, padding=1))
            self.pred_reg.append(
                nn.Conv2D(
134
                    in_c, 4 * self.reg_channels, 3, padding=1))
S
shangliang Xu 已提交
135
        # projection conv
136
        self.proj_conv = nn.Conv2D(self.reg_channels, 1, 1, bias_attr=False)
137
        self.proj_conv.skip_quant = True
S
shangliang Xu 已提交
138 139
        self._init_weights()

F
Feng Ni 已提交
140 141 142
        if self.for_distill:
            self.distill_pairs = {}

S
shangliang Xu 已提交
143 144 145 146 147 148 149 150 151 152 153 154
    @classmethod
    def from_config(cls, cfg, input_shape):
        return {'in_channels': [i.channels for i in input_shape], }

    def _init_weights(self):
        bias_cls = bias_init_with_prob(0.01)
        for cls_, reg_ in zip(self.pred_cls, self.pred_reg):
            constant_(cls_.weight)
            constant_(cls_.bias, bias_cls)
            constant_(reg_.weight)
            constant_(reg_.bias, 1.0)

155 156 157
        proj = paddle.linspace(self.reg_range[0], self.reg_range[1] - 1,
                               self.reg_channels).reshape(
                                   [1, self.reg_channels, 1, 1])
158
        self.proj_conv.weight.set_value(proj)
S
shangliang Xu 已提交
159
        self.proj_conv.weight.stop_gradient = True
160
        if self.eval_size:
S
shangliang Xu 已提交
161
            anchor_points, stride_tensor = self._generate_anchors()
W
wangxinxin08 已提交
162 163
            self.anchor_points = anchor_points
            self.stride_tensor = stride_tensor
S
shangliang Xu 已提交
164

165
    def forward_train(self, feats, targets, aux_pred=None):
S
shangliang Xu 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
        anchors, anchor_points, num_anchors_list, stride_tensor = \
            generate_anchors_for_grid_cell(
                feats, self.fpn_strides, self.grid_cell_scale,
                self.grid_cell_offset)

        cls_score_list, reg_distri_list = [], []
        for i, feat in enumerate(feats):
            avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
            cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
                                         feat)
            reg_distri = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
            # cls and reg
            cls_score = F.sigmoid(cls_logit)
            cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
            reg_distri_list.append(reg_distri.flatten(2).transpose([0, 2, 1]))
        cls_score_list = paddle.concat(cls_score_list, axis=1)
        reg_distri_list = paddle.concat(reg_distri_list, axis=1)

        return self.get_loss([
            cls_score_list, reg_distri_list, anchors, anchor_points,
            num_anchors_list, stride_tensor
187
        ], targets, aux_pred)
S
shangliang Xu 已提交
188

S
shangliang Xu 已提交
189
    def _generate_anchors(self, feats=None, dtype='float32'):
S
shangliang Xu 已提交
190 191 192 193 194 195 196
        # just use in eval time
        anchor_points = []
        stride_tensor = []
        for i, stride in enumerate(self.fpn_strides):
            if feats is not None:
                _, _, h, w = feats[i].shape
            else:
197 198
                h = int(self.eval_size[0] / stride)
                w = int(self.eval_size[1] / stride)
S
shangliang Xu 已提交
199 200 201 202 203
            shift_x = paddle.arange(end=w) + self.grid_cell_offset
            shift_y = paddle.arange(end=h) + self.grid_cell_offset
            shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
            anchor_point = paddle.cast(
                paddle.stack(
S
shangliang Xu 已提交
204
                    [shift_x, shift_y], axis=-1), dtype=dtype)
S
shangliang Xu 已提交
205
            anchor_points.append(anchor_point.reshape([-1, 2]))
S
shangliang Xu 已提交
206
            stride_tensor.append(paddle.full([h * w, 1], stride, dtype=dtype))
S
shangliang Xu 已提交
207 208 209 210 211
        anchor_points = paddle.concat(anchor_points)
        stride_tensor = paddle.concat(stride_tensor)
        return anchor_points, stride_tensor

    def forward_eval(self, feats):
212
        if self.eval_size:
S
shangliang Xu 已提交
213 214 215 216 217
            anchor_points, stride_tensor = self.anchor_points, self.stride_tensor
        else:
            anchor_points, stride_tensor = self._generate_anchors(feats)
        cls_score_list, reg_dist_list = [], []
        for i, feat in enumerate(feats):
218
            _, _, h, w = feat.shape
S
shangliang Xu 已提交
219 220 221 222 223
            l = h * w
            avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
            cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
                                         feat)
            reg_dist = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
224 225
            reg_dist = reg_dist.reshape(
                [-1, 4, self.reg_channels, l]).transpose([0, 2, 3, 1])
226 227 228 229 230
            if self.use_shared_conv:
                reg_dist = self.proj_conv(F.softmax(
                    reg_dist, axis=1)).squeeze(1)
            else:
                reg_dist = F.softmax(reg_dist, axis=1)
S
shangliang Xu 已提交
231 232
            # cls and reg
            cls_score = F.sigmoid(cls_logit)
233
            cls_score_list.append(cls_score.reshape([-1, self.num_classes, l]))
234
            reg_dist_list.append(reg_dist)
S
shangliang Xu 已提交
235 236

        cls_score_list = paddle.concat(cls_score_list, axis=-1)
237 238 239 240 241
        if self.use_shared_conv:
            reg_dist_list = paddle.concat(reg_dist_list, axis=1)
        else:
            reg_dist_list = paddle.concat(reg_dist_list, axis=2)
            reg_dist_list = self.proj_conv(reg_dist_list).squeeze(1)
S
shangliang Xu 已提交
242 243 244

        return cls_score_list, reg_dist_list, anchor_points, stride_tensor

245
    def forward(self, feats, targets=None, aux_pred=None):
S
shangliang Xu 已提交
246 247 248 249
        assert len(feats) == len(self.fpn_strides), \
            "The size of feats is not equal to size of fpn_strides"

        if self.training:
250
            return self.forward_train(feats, targets, aux_pred)
S
shangliang Xu 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
        else:
            return self.forward_eval(feats)

    @staticmethod
    def _focal_loss(score, label, alpha=0.25, gamma=2.0):
        weight = (score - label).pow(gamma)
        if alpha > 0:
            alpha_t = alpha * label + (1 - alpha) * (1 - label)
            weight *= alpha_t
        loss = F.binary_cross_entropy(
            score, label, weight=weight, reduction='sum')
        return loss

    @staticmethod
    def _varifocal_loss(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
        weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
        loss = F.binary_cross_entropy(
            pred_score, gt_score, weight=weight, reduction='sum')
        return loss

    def _bbox_decode(self, anchor_points, pred_dist):
272
        _, l, _ = get_static_shape(pred_dist)
273
        pred_dist = F.softmax(pred_dist.reshape([-1, l, 4, self.reg_channels]))
274
        pred_dist = self.proj_conv(pred_dist.transpose([0, 3, 1, 2])).squeeze(1)
S
shangliang Xu 已提交
275 276 277 278 279 280
        return batch_distance2bbox(anchor_points, pred_dist)

    def _bbox2distance(self, points, bbox):
        x1y1, x2y2 = paddle.split(bbox, 2, -1)
        lt = points - x1y1
        rb = x2y2 - points
281 282
        return paddle.concat([lt, rb], -1).clip(self.reg_range[0],
                                                self.reg_range[1] - 1 - 0.01)
S
shangliang Xu 已提交
283

284 285
    def _df_loss(self, pred_dist, target, lower_bound=0):
        target_left = paddle.cast(target.floor(), 'int64')
S
shangliang Xu 已提交
286 287 288 289
        target_right = target_left + 1
        weight_left = target_right.astype('float32') - target
        weight_right = 1 - weight_left
        loss_left = F.cross_entropy(
290 291
            pred_dist, target_left - lower_bound,
            reduction='none') * weight_left
S
shangliang Xu 已提交
292
        loss_right = F.cross_entropy(
293 294
            pred_dist, target_right - lower_bound,
            reduction='none') * weight_right
S
shangliang Xu 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
        return (loss_left + loss_right).mean(-1, keepdim=True)

    def _bbox_loss(self, pred_dist, pred_bboxes, anchor_points, assigned_labels,
                   assigned_bboxes, assigned_scores, assigned_scores_sum):
        # select positive samples mask
        mask_positive = (assigned_labels != self.num_classes)
        num_pos = mask_positive.sum()
        # pos/neg loss
        if num_pos > 0:
            # l1 + iou
            bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4])
            pred_bboxes_pos = paddle.masked_select(pred_bboxes,
                                                   bbox_mask).reshape([-1, 4])
            assigned_bboxes_pos = paddle.masked_select(
                assigned_bboxes, bbox_mask).reshape([-1, 4])
            bbox_weight = paddle.masked_select(
                assigned_scores.sum(-1), mask_positive).unsqueeze(-1)

            loss_l1 = F.l1_loss(pred_bboxes_pos, assigned_bboxes_pos)

            loss_iou = self.iou_loss(pred_bboxes_pos,
                                     assigned_bboxes_pos) * bbox_weight
            loss_iou = loss_iou.sum() / assigned_scores_sum

            dist_mask = mask_positive.unsqueeze(-1).tile(
320
                [1, 1, self.reg_channels * 4])
S
shangliang Xu 已提交
321
            pred_dist_pos = paddle.masked_select(
322
                pred_dist, dist_mask).reshape([-1, 4, self.reg_channels])
S
shangliang Xu 已提交
323 324 325
            assigned_ltrb = self._bbox2distance(anchor_points, assigned_bboxes)
            assigned_ltrb_pos = paddle.masked_select(
                assigned_ltrb, bbox_mask).reshape([-1, 4])
326 327
            loss_dfl = self._df_loss(pred_dist_pos, assigned_ltrb_pos,
                                     self.reg_range[0]) * bbox_weight
S
shangliang Xu 已提交
328
            loss_dfl = loss_dfl.sum() / assigned_scores_sum
F
Feng Ni 已提交
329 330 331 332
            if self.for_distill:
                self.distill_pairs['pred_bboxes_pos'] = pred_bboxes_pos
                self.distill_pairs['pred_dist_pos'] = pred_dist_pos
                self.distill_pairs['bbox_weight'] = bbox_weight
S
shangliang Xu 已提交
333 334 335
        else:
            loss_l1 = paddle.zeros([1])
            loss_iou = paddle.zeros([1])
336
            loss_dfl = pred_dist.sum() * 0.
S
shangliang Xu 已提交
337 338
        return loss_l1, loss_iou, loss_dfl

339
    def get_loss(self, head_outs, gt_meta, aux_pred=None):
S
shangliang Xu 已提交
340 341 342 343 344 345
        pred_scores, pred_distri, anchors,\
        anchor_points, num_anchors_list, stride_tensor = head_outs

        anchor_points_s = anchor_points / stride_tensor
        pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri)

346 347 348 349
        if aux_pred is not None:
            pred_scores_aux = aux_pred[0]
            pred_bboxes_aux = self._bbox_decode(anchor_points_s, aux_pred[1])

S
shangliang Xu 已提交
350 351 352 353 354
        gt_labels = gt_meta['gt_class']
        gt_bboxes = gt_meta['gt_bbox']
        pad_gt_mask = gt_meta['pad_gt_mask']
        # label assignment
        if gt_meta['epoch_id'] < self.static_assigner_epoch:
F
Feng Ni 已提交
355
            assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
S
shangliang Xu 已提交
356 357 358 359 360 361 362 363 364 365
                self.static_assigner(
                    anchors,
                    num_anchors_list,
                    gt_labels,
                    gt_bboxes,
                    pad_gt_mask,
                    bg_index=self.num_classes,
                    pred_bboxes=pred_bboxes.detach() * stride_tensor)
            alpha_l = 0.25
        else:
366
            if self.sm_use:
367
                # only used in smalldet of PPYOLOE-SOD model
F
Feng Ni 已提交
368
                assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
369 370 371 372 373 374 375 376 377 378
                    self.assigner(
                    pred_scores.detach(),
                    pred_bboxes.detach() * stride_tensor,
                    anchor_points,
                    stride_tensor,
                    gt_labels,
                    gt_bboxes,
                    pad_gt_mask,
                    bg_index=self.num_classes)
            else:
379
                if aux_pred is None:
F
Feng Ni 已提交
380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
                    if not hasattr(self, "assigned_labels"):
                        assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
                            self.assigner(
                            pred_scores.detach(),
                            pred_bboxes.detach() * stride_tensor,
                            anchor_points,
                            num_anchors_list,
                            gt_labels,
                            gt_bboxes,
                            pad_gt_mask,
                            bg_index=self.num_classes)
                        self.assigned_labels = assigned_labels
                        self.assigned_bboxes = assigned_bboxes
                        self.assigned_scores = assigned_scores
                        self.mask_positive = mask_positive
                    else:
                        assigned_labels = self.assigned_labels
                        assigned_bboxes = self.assigned_bboxes
                        assigned_scores = self.assigned_scores
                        mask_positive = self.mask_positive
400
                else:
F
Feng Ni 已提交
401
                    assigned_labels, assigned_bboxes, assigned_scores, mask_positive = \
402 403 404 405 406 407 408 409 410
                            self.assigner(
                            pred_scores_aux.detach(),
                            pred_bboxes_aux.detach() * stride_tensor,
                            anchor_points,
                            num_anchors_list,
                            gt_labels,
                            gt_bboxes,
                            pad_gt_mask,
                            bg_index=self.num_classes)
S
shangliang Xu 已提交
411 412 413
            alpha_l = -1
        # rescale bbox
        assigned_bboxes /= stride_tensor
414 415 416

        assign_out_dict = self.get_loss_from_assign(
            pred_scores, pred_distri, pred_bboxes, anchor_points_s,
F
Feng Ni 已提交
417 418
            assigned_labels, assigned_bboxes, assigned_scores, mask_positive,
            alpha_l)
419 420 421 422

        if aux_pred is not None:
            assign_out_dict_aux = self.get_loss_from_assign(
                aux_pred[0], aux_pred[1], pred_bboxes_aux, anchor_points_s,
F
Feng Ni 已提交
423 424
                assigned_labels, assigned_bboxes, assigned_scores,
                mask_positive, alpha_l)
425 426 427 428 429 430 431 432 433 434
            loss = {}
            for key in assign_out_dict.keys():
                loss[key] = assign_out_dict[key] + assign_out_dict_aux[key]
        else:
            loss = assign_out_dict

        return loss

    def get_loss_from_assign(self, pred_scores, pred_distri, pred_bboxes,
                             anchor_points_s, assigned_labels, assigned_bboxes,
F
Feng Ni 已提交
435
                             assigned_scores, mask_positive, alpha_l):
S
shangliang Xu 已提交
436 437
        # cls loss
        if self.use_varifocal_loss:
S
shangliang Xu 已提交
438 439
            one_hot_label = F.one_hot(assigned_labels,
                                      self.num_classes + 1)[..., :-1]
S
shangliang Xu 已提交
440 441 442
            loss_cls = self._varifocal_loss(pred_scores, assigned_scores,
                                            one_hot_label)
        else:
S
shangliang Xu 已提交
443
            loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha_l)
S
shangliang Xu 已提交
444 445

        assigned_scores_sum = assigned_scores.sum()
W
wangguanzhong 已提交
446
        if paddle.distributed.get_world_size() > 1:
S
shangliang Xu 已提交
447
            paddle.distributed.all_reduce(assigned_scores_sum)
448 449
            assigned_scores_sum /= paddle.distributed.get_world_size()
        assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.)
S
shangliang Xu 已提交
450 451
        loss_cls /= assigned_scores_sum

F
Feng Ni 已提交
452 453 454 455 456 457 458 459 460
        if self.for_distill:
            self.distill_pairs['pred_cls_scores'] = pred_scores
            self.distill_pairs['pos_num'] = assigned_scores_sum
            self.distill_pairs['assigned_scores'] = assigned_scores
            self.distill_pairs['mask_positive'] = mask_positive
            one_hot_label = F.one_hot(assigned_labels,
                                      self.num_classes + 1)[..., :-1]
            self.distill_pairs['target_labels'] = one_hot_label

S
shangliang Xu 已提交
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476
        loss_l1, loss_iou, loss_dfl = \
            self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s,
                            assigned_labels, assigned_bboxes, assigned_scores,
                            assigned_scores_sum)
        loss = self.loss_weight['class'] * loss_cls + \
               self.loss_weight['iou'] * loss_iou + \
               self.loss_weight['dfl'] * loss_dfl
        out_dict = {
            'loss': loss,
            'loss_cls': loss_cls,
            'loss_iou': loss_iou,
            'loss_dfl': loss_dfl,
            'loss_l1': loss_l1,
        }
        return out_dict

S
shangliang Xu 已提交
477
    def post_process(self, head_outs, scale_factor):
S
shangliang Xu 已提交
478
        pred_scores, pred_dist, anchor_points, stride_tensor = head_outs
479
        pred_bboxes = batch_distance2bbox(anchor_points, pred_dist)
S
shangliang Xu 已提交
480
        pred_bboxes *= stride_tensor
481 482
        if self.exclude_post_process:
            return paddle.concat(
F
Feng Ni 已提交
483 484
                [pred_bboxes, pred_scores.transpose([0, 2, 1])],
                axis=-1), None, None
S
shangliang Xu 已提交
485
        else:
486 487 488 489 490 491 492 493
            # scale bbox to origin
            scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
            scale_factor = paddle.concat(
                [scale_x, scale_y, scale_x, scale_y],
                axis=-1).reshape([-1, 1, 4])
            pred_bboxes /= scale_factor
            if self.exclude_nms:
                # `exclude_nms=True` just use in benchmark
F
Feng Ni 已提交
494
                return pred_bboxes, pred_scores, None
495
            else:
F
Feng Ni 已提交
496 497
                bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes,
                                                                   pred_scores)
X
xs1997zju 已提交
498
                return bbox_pred, bbox_num, before_nms_indexes
499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664


def get_activation(name="LeakyReLU"):
    if name == "silu":
        module = nn.Silu()
    elif name == "relu":
        module = nn.ReLU()
    elif name in ["LeakyReLU", 'leakyrelu', 'lrelu']:
        module = nn.LeakyReLU(0.1)
    elif name is None:
        module = nn.Identity()
    else:
        raise AttributeError("Unsupported act type: {}".format(name))
    return module


class ConvNormLayer(nn.Layer):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 norm_type='gn',
                 activation="LeakyReLU"):
        super(ConvNormLayer, self).__init__()
        assert norm_type in ['bn', 'sync_bn', 'syncbn', 'gn', None]
        self.conv = nn.Conv2D(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias_attr=False,
            weight_attr=ParamAttr(initializer=KaimingNormal()))

        if norm_type in ['bn', 'sync_bn', 'syncbn']:
            self.norm = nn.BatchNorm2D(out_channels)
        elif norm_type == 'gn':
            self.norm = nn.GroupNorm(num_groups=32, num_channels=out_channels)
        else:
            self.norm = None

        self.act = get_activation(activation)

    def forward(self, x):
        y = self.conv(x)
        if self.norm is not None:
            y = self.norm(y)
        y = self.act(y)
        return y


class ScaleReg(nn.Layer):
    """
    Parameter for scaling the regression outputs.
    """

    def __init__(self, scale=1.0):
        super(ScaleReg, self).__init__()
        scale = paddle.to_tensor(scale)
        self.scale = self.create_parameter(
            shape=[1],
            dtype='float32',
            default_initializer=nn.initializer.Assign(scale))

    def forward(self, x):
        return x * self.scale


@register
class SimpleConvHead(nn.Layer):
    __shared__ = ['num_classes']

    def __init__(self,
                 num_classes=80,
                 feat_in=288,
                 feat_out=288,
                 num_convs=1,
                 fpn_strides=[32, 16, 8, 4],
                 norm_type='gn',
                 act='LeakyReLU',
                 prior_prob=0.01,
                 reg_max=16):
        super(SimpleConvHead, self).__init__()
        self.num_classes = num_classes
        self.feat_in = feat_in
        self.feat_out = feat_out
        self.num_convs = num_convs
        self.fpn_strides = fpn_strides
        self.reg_max = reg_max

        self.cls_convs = nn.LayerList()
        self.reg_convs = nn.LayerList()
        for i in range(self.num_convs):
            in_c = feat_in if i == 0 else feat_out
            self.cls_convs.append(
                ConvNormLayer(
                    in_c,
                    feat_out,
                    3,
                    stride=1,
                    padding=1,
                    norm_type=norm_type,
                    activation=act))
            self.reg_convs.append(
                ConvNormLayer(
                    in_c,
                    feat_out,
                    3,
                    stride=1,
                    padding=1,
                    norm_type=norm_type,
                    activation=act))

        bias_cls = bias_init_with_prob(prior_prob)
        self.gfl_cls = nn.Conv2D(
            feat_out,
            self.num_classes,
            kernel_size=3,
            stride=1,
            padding=1,
            weight_attr=ParamAttr(initializer=Normal(
                mean=0.0, std=0.01)),
            bias_attr=ParamAttr(initializer=Constant(value=bias_cls)))
        self.gfl_reg = nn.Conv2D(
            feat_out,
            4 * (self.reg_max + 1),
            kernel_size=3,
            stride=1,
            padding=1,
            weight_attr=ParamAttr(initializer=Normal(
                mean=0.0, std=0.01)),
            bias_attr=ParamAttr(initializer=Constant(value=0)))

        self.scales = nn.LayerList()
        for i in range(len(self.fpn_strides)):
            self.scales.append(ScaleReg(1.0))

    def forward(self, feats):
        cls_scores = []
        bbox_preds = []
        for x, scale in zip(feats, self.scales):
            cls_feat = x
            reg_feat = x
            for cls_conv in self.cls_convs:
                cls_feat = cls_conv(cls_feat)
            for reg_conv in self.reg_convs:
                reg_feat = reg_conv(reg_feat)

            cls_score = self.gfl_cls(cls_feat)
            cls_score = F.sigmoid(cls_score)
            cls_score = cls_score.flatten(2).transpose([0, 2, 1])
            cls_scores.append(cls_score)

            bbox_pred = scale(self.gfl_reg(reg_feat))
            bbox_pred = bbox_pred.flatten(2).transpose([0, 2, 1])
            bbox_preds.append(bbox_pred)

        cls_scores = paddle.concat(cls_scores, axis=1)
        bbox_preds = paddle.concat(bbox_preds, axis=1)
        return cls_scores, bbox_preds