pico_head.py 30.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
G
Guanghua Yu 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Normal, Constant

27 28 29 30
from ppdet.modeling.ops import get_static_shape
from ..initializer import normal_
from ..assigners.utils import generate_anchors_for_grid_cell
from ..bbox_utils import bbox_center, batch_distance2bbox, bbox2distance
G
Guanghua Yu 已提交
31 32
from ppdet.core.workspace import register
from ppdet.modeling.layers import ConvNormLayer
33
from .simota_head import OTAVFLHead
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
from .gfl_head import Integral, GFLHead
from ppdet.modeling.necks.csp_pan import DPModule

eps = 1e-9

__all__ = ['PicoHead', 'PicoHeadV2', 'PicoFeat']


class PicoSE(nn.Layer):
    def __init__(self, feat_channels):
        super(PicoSE, self).__init__()
        self.fc = nn.Conv2D(feat_channels, feat_channels, 1)
        self.conv = ConvNormLayer(feat_channels, feat_channels, 1, 1)

        self._init_weights()

    def _init_weights(self):
        normal_(self.fc.weight, std=0.001)

    def forward(self, feat, avg_feat):
        weight = F.sigmoid(self.fc(avg_feat))
        out = self.conv(feat * weight)
        return out
G
Guanghua Yu 已提交
57 58 59 60 61 62 63 64 65 66 67 68


@register
class PicoFeat(nn.Layer):
    """
    PicoFeat of PicoDet

    Args:
        feat_in (int): The channel number of input Tensor.
        feat_out (int): The channel number of output Tensor.
        num_convs (int): The convolution number of the LiteGFLFeat.
        norm_type (str): Normalization type, 'bn'/'sync_bn'/'gn'.
69 70 71
        share_cls_reg (bool): Whether to share the cls and reg output.
        act (str): The act of per layers.
        use_se (bool): Whether to use se module.
G
Guanghua Yu 已提交
72 73 74 75 76 77 78 79
    """

    def __init__(self,
                 feat_in=256,
                 feat_out=96,
                 num_fpn_stride=3,
                 num_convs=2,
                 norm_type='bn',
80
                 share_cls_reg=False,
81 82
                 act='hard_swish',
                 use_se=False):
G
Guanghua Yu 已提交
83 84 85 86
        super(PicoFeat, self).__init__()
        self.num_convs = num_convs
        self.norm_type = norm_type
        self.share_cls_reg = share_cls_reg
87
        self.act = act
88
        self.use_se = use_se
G
Guanghua Yu 已提交
89 90
        self.cls_convs = []
        self.reg_convs = []
91 92
        if use_se:
            assert share_cls_reg == True, \
G
Guanghua Yu 已提交
93
                'In the case of using se, share_cls_reg must be set to True'
94
            self.se = nn.LayerList()
G
Guanghua Yu 已提交
95 96 97 98 99 100 101 102 103 104
        for stage_idx in range(num_fpn_stride):
            cls_subnet_convs = []
            reg_subnet_convs = []
            for i in range(self.num_convs):
                in_c = feat_in if i == 0 else feat_out
                cls_conv_dw = self.add_sublayer(
                    'cls_conv_dw{}.{}'.format(stage_idx, i),
                    ConvNormLayer(
                        ch_in=in_c,
                        ch_out=feat_out,
G
Guanghua Yu 已提交
105
                        filter_size=5,
G
Guanghua Yu 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
                        stride=1,
                        groups=feat_out,
                        norm_type=norm_type,
                        bias_on=False,
                        lr_scale=2.))
                cls_subnet_convs.append(cls_conv_dw)
                cls_conv_pw = self.add_sublayer(
                    'cls_conv_pw{}.{}'.format(stage_idx, i),
                    ConvNormLayer(
                        ch_in=in_c,
                        ch_out=feat_out,
                        filter_size=1,
                        stride=1,
                        norm_type=norm_type,
                        bias_on=False,
                        lr_scale=2.))
                cls_subnet_convs.append(cls_conv_pw)

                if not self.share_cls_reg:
                    reg_conv_dw = self.add_sublayer(
                        'reg_conv_dw{}.{}'.format(stage_idx, i),
                        ConvNormLayer(
                            ch_in=in_c,
                            ch_out=feat_out,
G
Guanghua Yu 已提交
130
                            filter_size=5,
G
Guanghua Yu 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
                            stride=1,
                            groups=feat_out,
                            norm_type=norm_type,
                            bias_on=False,
                            lr_scale=2.))
                    reg_subnet_convs.append(reg_conv_dw)
                    reg_conv_pw = self.add_sublayer(
                        'reg_conv_pw{}.{}'.format(stage_idx, i),
                        ConvNormLayer(
                            ch_in=in_c,
                            ch_out=feat_out,
                            filter_size=1,
                            stride=1,
                            norm_type=norm_type,
                            bias_on=False,
                            lr_scale=2.))
                    reg_subnet_convs.append(reg_conv_pw)
            self.cls_convs.append(cls_subnet_convs)
            self.reg_convs.append(reg_subnet_convs)
150 151
            if use_se:
                self.se.append(PicoSE(feat_out))
G
Guanghua Yu 已提交
152

153 154 155 156 157
    def act_func(self, x):
        if self.act == "leaky_relu":
            x = F.leaky_relu(x)
        elif self.act == "hard_swish":
            x = F.hardswish(x)
G
Guanghua Yu 已提交
158 159
        elif self.act == "relu6":
            x = F.relu6(x)
160 161
        return x

G
Guanghua Yu 已提交
162 163 164 165 166
    def forward(self, fpn_feat, stage_idx):
        assert stage_idx < len(self.cls_convs)
        cls_feat = fpn_feat
        reg_feat = fpn_feat
        for i in range(len(self.cls_convs[stage_idx])):
167
            cls_feat = self.act_func(self.cls_convs[stage_idx][i](cls_feat))
168
            reg_feat = cls_feat
G
Guanghua Yu 已提交
169
            if not self.share_cls_reg:
170
                reg_feat = self.act_func(self.reg_convs[stage_idx][i](reg_feat))
171 172 173 174
        if self.use_se:
            avg_feat = F.adaptive_avg_pool2d(cls_feat, (1, 1))
            se_feat = self.act_func(self.se[stage_idx](cls_feat, avg_feat))
            return cls_feat, se_feat
G
Guanghua Yu 已提交
175 176 177 178
        return cls_feat, reg_feat


@register
179
class PicoHead(OTAVFLHead):
G
Guanghua Yu 已提交
180 181 182
    """
    PicoHead
    Args:
183
        conv_feat (object): Instance of 'PicoFeat'
G
Guanghua Yu 已提交
184 185 186
        num_classes (int): Number of classes
        fpn_stride (list): The stride of each FPN Layer
        prior_prob (float): Used to set the bias init for the class prediction layer
187 188 189 190
        loss_class (object): Instance of VariFocalLoss.
        loss_dfl (object): Instance of DistributionFocalLoss.
        loss_bbox (object): Instance of bbox loss.
        assigner (object): Instance of label assigner.
G
Guanghua Yu 已提交
191
        reg_max: Max value of integral set :math: `{0, ..., reg_max}`
192
                n QFL setting. Default: 7.
G
Guanghua Yu 已提交
193 194
    """
    __inject__ = [
195 196
        'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox',
        'assigner', 'nms'
G
Guanghua Yu 已提交
197
    ]
198
    __shared__ = ['num_classes', 'eval_size']
G
Guanghua Yu 已提交
199 200 201 202 203 204 205

    def __init__(self,
                 conv_feat='PicoFeat',
                 dgqp_module=None,
                 num_classes=80,
                 fpn_stride=[8, 16, 32],
                 prior_prob=0.01,
206
                 loss_class='VariFocalLoss',
G
Guanghua Yu 已提交
207 208
                 loss_dfl='DistributionFocalLoss',
                 loss_bbox='GIoULoss',
209
                 assigner='SimOTAAssigner',
G
Guanghua Yu 已提交
210 211 212 213
                 reg_max=16,
                 feat_in_chan=96,
                 nms=None,
                 nms_pre=1000,
214 215
                 cell_offset=0,
                 eval_size=None):
G
Guanghua Yu 已提交
216 217 218 219 220 221
        super(PicoHead, self).__init__(
            conv_feat=conv_feat,
            dgqp_module=dgqp_module,
            num_classes=num_classes,
            fpn_stride=fpn_stride,
            prior_prob=prior_prob,
222
            loss_class=loss_class,
G
Guanghua Yu 已提交
223 224
            loss_dfl=loss_dfl,
            loss_bbox=loss_bbox,
225
            assigner=assigner,
G
Guanghua Yu 已提交
226 227 228 229 230 231 232 233 234
            reg_max=reg_max,
            feat_in_chan=feat_in_chan,
            nms=nms,
            nms_pre=nms_pre,
            cell_offset=cell_offset)
        self.conv_feat = conv_feat
        self.num_classes = num_classes
        self.fpn_stride = fpn_stride
        self.prior_prob = prior_prob
235
        self.loss_vfl = loss_class
G
Guanghua Yu 已提交
236 237
        self.loss_dfl = loss_dfl
        self.loss_bbox = loss_bbox
238
        self.assigner = assigner
G
Guanghua Yu 已提交
239 240 241 242 243
        self.reg_max = reg_max
        self.feat_in_chan = feat_in_chan
        self.nms = nms
        self.nms_pre = nms_pre
        self.cell_offset = cell_offset
244
        self.eval_size = eval_size
245 246

        self.use_sigmoid = self.loss_vfl.use_sigmoid
G
Guanghua Yu 已提交
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
        if self.use_sigmoid:
            self.cls_out_channels = self.num_classes
        else:
            self.cls_out_channels = self.num_classes + 1
        bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
        # Clear the super class initialization
        self.gfl_head_cls = None
        self.gfl_head_reg = None
        self.scales_regs = None

        self.head_cls_list = []
        self.head_reg_list = []
        for i in range(len(fpn_stride)):
            head_cls = self.add_sublayer(
                "head_cls" + str(i),
                nn.Conv2D(
                    in_channels=self.feat_in_chan,
                    out_channels=self.cls_out_channels + 4 * (self.reg_max + 1)
                    if self.conv_feat.share_cls_reg else self.cls_out_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    weight_attr=ParamAttr(initializer=Normal(
                        mean=0., std=0.01)),
                    bias_attr=ParamAttr(
                        initializer=Constant(value=bias_init_value))))
            self.head_cls_list.append(head_cls)
            if not self.conv_feat.share_cls_reg:
                head_reg = self.add_sublayer(
                    "head_reg" + str(i),
                    nn.Conv2D(
                        in_channels=self.feat_in_chan,
                        out_channels=4 * (self.reg_max + 1),
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        weight_attr=ParamAttr(initializer=Normal(
                            mean=0., std=0.01)),
                        bias_attr=ParamAttr(initializer=Constant(value=0))))
                self.head_reg_list.append(head_reg)

288 289 290 291
        # initialize the anchor points
        if self.eval_size:
            self.anchor_points, self.stride_tensor = self._generate_anchors()

292
    def forward(self, fpn_feats, export_post_process=True):
G
Guanghua Yu 已提交
293 294 295
        assert len(fpn_feats) == len(
            self.fpn_stride
        ), "The size of fpn_feats is not equal to size of fpn_stride"
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331

        if self.training:
            return self.forward_train(fpn_feats)
        else:
            return self.forward_eval(
                fpn_feats, export_post_process=export_post_process)

    def forward_train(self, fpn_feats):
        cls_logits_list, bboxes_reg_list = [], []
        for i, fpn_feat in enumerate(fpn_feats):
            conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat, i)
            if self.conv_feat.share_cls_reg:
                cls_logits = self.head_cls_list[i](conv_cls_feat)
                cls_score, bbox_pred = paddle.split(
                    cls_logits,
                    [self.cls_out_channels, 4 * (self.reg_max + 1)],
                    axis=1)
            else:
                cls_score = self.head_cls_list[i](conv_cls_feat)
                bbox_pred = self.head_reg_list[i](conv_reg_feat)

            if self.dgqp_module:
                quality_score = self.dgqp_module(bbox_pred)
                cls_score = F.sigmoid(cls_score) * quality_score

            cls_logits_list.append(cls_score)
            bboxes_reg_list.append(bbox_pred)

        return (cls_logits_list, bboxes_reg_list)

    def forward_eval(self, fpn_feats, export_post_process=True):
        if self.eval_size:
            anchor_points, stride_tensor = self.anchor_points, self.stride_tensor
        else:
            anchor_points, stride_tensor = self._generate_anchors(fpn_feats)
        cls_logits_list, bboxes_reg_list = [], []
G
Guanghua Yu 已提交
332 333 334 335 336 337 338 339 340 341 342
        for i, fpn_feat in enumerate(fpn_feats):
            conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat, i)
            if self.conv_feat.share_cls_reg:
                cls_logits = self.head_cls_list[i](conv_cls_feat)
                cls_score, bbox_pred = paddle.split(
                    cls_logits,
                    [self.cls_out_channels, 4 * (self.reg_max + 1)],
                    axis=1)
            else:
                cls_score = self.head_cls_list[i](conv_cls_feat)
                bbox_pred = self.head_reg_list[i](conv_reg_feat)
343

G
Guanghua Yu 已提交
344 345 346 347
            if self.dgqp_module:
                quality_score = self.dgqp_module(bbox_pred)
                cls_score = F.sigmoid(cls_score) * quality_score

348
            if not export_post_process:
349 350
                # Now only supports batch size = 1 in deploy
                # TODO(ygh): support batch size > 1
351
                cls_score_out = F.sigmoid(cls_score).reshape(
352 353 354
                    [1, self.cls_out_channels, -1]).transpose([0, 2, 1])
                bbox_pred = bbox_pred.reshape([1, (self.reg_max + 1) * 4,
                                               -1]).transpose([0, 2, 1])
355
            else:
356
                _, _, h, w = fpn_feat.shape
357 358
                l = h * w
                cls_score_out = F.sigmoid(
359
                    cls_score.reshape([-1, self.cls_out_channels, l]))
G
Guanghua Yu 已提交
360
                bbox_pred = bbox_pred.transpose([0, 2, 3, 1])
361
                bbox_pred = self.distribution_project(bbox_pred)
362
                bbox_pred = bbox_pred.reshape([-1, l, 4])
G
Guanghua Yu 已提交
363

364
            cls_logits_list.append(cls_score_out)
G
Guanghua Yu 已提交
365 366
            bboxes_reg_list.append(bbox_pred)

367 368 369 370 371 372 373
        if export_post_process:
            cls_logits_list = paddle.concat(cls_logits_list, axis=-1)
            bboxes_reg_list = paddle.concat(bboxes_reg_list, axis=1)
            bboxes_reg_list = batch_distance2bbox(anchor_points,
                                                  bboxes_reg_list)
            bboxes_reg_list *= stride_tensor

G
Guanghua Yu 已提交
374
        return (cls_logits_list, bboxes_reg_list)
375

376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401
    def _generate_anchors(self, feats=None):
        # just use in eval time
        anchor_points = []
        stride_tensor = []
        for i, stride in enumerate(self.fpn_stride):
            if feats is not None:
                _, _, h, w = feats[i].shape
            else:
                h = math.ceil(self.eval_size[0] / stride)
                w = math.ceil(self.eval_size[1] / stride)
            shift_x = paddle.arange(end=w) + self.cell_offset
            shift_y = paddle.arange(end=h) + self.cell_offset
            shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
            anchor_point = paddle.cast(
                paddle.stack(
                    [shift_x, shift_y], axis=-1), dtype='float32')
            anchor_points.append(anchor_point.reshape([-1, 2]))
            stride_tensor.append(
                paddle.full(
                    [h * w, 1], stride, dtype='float32'))
        anchor_points = paddle.concat(anchor_points)
        stride_tensor = paddle.concat(stride_tensor)
        return anchor_points, stride_tensor

    def post_process(self, head_outs, scale_factor, export_nms=True):
        pred_scores, pred_bboxes = head_outs
402
        if not export_nms:
403
            return pred_bboxes, pred_scores
404 405
        else:
            # rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale]
406 407 408 409 410 411 412
            scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
            scale_factor = paddle.concat(
                [scale_x, scale_y, scale_x, scale_y],
                axis=-1).reshape([-1, 1, 4])
            # scale bbox to origin image size.
            pred_bboxes /= scale_factor
            bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
413 414
            return bbox_pred, bbox_num

415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435

@register
class PicoHeadV2(GFLHead):
    """
    PicoHeadV2
    Args:
        conv_feat (object): Instance of 'PicoFeat'
        num_classes (int): Number of classes
        fpn_stride (list): The stride of each FPN Layer
        prior_prob (float): Used to set the bias init for the class prediction layer
        loss_class (object): Instance of VariFocalLoss.
        loss_dfl (object): Instance of DistributionFocalLoss.
        loss_bbox (object): Instance of bbox loss.
        assigner (object): Instance of label assigner.
        reg_max: Max value of integral set :math: `{0, ..., reg_max}`
                n QFL setting. Default: 7.
    """
    __inject__ = [
        'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox',
        'static_assigner', 'assigner', 'nms'
    ]
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
    __shared__ = ['num_classes', 'eval_size']

    def __init__(self,
                 conv_feat='PicoFeatV2',
                 dgqp_module=None,
                 num_classes=80,
                 fpn_stride=[8, 16, 32],
                 prior_prob=0.01,
                 use_align_head=True,
                 loss_class='VariFocalLoss',
                 loss_dfl='DistributionFocalLoss',
                 loss_bbox='GIoULoss',
                 static_assigner_epoch=60,
                 static_assigner='ATSSAssigner',
                 assigner='TaskAlignedAssigner',
                 reg_max=16,
                 feat_in_chan=96,
                 nms=None,
                 nms_pre=1000,
                 cell_offset=0,
                 act='hard_swish',
                 grid_cell_scale=5.0,
                 eval_size=None):
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
        super(PicoHeadV2, self).__init__(
            conv_feat=conv_feat,
            dgqp_module=dgqp_module,
            num_classes=num_classes,
            fpn_stride=fpn_stride,
            prior_prob=prior_prob,
            loss_class=loss_class,
            loss_dfl=loss_dfl,
            loss_bbox=loss_bbox,
            reg_max=reg_max,
            feat_in_chan=feat_in_chan,
            nms=nms,
            nms_pre=nms_pre,
            cell_offset=cell_offset, )
        self.conv_feat = conv_feat
        self.num_classes = num_classes
        self.fpn_stride = fpn_stride
        self.prior_prob = prior_prob
        self.loss_vfl = loss_class
        self.loss_dfl = loss_dfl
        self.loss_bbox = loss_bbox

        self.static_assigner_epoch = static_assigner_epoch
        self.static_assigner = static_assigner
        self.assigner = assigner

        self.reg_max = reg_max
        self.feat_in_chan = feat_in_chan
        self.nms = nms
        self.nms_pre = nms_pre
        self.cell_offset = cell_offset
        self.act = act
        self.grid_cell_scale = grid_cell_scale
        self.use_align_head = use_align_head
        self.cls_out_channels = self.num_classes
494
        self.eval_size = eval_size
495 496 497 498 499 500 501

        bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
        # Clear the super class initialization
        self.gfl_head_cls = None
        self.gfl_head_reg = None
        self.scales_regs = None

502 503
        self.head_cls_list = nn.LayerList()
        self.head_reg_list = nn.LayerList()
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540
        self.cls_align = nn.LayerList()

        for i in range(len(fpn_stride)):
            head_cls = self.add_sublayer(
                "head_cls" + str(i),
                nn.Conv2D(
                    in_channels=self.feat_in_chan,
                    out_channels=self.cls_out_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    weight_attr=ParamAttr(initializer=Normal(
                        mean=0., std=0.01)),
                    bias_attr=ParamAttr(
                        initializer=Constant(value=bias_init_value))))
            self.head_cls_list.append(head_cls)
            head_reg = self.add_sublayer(
                "head_reg" + str(i),
                nn.Conv2D(
                    in_channels=self.feat_in_chan,
                    out_channels=4 * (self.reg_max + 1),
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    weight_attr=ParamAttr(initializer=Normal(
                        mean=0., std=0.01)),
                    bias_attr=ParamAttr(initializer=Constant(value=0))))
            self.head_reg_list.append(head_reg)
            if self.use_align_head:
                self.cls_align.append(
                    DPModule(
                        self.feat_in_chan,
                        1,
                        5,
                        act=self.act,
                        use_act_in_out=False))

541 542 543 544
        # initialize the anchor points
        if self.eval_size:
            self.anchor_points, self.stride_tensor = self._generate_anchors()

G
Guanghua Yu 已提交
545
    def forward(self, fpn_feats, export_post_process=True):
546 547 548 549
        assert len(fpn_feats) == len(
            self.fpn_stride
        ), "The size of fpn_feats is not equal to size of fpn_stride"

550 551 552 553 554 555 556
        if self.training:
            return self.forward_train(fpn_feats)
        else:
            return self.forward_eval(
                fpn_feats, export_post_process=export_post_process)

    def forward_train(self, fpn_feats):
557
        cls_score_list, reg_list, box_list = [], [], []
558
        for i, (fpn_feat, stride) in enumerate(zip(fpn_feats, self.fpn_stride)):
559 560 561 562 563 564 565 566
            b, _, h, w = get_static_shape(fpn_feat)
            # task decomposition
            conv_cls_feat, se_feat = self.conv_feat(fpn_feat, i)
            cls_logit = self.head_cls_list[i](se_feat)
            reg_pred = self.head_reg_list[i](se_feat)

            # cls prediction and alignment
            if self.use_align_head:
567
                cls_prob = F.sigmoid(self.cls_align[i](conv_cls_feat))
568 569 570 571
                cls_score = (F.sigmoid(cls_logit) * cls_prob + eps).sqrt()
            else:
                cls_score = F.sigmoid(cls_logit)

572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599
            cls_score_out = cls_score.transpose([0, 2, 3, 1])
            bbox_pred = reg_pred.transpose([0, 2, 3, 1])
            b, cell_h, cell_w, _ = paddle.shape(cls_score_out)
            y, x = self.get_single_level_center_point(
                [cell_h, cell_w], stride, cell_offset=self.cell_offset)
            center_points = paddle.stack([x, y], axis=-1)
            cls_score_out = cls_score_out.reshape(
                [b, -1, self.cls_out_channels])
            bbox_pred = self.distribution_project(bbox_pred) * stride
            bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4])
            bbox_pred = batch_distance2bbox(
                center_points, bbox_pred, max_shapes=None)
            cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
            reg_list.append(reg_pred.flatten(2).transpose([0, 2, 1]))
            box_list.append(bbox_pred / stride)

        cls_score_list = paddle.concat(cls_score_list, axis=1)
        box_list = paddle.concat(box_list, axis=1)
        reg_list = paddle.concat(reg_list, axis=1)
        return cls_score_list, reg_list, box_list, fpn_feats

    def forward_eval(self, fpn_feats, export_post_process=True):
        if self.eval_size:
            anchor_points, stride_tensor = self.anchor_points, self.stride_tensor
        else:
            anchor_points, stride_tensor = self._generate_anchors(fpn_feats)
        cls_score_list, box_list = [], []
        for i, (fpn_feat, stride) in enumerate(zip(fpn_feats, self.fpn_stride)):
600
            _, _, h, w = fpn_feat.shape
601 602 603 604 605 606 607 608 609 610 611 612 613
            # task decomposition
            conv_cls_feat, se_feat = self.conv_feat(fpn_feat, i)
            cls_logit = self.head_cls_list[i](se_feat)
            reg_pred = self.head_reg_list[i](se_feat)

            # cls prediction and alignment
            if self.use_align_head:
                cls_prob = F.sigmoid(self.cls_align[i](conv_cls_feat))
                cls_score = (F.sigmoid(cls_logit) * cls_prob + eps).sqrt()
            else:
                cls_score = F.sigmoid(cls_logit)

            if not export_post_process:
G
Guanghua Yu 已提交
614
                # Now only supports batch size = 1 in deploy
615
                cls_score_list.append(
G
Guanghua Yu 已提交
616 617 618 619 620
                    cls_score.reshape([1, self.cls_out_channels, -1]).transpose(
                        [0, 2, 1]))
                box_list.append(
                    reg_pred.reshape([1, (self.reg_max + 1) * 4, -1]).transpose(
                        [0, 2, 1]))
621
            else:
622
                l = h * w
623 624
                cls_score_out = cls_score.reshape(
                    [-1, self.cls_out_channels, l])
G
Guanghua Yu 已提交
625
                bbox_pred = reg_pred.transpose([0, 2, 3, 1])
626
                bbox_pred = self.distribution_project(bbox_pred)
627
                bbox_pred = bbox_pred.reshape([-1, l, 4])
628 629 630 631 632
                cls_score_list.append(cls_score_out)
                box_list.append(bbox_pred)

        if export_post_process:
            cls_score_list = paddle.concat(cls_score_list, axis=-1)
633
            box_list = paddle.concat(box_list, axis=1)
634 635 636 637
            box_list = batch_distance2bbox(anchor_points, box_list)
            box_list *= stride_tensor

        return cls_score_list, box_list
638 639

    def get_loss(self, head_outs, gt_meta):
G
Guanghua Yu 已提交
640
        pred_scores, pred_regs, pred_bboxes, fpn_feats = head_outs
641 642 643 644 645 646
        gt_labels = gt_meta['gt_class']
        gt_bboxes = gt_meta['gt_bbox']
        gt_scores = gt_meta['gt_score'] if 'gt_score' in gt_meta else None
        num_imgs = gt_meta['im_id'].shape[0]
        pad_gt_mask = gt_meta['pad_gt_mask']

G
Guanghua Yu 已提交
647 648 649
        anchors, _, num_anchors_list, stride_tensor_list = generate_anchors_for_grid_cell(
            fpn_feats, self.fpn_stride, self.grid_cell_scale, self.cell_offset)

650 651 652 653
        centers = bbox_center(anchors)

        # label assignment
        if gt_meta['epoch_id'] < self.static_assigner_epoch:
654
            assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner(
655 656 657 658 659 660 661 662 663 664
                anchors,
                num_anchors_list,
                gt_labels,
                gt_bboxes,
                pad_gt_mask,
                bg_index=self.num_classes,
                gt_scores=gt_scores,
                pred_bboxes=pred_bboxes.detach() * stride_tensor_list)

        else:
665
            assigned_labels, assigned_bboxes, assigned_scores = self.assigner(
666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730
                pred_scores.detach(),
                pred_bboxes.detach() * stride_tensor_list,
                centers,
                num_anchors_list,
                gt_labels,
                gt_bboxes,
                pad_gt_mask,
                bg_index=self.num_classes,
                gt_scores=gt_scores)

        assigned_bboxes /= stride_tensor_list

        centers_shape = centers.shape
        flatten_centers = centers.expand(
            [num_imgs, centers_shape[0], centers_shape[1]]).reshape([-1, 2])
        flatten_strides = stride_tensor_list.expand(
            [num_imgs, centers_shape[0], 1]).reshape([-1, 1])
        flatten_cls_preds = pred_scores.reshape([-1, self.num_classes])
        flatten_regs = pred_regs.reshape([-1, 4 * (self.reg_max + 1)])
        flatten_bboxes = pred_bboxes.reshape([-1, 4])
        flatten_bbox_targets = assigned_bboxes.reshape([-1, 4])
        flatten_labels = assigned_labels.reshape([-1])
        flatten_assigned_scores = assigned_scores.reshape(
            [-1, self.num_classes])

        pos_inds = paddle.nonzero(
            paddle.logical_and((flatten_labels >= 0),
                               (flatten_labels < self.num_classes)),
            as_tuple=False).squeeze(1)

        num_total_pos = len(pos_inds)

        if num_total_pos > 0:
            pos_bbox_targets = paddle.gather(
                flatten_bbox_targets, pos_inds, axis=0)
            pos_decode_bbox_pred = paddle.gather(
                flatten_bboxes, pos_inds, axis=0)
            pos_reg = paddle.gather(flatten_regs, pos_inds, axis=0)
            pos_strides = paddle.gather(flatten_strides, pos_inds, axis=0)
            pos_centers = paddle.gather(
                flatten_centers, pos_inds, axis=0) / pos_strides

            weight_targets = flatten_assigned_scores.detach()
            weight_targets = paddle.gather(
                weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0)

            pred_corners = pos_reg.reshape([-1, self.reg_max + 1])
            target_corners = bbox2distance(pos_centers, pos_bbox_targets,
                                           self.reg_max).reshape([-1])
            # regression loss
            loss_bbox = paddle.sum(
                self.loss_bbox(pos_decode_bbox_pred,
                               pos_bbox_targets) * weight_targets)

            # dfl loss
            loss_dfl = self.loss_dfl(
                pred_corners,
                target_corners,
                weight=weight_targets.expand([-1, 4]).reshape([-1]),
                avg_factor=4.0)
        else:
            loss_bbox = paddle.zeros([1])
            loss_dfl = paddle.zeros([1])

        avg_factor = flatten_assigned_scores.sum()
W
wangguanzhong 已提交
731
        if paddle.distributed.get_world_size() > 1:
732 733 734 735 736 737 738 739 740 741 742 743 744
            paddle.distributed.all_reduce(avg_factor)
            avg_factor = paddle.clip(
                avg_factor / paddle.distributed.get_world_size(), min=1)
        loss_vfl = self.loss_vfl(
            flatten_cls_preds, flatten_assigned_scores, avg_factor=avg_factor)

        loss_bbox = loss_bbox / avg_factor
        loss_dfl = loss_dfl / avg_factor

        loss_states = dict(
            loss_vfl=loss_vfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl)

        return loss_states
745

746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771
    def _generate_anchors(self, feats=None):
        # just use in eval time
        anchor_points = []
        stride_tensor = []
        for i, stride in enumerate(self.fpn_stride):
            if feats is not None:
                _, _, h, w = feats[i].shape
            else:
                h = math.ceil(self.eval_size[0] / stride)
                w = math.ceil(self.eval_size[1] / stride)
            shift_x = paddle.arange(end=w) + self.cell_offset
            shift_y = paddle.arange(end=h) + self.cell_offset
            shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
            anchor_point = paddle.cast(
                paddle.stack(
                    [shift_x, shift_y], axis=-1), dtype='float32')
            anchor_points.append(anchor_point.reshape([-1, 2]))
            stride_tensor.append(
                paddle.full(
                    [h * w, 1], stride, dtype='float32'))
        anchor_points = paddle.concat(anchor_points)
        stride_tensor = paddle.concat(stride_tensor)
        return anchor_points, stride_tensor

    def post_process(self, head_outs, scale_factor, export_nms=True):
        pred_scores, pred_bboxes = head_outs
772
        if not export_nms:
773
            return pred_bboxes, pred_scores
774 775
        else:
            # rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale]
776 777 778 779 780 781 782
            scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
            scale_factor = paddle.concat(
                [scale_x, scale_y, scale_x, scale_y],
                axis=-1).reshape([-1, 1, 4])
            # scale bbox to origin image size.
            pred_bboxes /= scale_factor
            bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
783
            return bbox_pred, bbox_num