pico_head.py 11.5 KB
Newer Older
G
Guanghua Yu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Normal, Constant

from ppdet.core.workspace import register
from ppdet.modeling.layers import ConvNormLayer
29
from ppdet.modeling.bbox_utils import batch_distance2bbox
30
from .simota_head import OTAVFLHead
G
Guanghua Yu 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50


@register
class PicoFeat(nn.Layer):
    """
    PicoFeat of PicoDet

    Args:
        feat_in (int): The channel number of input Tensor.
        feat_out (int): The channel number of output Tensor.
        num_convs (int): The convolution number of the LiteGFLFeat.
        norm_type (str): Normalization type, 'bn'/'sync_bn'/'gn'.
    """

    def __init__(self,
                 feat_in=256,
                 feat_out=96,
                 num_fpn_stride=3,
                 num_convs=2,
                 norm_type='bn',
51 52
                 share_cls_reg=False,
                 act='hard_swish'):
G
Guanghua Yu 已提交
53 54 55 56
        super(PicoFeat, self).__init__()
        self.num_convs = num_convs
        self.norm_type = norm_type
        self.share_cls_reg = share_cls_reg
57
        self.act = act
G
Guanghua Yu 已提交
58 59 60 61 62 63 64 65 66 67 68 69
        self.cls_convs = []
        self.reg_convs = []
        for stage_idx in range(num_fpn_stride):
            cls_subnet_convs = []
            reg_subnet_convs = []
            for i in range(self.num_convs):
                in_c = feat_in if i == 0 else feat_out
                cls_conv_dw = self.add_sublayer(
                    'cls_conv_dw{}.{}'.format(stage_idx, i),
                    ConvNormLayer(
                        ch_in=in_c,
                        ch_out=feat_out,
G
Guanghua Yu 已提交
70
                        filter_size=5,
G
Guanghua Yu 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
                        stride=1,
                        groups=feat_out,
                        norm_type=norm_type,
                        bias_on=False,
                        lr_scale=2.))
                cls_subnet_convs.append(cls_conv_dw)
                cls_conv_pw = self.add_sublayer(
                    'cls_conv_pw{}.{}'.format(stage_idx, i),
                    ConvNormLayer(
                        ch_in=in_c,
                        ch_out=feat_out,
                        filter_size=1,
                        stride=1,
                        norm_type=norm_type,
                        bias_on=False,
                        lr_scale=2.))
                cls_subnet_convs.append(cls_conv_pw)

                if not self.share_cls_reg:
                    reg_conv_dw = self.add_sublayer(
                        'reg_conv_dw{}.{}'.format(stage_idx, i),
                        ConvNormLayer(
                            ch_in=in_c,
                            ch_out=feat_out,
G
Guanghua Yu 已提交
95
                            filter_size=5,
G
Guanghua Yu 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
                            stride=1,
                            groups=feat_out,
                            norm_type=norm_type,
                            bias_on=False,
                            lr_scale=2.))
                    reg_subnet_convs.append(reg_conv_dw)
                    reg_conv_pw = self.add_sublayer(
                        'reg_conv_pw{}.{}'.format(stage_idx, i),
                        ConvNormLayer(
                            ch_in=in_c,
                            ch_out=feat_out,
                            filter_size=1,
                            stride=1,
                            norm_type=norm_type,
                            bias_on=False,
                            lr_scale=2.))
                    reg_subnet_convs.append(reg_conv_pw)
            self.cls_convs.append(cls_subnet_convs)
            self.reg_convs.append(reg_subnet_convs)

116 117 118 119 120 121 122
    def act_func(self, x):
        if self.act == "leaky_relu":
            x = F.leaky_relu(x)
        elif self.act == "hard_swish":
            x = F.hardswish(x)
        return x

G
Guanghua Yu 已提交
123 124 125 126 127
    def forward(self, fpn_feat, stage_idx):
        assert stage_idx < len(self.cls_convs)
        cls_feat = fpn_feat
        reg_feat = fpn_feat
        for i in range(len(self.cls_convs[stage_idx])):
128
            cls_feat = self.act_func(self.cls_convs[stage_idx][i](cls_feat))
G
Guanghua Yu 已提交
129
            if not self.share_cls_reg:
130
                reg_feat = self.act_func(self.reg_convs[stage_idx][i](reg_feat))
G
Guanghua Yu 已提交
131 132 133 134
        return cls_feat, reg_feat


@register
135
class PicoHead(OTAVFLHead):
G
Guanghua Yu 已提交
136 137 138
    """
    PicoHead
    Args:
139
        conv_feat (object): Instance of 'PicoFeat'
G
Guanghua Yu 已提交
140 141 142
        num_classes (int): Number of classes
        fpn_stride (list): The stride of each FPN Layer
        prior_prob (float): Used to set the bias init for the class prediction layer
143 144 145 146
        loss_class (object): Instance of VariFocalLoss.
        loss_dfl (object): Instance of DistributionFocalLoss.
        loss_bbox (object): Instance of bbox loss.
        assigner (object): Instance of label assigner.
G
Guanghua Yu 已提交
147
        reg_max: Max value of integral set :math: `{0, ..., reg_max}`
148
                n QFL setting. Default: 7.
G
Guanghua Yu 已提交
149 150
    """
    __inject__ = [
151 152
        'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox',
        'assigner', 'nms'
G
Guanghua Yu 已提交
153 154 155 156 157 158 159 160 161
    ]
    __shared__ = ['num_classes']

    def __init__(self,
                 conv_feat='PicoFeat',
                 dgqp_module=None,
                 num_classes=80,
                 fpn_stride=[8, 16, 32],
                 prior_prob=0.01,
162
                 loss_class='VariFocalLoss',
G
Guanghua Yu 已提交
163 164
                 loss_dfl='DistributionFocalLoss',
                 loss_bbox='GIoULoss',
165
                 assigner='SimOTAAssigner',
G
Guanghua Yu 已提交
166 167 168 169 170 171 172 173 174 175 176
                 reg_max=16,
                 feat_in_chan=96,
                 nms=None,
                 nms_pre=1000,
                 cell_offset=0):
        super(PicoHead, self).__init__(
            conv_feat=conv_feat,
            dgqp_module=dgqp_module,
            num_classes=num_classes,
            fpn_stride=fpn_stride,
            prior_prob=prior_prob,
177
            loss_class=loss_class,
G
Guanghua Yu 已提交
178 179
            loss_dfl=loss_dfl,
            loss_bbox=loss_bbox,
180
            assigner=assigner,
G
Guanghua Yu 已提交
181 182 183 184 185 186 187 188 189
            reg_max=reg_max,
            feat_in_chan=feat_in_chan,
            nms=nms,
            nms_pre=nms_pre,
            cell_offset=cell_offset)
        self.conv_feat = conv_feat
        self.num_classes = num_classes
        self.fpn_stride = fpn_stride
        self.prior_prob = prior_prob
190
        self.loss_vfl = loss_class
G
Guanghua Yu 已提交
191 192
        self.loss_dfl = loss_dfl
        self.loss_bbox = loss_bbox
193
        self.assigner = assigner
G
Guanghua Yu 已提交
194 195 196 197 198
        self.reg_max = reg_max
        self.feat_in_chan = feat_in_chan
        self.nms = nms
        self.nms_pre = nms_pre
        self.cell_offset = cell_offset
199 200

        self.use_sigmoid = self.loss_vfl.use_sigmoid
G
Guanghua Yu 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
        if self.use_sigmoid:
            self.cls_out_channels = self.num_classes
        else:
            self.cls_out_channels = self.num_classes + 1
        bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
        # Clear the super class initialization
        self.gfl_head_cls = None
        self.gfl_head_reg = None
        self.scales_regs = None

        self.head_cls_list = []
        self.head_reg_list = []
        for i in range(len(fpn_stride)):
            head_cls = self.add_sublayer(
                "head_cls" + str(i),
                nn.Conv2D(
                    in_channels=self.feat_in_chan,
                    out_channels=self.cls_out_channels + 4 * (self.reg_max + 1)
                    if self.conv_feat.share_cls_reg else self.cls_out_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    weight_attr=ParamAttr(initializer=Normal(
                        mean=0., std=0.01)),
                    bias_attr=ParamAttr(
                        initializer=Constant(value=bias_init_value))))
            self.head_cls_list.append(head_cls)
            if not self.conv_feat.share_cls_reg:
                head_reg = self.add_sublayer(
                    "head_reg" + str(i),
                    nn.Conv2D(
                        in_channels=self.feat_in_chan,
                        out_channels=4 * (self.reg_max + 1),
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        weight_attr=ParamAttr(initializer=Normal(
                            mean=0., std=0.01)),
                        bias_attr=ParamAttr(initializer=Constant(value=0))))
                self.head_reg_list.append(head_reg)

242
    def forward(self, fpn_feats, export_post_process=True):
G
Guanghua Yu 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
        assert len(fpn_feats) == len(
            self.fpn_stride
        ), "The size of fpn_feats is not equal to size of fpn_stride"
        cls_logits_list = []
        bboxes_reg_list = []
        for i, fpn_feat in enumerate(fpn_feats):
            conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat, i)
            if self.conv_feat.share_cls_reg:
                cls_logits = self.head_cls_list[i](conv_cls_feat)
                cls_score, bbox_pred = paddle.split(
                    cls_logits,
                    [self.cls_out_channels, 4 * (self.reg_max + 1)],
                    axis=1)
            else:
                cls_score = self.head_cls_list[i](conv_cls_feat)
                bbox_pred = self.head_reg_list[i](conv_reg_feat)
259

G
Guanghua Yu 已提交
260 261 262 263
            if self.dgqp_module:
                quality_score = self.dgqp_module(bbox_pred)
                cls_score = F.sigmoid(cls_score) * quality_score

264
            if not export_post_process:
265 266 267 268 269 270 271
                # Now only supports batch size = 1 in deploy
                # TODO(ygh): support batch size > 1
                cls_score = F.sigmoid(cls_score).reshape(
                    [1, self.cls_out_channels, -1]).transpose([0, 2, 1])
                bbox_pred = bbox_pred.reshape([1, (self.reg_max + 1) * 4,
                                               -1]).transpose([0, 2, 1])
            elif not self.training:
G
Guanghua Yu 已提交
272
                cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1]))
G
Guanghua Yu 已提交
273
                bbox_pred = bbox_pred.transpose([0, 2, 3, 1])
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
                stride = self.fpn_stride[i]
                b, cell_h, cell_w, _ = paddle.shape(cls_score)
                y, x = self.get_single_level_center_point(
                    [cell_h, cell_w], stride, cell_offset=self.cell_offset)
                center_points = paddle.stack([x, y], axis=-1)
                cls_score = cls_score.reshape([b, -1, self.cls_out_channels])
                bbox_pred = self.distribution_project(bbox_pred) * stride
                bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4])

                # NOTE: If keep_ratio=False and image shape value that 
                # multiples of 32, distance2bbox not set max_shapes parameter
                # to speed up model prediction. If need to set max_shapes,
                # please use inputs['im_shape']. 
                bbox_pred = batch_distance2bbox(
                    center_points, bbox_pred, max_shapes=None)
G
Guanghua Yu 已提交
289 290 291 292 293

            cls_logits_list.append(cls_score)
            bboxes_reg_list.append(bbox_pred)

        return (cls_logits_list, bboxes_reg_list)