yolo_fpn.py 36.6 KB
Newer Older
F
Feng Ni 已提交
1 2 3 4 5 6 7 8 9 10 11 12
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
# 
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
Q
qingqing01 已提交
13 14 15 16 17 18
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
G
Guanghua Yu 已提交
19
from ppdet.modeling.layers import DropBlock
20
from ..backbones.darknet import ConvBNLayer
21
from ..shape_spec import ShapeSpec
F
Feng Ni 已提交
22
from ..backbones.csp_darknet import BaseConv, DWConv, CSPLayer
23

F
Feng Ni 已提交
24
__all__ = ['YOLOv3FPN', 'PPYOLOFPN', 'PPYOLOTinyFPN', 'PPYOLOPAN', 'YOLOCSPPAN']
25

Q
qingqing01 已提交
26

27
def add_coord(x, data_format):
28
    b = paddle.shape(x)[0]
29
    if data_format == 'NCHW':
30
        h, w = x.shape[2], x.shape[3]
W
wangxinxin08 已提交
31
    else:
32
        h, w = x.shape[1], x.shape[2]
W
wangxinxin08 已提交
33

34 35
    gx = paddle.cast(paddle.arange(w) / ((w - 1.) * 2.0) - 1., x.dtype)
    gy = paddle.cast(paddle.arange(h) / ((h - 1.) * 2.0) - 1., x.dtype)
W
wangxinxin08 已提交
36

37
    if data_format == 'NCHW':
W
wangxinxin08 已提交
38
        gx = gx.reshape([1, 1, 1, w]).expand([b, 1, h, w])
W
wangxinxin08 已提交
39 40
        gy = gy.reshape([1, 1, h, 1]).expand([b, 1, h, w])
    else:
W
wangxinxin08 已提交
41
        gx = gx.reshape([1, 1, w, 1]).expand([b, h, w, 1])
W
wangxinxin08 已提交
42 43
        gy = gy.reshape([1, h, 1, 1]).expand([b, h, w, 1])

W
wangxinxin08 已提交
44 45
    gx.stop_gradient = True
    gy.stop_gradient = True
W
wangxinxin08 已提交
46 47 48
    return gx, gy


Q
qingqing01 已提交
49
class YoloDetBlock(nn.Layer):
50 51 52 53 54 55 56
    def __init__(self,
                 ch_in,
                 channel,
                 norm_type,
                 freeze_norm=False,
                 name='',
                 data_format='NCHW'):
W
wangxinxin08 已提交
57 58 59 60 61 62 63
        """
        YOLODetBlock layer for yolov3, see https://arxiv.org/abs/1804.02767

        Args:
            ch_in (int): input channel
            channel (int): base channel
            norm_type (str): batch norm type
64
            freeze_norm (bool): whether to freeze norm, default False
W
wangxinxin08 已提交
65 66 67
            name (str): layer name
            data_format (str): data format, NCHW or NHWC
        """
Q
qingqing01 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
        super(YoloDetBlock, self).__init__()
        self.ch_in = ch_in
        self.channel = channel
        assert channel % 2 == 0, \
            "channel {} cannot be divided by 2".format(channel)
        conv_def = [
            ['conv0', ch_in, channel, 1, '.0.0'],
            ['conv1', channel, channel * 2, 3, '.0.1'],
            ['conv2', channel * 2, channel, 1, '.1.0'],
            ['conv3', channel, channel * 2, 3, '.1.1'],
            ['route', channel * 2, channel, 1, '.2'],
        ]

        self.conv_module = nn.Sequential()
        for idx, (conv_name, ch_in, ch_out, filter_size,
                  post_name) in enumerate(conv_def):
            self.conv_module.add_sublayer(
                conv_name,
                ConvBNLayer(
                    ch_in=ch_in,
                    ch_out=ch_out,
                    filter_size=filter_size,
                    padding=(filter_size - 1) // 2,
                    norm_type=norm_type,
92
                    freeze_norm=freeze_norm,
93
                    data_format=data_format,
Q
qingqing01 已提交
94 95 96 97 98 99 100 101
                    name=name + post_name))

        self.tip = ConvBNLayer(
            ch_in=channel,
            ch_out=channel * 2,
            filter_size=3,
            padding=1,
            norm_type=norm_type,
102
            freeze_norm=freeze_norm,
103
            data_format=data_format,
Q
qingqing01 已提交
104 105 106 107 108 109 110 111
            name=name + '.tip')

    def forward(self, inputs):
        route = self.conv_module(inputs)
        tip = self.tip(route)
        return route, tip


W
wangxinxin08 已提交
112
class SPP(nn.Layer):
113 114 115 116 117
    def __init__(self,
                 ch_in,
                 ch_out,
                 k,
                 pool_size,
S
shangliang Xu 已提交
118
                 norm_type='bn',
119 120
                 freeze_norm=False,
                 name='',
W
wangxinxin08 已提交
121
                 act='leaky',
122
                 data_format='NCHW'):
W
wangxinxin08 已提交
123 124 125 126 127 128 129 130
        """
        SPP layer, which consist of four pooling layer follwed by conv layer

        Args:
            ch_in (int): input channel of conv layer
            ch_out (int): output channel of conv layer
            k (int): kernel size of conv layer
            norm_type (str): batch norm type
131
            freeze_norm (bool): whether to freeze norm, default False
W
wangxinxin08 已提交
132
            name (str): layer name
133
            act (str): activation function
W
wangxinxin08 已提交
134 135
            data_format (str): data format, NCHW or NHWC
        """
W
wangxinxin08 已提交
136 137
        super(SPP, self).__init__()
        self.pool = []
W
wangxinxin08 已提交
138
        self.data_format = data_format
W
wangxinxin08 已提交
139 140 141 142 143 144 145
        for size in pool_size:
            pool = self.add_sublayer(
                '{}.pool1'.format(name),
                nn.MaxPool2D(
                    kernel_size=size,
                    stride=1,
                    padding=size // 2,
146
                    data_format=data_format,
W
wangxinxin08 已提交
147 148 149
                    ceil_mode=False))
            self.pool.append(pool)
        self.conv = ConvBNLayer(
150 151 152 153 154
            ch_in,
            ch_out,
            k,
            padding=k // 2,
            norm_type=norm_type,
155
            freeze_norm=freeze_norm,
156
            name=name,
W
wangxinxin08 已提交
157
            act=act,
158
            data_format=data_format)
W
wangxinxin08 已提交
159 160 161 162 163

    def forward(self, x):
        outs = [x]
        for pool in self.pool:
            outs.append(pool(x))
W
wangxinxin08 已提交
164 165 166 167 168
        if self.data_format == "NCHW":
            y = paddle.concat(outs, axis=1)
        else:
            y = paddle.concat(outs, axis=-1)

W
wangxinxin08 已提交
169 170 171 172 173
        y = self.conv(y)
        return y


class CoordConv(nn.Layer):
174 175 176 177 178 179
    def __init__(self,
                 ch_in,
                 ch_out,
                 filter_size,
                 padding,
                 norm_type,
180 181
                 freeze_norm=False,
                 name='',
182
                 data_format='NCHW'):
W
wangxinxin08 已提交
183
        """
184
        CoordConv layer, see https://arxiv.org/abs/1807.03247
W
wangxinxin08 已提交
185 186 187 188 189 190 191 192 193 194 195

        Args:
            ch_in (int): input channel
            ch_out (int): output channel
            filter_size (int): filter size, default 3
            padding (int): padding size, default 0
            norm_type (str): batch norm type, default bn
            name (str): layer name
            data_format (str): data format, NCHW or NHWC

        """
W
wangxinxin08 已提交
196 197 198 199 200 201 202
        super(CoordConv, self).__init__()
        self.conv = ConvBNLayer(
            ch_in + 2,
            ch_out,
            filter_size=filter_size,
            padding=padding,
            norm_type=norm_type,
203
            freeze_norm=freeze_norm,
204
            data_format=data_format,
W
wangxinxin08 已提交
205
            name=name)
206
        self.data_format = data_format
W
wangxinxin08 已提交
207 208

    def forward(self, x):
209
        gx, gy = add_coord(x, self.data_format)
210 211 212 213
        if self.data_format == 'NCHW':
            y = paddle.concat([x, gx, gy], axis=1)
        else:
            y = paddle.concat([x, gx, gy], axis=-1)
W
wangxinxin08 已提交
214 215 216 217 218
        y = self.conv(y)
        return y


class PPYOLODetBlock(nn.Layer):
219
    def __init__(self, cfg, name, data_format='NCHW'):
W
wangxinxin08 已提交
220 221 222 223 224 225 226 227
        """
        PPYOLODetBlock layer

        Args:
            cfg (list): layer configs for this block
            name (str): block name
            data_format (str): data format, NCHW or NHWC
        """
W
wangxinxin08 已提交
228 229 230
        super(PPYOLODetBlock, self).__init__()
        self.conv_module = nn.Sequential()
        for idx, (conv_name, layer, args, kwargs) in enumerate(cfg[:-1]):
231 232
            kwargs.update(
                name='{}.{}'.format(name, conv_name), data_format=data_format)
W
wangxinxin08 已提交
233 234 235
            self.conv_module.add_sublayer(conv_name, layer(*args, **kwargs))

        conv_name, layer, args, kwargs = cfg[-1]
236 237
        kwargs.update(
            name='{}.{}'.format(name, conv_name), data_format=data_format)
W
wangxinxin08 已提交
238 239 240 241 242 243 244 245
        self.tip = layer(*args, **kwargs)

    def forward(self, inputs):
        route = self.conv_module(inputs)
        tip = self.tip(route)
        return route, tip


K
Kaipeng Deng 已提交
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
class PPYOLOTinyDetBlock(nn.Layer):
    def __init__(self,
                 ch_in,
                 ch_out,
                 name,
                 drop_block=False,
                 block_size=3,
                 keep_prob=0.9,
                 data_format='NCHW'):
        """
        PPYOLO Tiny DetBlock layer
        Args:
            ch_in (list): input channel number
            ch_out (list): output channel number
            name (str): block name
            drop_block: whether user DropBlock
            block_size: drop block size
            keep_prob: probability to keep block in DropBlock
            data_format (str): data format, NCHW or NHWC
        """
        super(PPYOLOTinyDetBlock, self).__init__()
        self.drop_block_ = drop_block
        self.conv_module = nn.Sequential()

        cfgs = [
            # name, in channels, out channels, filter_size, 
            # stride, padding, groups
            ['.0', ch_in, ch_out, 1, 1, 0, 1],
            ['.1', ch_out, ch_out, 5, 1, 2, ch_out],
            ['.2', ch_out, ch_out, 1, 1, 0, 1],
            ['.route', ch_out, ch_out, 5, 1, 2, ch_out],
        ]
        for cfg in cfgs:
            conv_name, conv_ch_in, conv_ch_out, filter_size, stride, padding, \
                    groups = cfg
            self.conv_module.add_sublayer(
                name + conv_name,
                ConvBNLayer(
                    ch_in=conv_ch_in,
                    ch_out=conv_ch_out,
                    filter_size=filter_size,
                    stride=stride,
                    padding=padding,
                    groups=groups,
                    name=name + conv_name))

        self.tip = ConvBNLayer(
            ch_in=ch_out,
            ch_out=ch_out,
            filter_size=1,
            stride=1,
            padding=0,
            groups=1,
            name=name + conv_name)

        if self.drop_block_:
            self.drop_block = DropBlock(
                block_size=block_size,
                keep_prob=keep_prob,
                data_format=data_format,
                name=name + '.dropblock')

    def forward(self, inputs):
        if self.drop_block_:
            inputs = self.drop_block(inputs)
        route = self.conv_module(inputs)
        tip = self.tip(route)
        return route, tip


W
wangxinxin08 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
class PPYOLODetBlockCSP(nn.Layer):
    def __init__(self,
                 cfg,
                 ch_in,
                 ch_out,
                 act,
                 norm_type,
                 name,
                 data_format='NCHW'):
        """
        PPYOLODetBlockCSP layer

        Args:
            cfg (list): layer configs for this block
            ch_in (int): input channel
            ch_out (int): output channel
            act (str): default mish
            name (str): block name
            data_format (str): data format, NCHW or NHWC
        """
        super(PPYOLODetBlockCSP, self).__init__()
        self.data_format = data_format
        self.conv1 = ConvBNLayer(
            ch_in,
            ch_out,
            1,
            padding=0,
            act=act,
            norm_type=norm_type,
            name=name + '.left',
            data_format=data_format)
        self.conv2 = ConvBNLayer(
            ch_in,
            ch_out,
            1,
            padding=0,
            act=act,
            norm_type=norm_type,
            name=name + '.right',
            data_format=data_format)
        self.conv3 = ConvBNLayer(
            ch_out * 2,
            ch_out * 2,
            1,
            padding=0,
            act=act,
            norm_type=norm_type,
            name=name,
            data_format=data_format)
        self.conv_module = nn.Sequential()
        for idx, (layer_name, layer, args, kwargs) in enumerate(cfg):
            kwargs.update(name=name + layer_name, data_format=data_format)
            self.conv_module.add_sublayer(layer_name, layer(*args, **kwargs))

    def forward(self, inputs):
        conv_left = self.conv1(inputs)
        conv_right = self.conv2(inputs)
        conv_left = self.conv_module(conv_left)
        if self.data_format == 'NCHW':
            conv = paddle.concat([conv_left, conv_right], axis=1)
        else:
            conv = paddle.concat([conv_left, conv_right], axis=-1)

        conv = self.conv3(conv)
        return conv, conv


Q
qingqing01 已提交
383 384 385
@register
@serializable
class YOLOv3FPN(nn.Layer):
386
    __shared__ = ['norm_type', 'data_format']
Q
qingqing01 已提交
387

388 389 390
    def __init__(self,
                 in_channels=[256, 512, 1024],
                 norm_type='bn',
391
                 freeze_norm=False,
392
                 data_format='NCHW'):
W
wangxinxin08 已提交
393 394 395 396 397 398 399 400 401
        """
        YOLOv3FPN layer

        Args:
            in_channels (list): input channels for fpn
            norm_type (str): batch norm type, default bn
            data_format (str): data format, NCHW or NHWC

        """
Q
qingqing01 已提交
402
        super(YOLOv3FPN, self).__init__()
403 404 405 406 407
        assert len(in_channels) > 0, "in_channels length should > 0"
        self.in_channels = in_channels
        self.num_blocks = len(in_channels)

        self._out_channels = []
Q
qingqing01 已提交
408 409
        self.yolo_blocks = []
        self.routes = []
410
        self.data_format = data_format
Q
qingqing01 已提交
411 412
        for i in range(self.num_blocks):
            name = 'yolo_block.{}'.format(i)
413 414 415
            in_channel = in_channels[-i - 1]
            if i > 0:
                in_channel += 512 // (2**i)
Q
qingqing01 已提交
416 417 418
            yolo_block = self.add_sublayer(
                name,
                YoloDetBlock(
419
                    in_channel,
Q
qingqing01 已提交
420 421
                    channel=512 // (2**i),
                    norm_type=norm_type,
422
                    freeze_norm=freeze_norm,
423
                    data_format=data_format,
Q
qingqing01 已提交
424 425
                    name=name))
            self.yolo_blocks.append(yolo_block)
426 427
            # tip layer output channel doubled
            self._out_channels.append(1024 // (2**i))
Q
qingqing01 已提交
428 429 430 431 432 433 434 435 436 437 438 439

            if i < self.num_blocks - 1:
                name = 'yolo_transition.{}'.format(i)
                route = self.add_sublayer(
                    name,
                    ConvBNLayer(
                        ch_in=512 // (2**i),
                        ch_out=256 // (2**i),
                        filter_size=1,
                        stride=1,
                        padding=0,
                        norm_type=norm_type,
440
                        freeze_norm=freeze_norm,
441
                        data_format=data_format,
Q
qingqing01 已提交
442 443 444
                        name=name))
                self.routes.append(route)

445
    def forward(self, blocks, for_mot=False):
Q
qingqing01 已提交
446 447 448
        assert len(blocks) == self.num_blocks
        blocks = blocks[::-1]
        yolo_feats = []
449 450

        # add embedding features output for multi-object tracking model
451 452
        if for_mot:
            emb_feats = []
453

Q
qingqing01 已提交
454 455
        for i, block in enumerate(blocks):
            if i > 0:
456 457 458 459
                if self.data_format == 'NCHW':
                    block = paddle.concat([route, block], axis=1)
                else:
                    block = paddle.concat([route, block], axis=-1)
Q
qingqing01 已提交
460 461 462
            route, tip = self.yolo_blocks[i](block)
            yolo_feats.append(tip)

463
            if for_mot:
464
                # add embedding features output
465 466
                emb_feats.append(route)

Q
qingqing01 已提交
467 468
            if i < self.num_blocks - 1:
                route = self.routes[i](route)
469 470
                route = F.interpolate(
                    route, scale_factor=2., data_format=self.data_format)
Q
qingqing01 已提交
471

472 473 474 475
        if for_mot:
            return {'yolo_feats': yolo_feats, 'emb_feats': emb_feats}
        else:
            return yolo_feats
W
wangxinxin08 已提交
476

477 478 479 480 481 482 483 484
    @classmethod
    def from_config(cls, cfg, input_shape):
        return {'in_channels': [i.channels for i in input_shape], }

    @property
    def out_shape(self):
        return [ShapeSpec(channels=c) for c in self._out_channels]

W
wangxinxin08 已提交
485 486 487 488

@register
@serializable
class PPYOLOFPN(nn.Layer):
489
    __shared__ = ['norm_type', 'data_format']
W
wangxinxin08 已提交
490

491 492 493
    def __init__(self,
                 in_channels=[512, 1024, 2048],
                 norm_type='bn',
494
                 freeze_norm=False,
495
                 data_format='NCHW',
W
wangxinxin08 已提交
496
                 coord_conv=False,
497
                 conv_block_num=2,
W
wangxinxin08 已提交
498 499 500 501
                 drop_block=False,
                 block_size=3,
                 keep_prob=0.9,
                 spp=False):
W
wangxinxin08 已提交
502 503 504 505 506 507 508
        """
        PPYOLOFPN layer

        Args:
            in_channels (list): input channels for fpn
            norm_type (str): batch norm type, default bn
            data_format (str): data format, NCHW or NHWC
W
wangxinxin08 已提交
509 510 511 512 513 514
            coord_conv (bool): whether use CoordConv or not
            conv_block_num (int): conv block num of each pan block
            drop_block (bool): whether use DropBlock or not
            block_size (int): block size of DropBlock
            keep_prob (float): keep probability of DropBlock
            spp (bool): whether use spp or not
W
wangxinxin08 已提交
515 516

        """
W
wangxinxin08 已提交
517
        super(PPYOLOFPN, self).__init__()
518 519 520
        assert len(in_channels) > 0, "in_channels length should > 0"
        self.in_channels = in_channels
        self.num_blocks = len(in_channels)
W
wangxinxin08 已提交
521
        # parse kwargs
W
wangxinxin08 已提交
522 523 524 525 526 527
        self.coord_conv = coord_conv
        self.drop_block = drop_block
        self.block_size = block_size
        self.keep_prob = keep_prob
        self.spp = spp
        self.conv_block_num = conv_block_num
W
wangxinxin08 已提交
528
        self.data_format = data_format
W
wangxinxin08 已提交
529 530 531 532 533 534 535 536 537 538 539 540 541
        if self.coord_conv:
            ConvLayer = CoordConv
        else:
            ConvLayer = ConvBNLayer

        if self.drop_block:
            dropblock_cfg = [[
                'dropblock', DropBlock, [self.block_size, self.keep_prob],
                dict()
            ]]
        else:
            dropblock_cfg = []

542
        self._out_channels = []
W
wangxinxin08 已提交
543 544
        self.yolo_blocks = []
        self.routes = []
545 546 547
        for i, ch_in in enumerate(self.in_channels[::-1]):
            if i > 0:
                ch_in += 512 // (2**i)
W
wangxinxin08 已提交
548
            channel = 64 * (2**self.num_blocks) // (2**i)
W
wangxinxin08 已提交
549 550 551 552 553 554 555
            base_cfg = []
            c_in, c_out = ch_in, channel
            for j in range(self.conv_block_num):
                base_cfg += [
                    [
                        'conv{}'.format(2 * j), ConvLayer, [c_in, c_out, 1],
                        dict(
556 557 558
                            padding=0,
                            norm_type=norm_type,
                            freeze_norm=freeze_norm)
W
wangxinxin08 已提交
559 560 561 562
                    ],
                    [
                        'conv{}'.format(2 * j + 1), ConvBNLayer,
                        [c_out, c_out * 2, 3], dict(
563 564 565
                            padding=1,
                            norm_type=norm_type,
                            freeze_norm=freeze_norm)
W
wangxinxin08 已提交
566 567 568 569 570 571
                    ],
                ]
                c_in, c_out = c_out * 2, c_out

            base_cfg += [[
                'route', ConvLayer, [c_in, c_out, 1], dict(
572
                    padding=0, norm_type=norm_type, freeze_norm=freeze_norm)
W
wangxinxin08 已提交
573 574
            ], [
                'tip', ConvLayer, [c_out, c_out * 2, 3], dict(
575
                    padding=1, norm_type=norm_type, freeze_norm=freeze_norm)
W
wangxinxin08 已提交
576 577 578 579 580 581 582
            ]]

            if self.conv_block_num == 2:
                if i == 0:
                    if self.spp:
                        spp_cfg = [[
                            'spp', SPP, [channel * 4, channel, 1], dict(
583 584 585
                                pool_size=[5, 9, 13],
                                norm_type=norm_type,
                                freeze_norm=freeze_norm)
W
wangxinxin08 已提交
586 587 588 589 590 591 592 593 594
                        ]]
                    else:
                        spp_cfg = []
                    cfg = base_cfg[0:3] + spp_cfg + base_cfg[
                        3:4] + dropblock_cfg + base_cfg[4:6]
                else:
                    cfg = base_cfg[0:2] + dropblock_cfg + base_cfg[2:6]
            elif self.conv_block_num == 0:
                if self.spp and i == 0:
W
wangxinxin08 已提交
595
                    spp_cfg = [[
W
wangxinxin08 已提交
596
                        'spp', SPP, [c_in * 4, c_in, 1], dict(
597 598 599
                            pool_size=[5, 9, 13],
                            norm_type=norm_type,
                            freeze_norm=freeze_norm)
W
wangxinxin08 已提交
600 601 602
                    ]]
                else:
                    spp_cfg = []
W
wangxinxin08 已提交
603
                cfg = spp_cfg + dropblock_cfg + base_cfg
W
wangxinxin08 已提交
604 605 606
            name = 'yolo_block.{}'.format(i)
            yolo_block = self.add_sublayer(name, PPYOLODetBlock(cfg, name))
            self.yolo_blocks.append(yolo_block)
607
            self._out_channels.append(channel * 2)
W
wangxinxin08 已提交
608 609 610 611 612 613
            if i < self.num_blocks - 1:
                name = 'yolo_transition.{}'.format(i)
                route = self.add_sublayer(
                    name,
                    ConvBNLayer(
                        ch_in=channel,
W
wangxinxin08 已提交
614
                        ch_out=256 // (2**i),
W
wangxinxin08 已提交
615 616 617 618
                        filter_size=1,
                        stride=1,
                        padding=0,
                        norm_type=norm_type,
619
                        freeze_norm=freeze_norm,
620
                        data_format=data_format,
W
wangxinxin08 已提交
621 622 623
                        name=name))
                self.routes.append(route)

624
    def forward(self, blocks, for_mot=False):
W
wangxinxin08 已提交
625 626 627
        assert len(blocks) == self.num_blocks
        blocks = blocks[::-1]
        yolo_feats = []
628 629

        # add embedding features output for multi-object tracking model
630 631
        if for_mot:
            emb_feats = []
632

W
wangxinxin08 已提交
633 634
        for i, block in enumerate(blocks):
            if i > 0:
635 636 637 638
                if self.data_format == 'NCHW':
                    block = paddle.concat([route, block], axis=1)
                else:
                    block = paddle.concat([route, block], axis=-1)
W
wangxinxin08 已提交
639 640 641
            route, tip = self.yolo_blocks[i](block)
            yolo_feats.append(tip)

642
            if for_mot:
643
                # add embedding features output
644 645
                emb_feats.append(route)

W
wangxinxin08 已提交
646 647
            if i < self.num_blocks - 1:
                route = self.routes[i](route)
648 649
                route = F.interpolate(
                    route, scale_factor=2., data_format=self.data_format)
W
wangxinxin08 已提交
650

651 652 653 654
        if for_mot:
            return {'yolo_feats': yolo_feats, 'emb_feats': emb_feats}
        else:
            return yolo_feats
655 656 657 658 659 660 661 662

    @classmethod
    def from_config(cls, cfg, input_shape):
        return {'in_channels': [i.channels for i in input_shape], }

    @property
    def out_shape(self):
        return [ShapeSpec(channels=c) for c in self._out_channels]
K
Kaipeng Deng 已提交
663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742


@register
@serializable
class PPYOLOTinyFPN(nn.Layer):
    __shared__ = ['norm_type', 'data_format']

    def __init__(self,
                 in_channels=[80, 56, 34],
                 detection_block_channels=[160, 128, 96],
                 norm_type='bn',
                 data_format='NCHW',
                 **kwargs):
        """
        PPYOLO Tiny FPN layer
        Args:
            in_channels (list): input channels for fpn
            detection_block_channels (list): channels in fpn
            norm_type (str): batch norm type, default bn
            data_format (str): data format, NCHW or NHWC
            kwargs: extra key-value pairs, such as parameter of DropBlock and spp 
        """
        super(PPYOLOTinyFPN, self).__init__()
        assert len(in_channels) > 0, "in_channels length should > 0"
        self.in_channels = in_channels[::-1]
        assert len(detection_block_channels
                   ) > 0, "detection_block_channelslength should > 0"
        self.detection_block_channels = detection_block_channels
        self.data_format = data_format
        self.num_blocks = len(in_channels)
        # parse kwargs
        self.drop_block = kwargs.get('drop_block', False)
        self.block_size = kwargs.get('block_size', 3)
        self.keep_prob = kwargs.get('keep_prob', 0.9)

        self.spp_ = kwargs.get('spp', False)
        if self.spp_:
            self.spp = SPP(self.in_channels[0] * 4,
                           self.in_channels[0],
                           k=1,
                           pool_size=[5, 9, 13],
                           norm_type=norm_type,
                           name='spp')

        self._out_channels = []
        self.yolo_blocks = []
        self.routes = []
        for i, (
                ch_in, ch_out
        ) in enumerate(zip(self.in_channels, self.detection_block_channels)):
            name = 'yolo_block.{}'.format(i)
            if i > 0:
                ch_in += self.detection_block_channels[i - 1]
            yolo_block = self.add_sublayer(
                name,
                PPYOLOTinyDetBlock(
                    ch_in,
                    ch_out,
                    name,
                    drop_block=self.drop_block,
                    block_size=self.block_size,
                    keep_prob=self.keep_prob))
            self.yolo_blocks.append(yolo_block)
            self._out_channels.append(ch_out)

            if i < self.num_blocks - 1:
                name = 'yolo_transition.{}'.format(i)
                route = self.add_sublayer(
                    name,
                    ConvBNLayer(
                        ch_in=ch_out,
                        ch_out=ch_out,
                        filter_size=1,
                        stride=1,
                        padding=0,
                        norm_type=norm_type,
                        data_format=data_format,
                        name=name))
                self.routes.append(route)

743
    def forward(self, blocks, for_mot=False):
K
Kaipeng Deng 已提交
744 745 746
        assert len(blocks) == self.num_blocks
        blocks = blocks[::-1]
        yolo_feats = []
747 748 749 750 751

        # add embedding features output for multi-object tracking model
        if for_mot:
            emb_feats = []

K
Kaipeng Deng 已提交
752 753 754 755 756 757 758 759 760 761 762 763
        for i, block in enumerate(blocks):
            if i == 0 and self.spp_:
                block = self.spp(block)

            if i > 0:
                if self.data_format == 'NCHW':
                    block = paddle.concat([route, block], axis=1)
                else:
                    block = paddle.concat([route, block], axis=-1)
            route, tip = self.yolo_blocks[i](block)
            yolo_feats.append(tip)

764 765 766 767
            if for_mot:
                # add embedding features output
                emb_feats.append(route)

K
Kaipeng Deng 已提交
768 769 770 771 772
            if i < self.num_blocks - 1:
                route = self.routes[i](route)
                route = F.interpolate(
                    route, scale_factor=2., data_format=self.data_format)

773 774 775 776
        if for_mot:
            return {'yolo_feats': yolo_feats, 'emb_feats': emb_feats}
        else:
            return yolo_feats
K
Kaipeng Deng 已提交
777 778 779 780 781 782 783 784

    @classmethod
    def from_config(cls, cfg, input_shape):
        return {'in_channels': [i.channels for i in input_shape], }

    @property
    def out_shape(self):
        return [ShapeSpec(channels=c) for c in self._out_channels]
W
wangxinxin08 已提交
785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937


@register
@serializable
class PPYOLOPAN(nn.Layer):
    __shared__ = ['norm_type', 'data_format']

    def __init__(self,
                 in_channels=[512, 1024, 2048],
                 norm_type='bn',
                 data_format='NCHW',
                 act='mish',
                 conv_block_num=3,
                 drop_block=False,
                 block_size=3,
                 keep_prob=0.9,
                 spp=False):
        """
        PPYOLOPAN layer with SPP, DropBlock and CSP connection.

        Args:
            in_channels (list): input channels for fpn
            norm_type (str): batch norm type, default bn
            data_format (str): data format, NCHW or NHWC
            act (str): activation function, default mish
            conv_block_num (int): conv block num of each pan block
            drop_block (bool): whether use DropBlock or not
            block_size (int): block size of DropBlock
            keep_prob (float): keep probability of DropBlock
            spp (bool): whether use spp or not

        """
        super(PPYOLOPAN, self).__init__()
        assert len(in_channels) > 0, "in_channels length should > 0"
        self.in_channels = in_channels
        self.num_blocks = len(in_channels)
        # parse kwargs
        self.drop_block = drop_block
        self.block_size = block_size
        self.keep_prob = keep_prob
        self.spp = spp
        self.conv_block_num = conv_block_num
        self.data_format = data_format
        if self.drop_block:
            dropblock_cfg = [[
                'dropblock', DropBlock, [self.block_size, self.keep_prob],
                dict()
            ]]
        else:
            dropblock_cfg = []

        # fpn
        self.fpn_blocks = []
        self.fpn_routes = []
        fpn_channels = []
        for i, ch_in in enumerate(self.in_channels[::-1]):
            if i > 0:
                ch_in += 512 // (2**(i - 1))
            channel = 512 // (2**i)
            base_cfg = []
            for j in range(self.conv_block_num):
                base_cfg += [
                    # name, layer, args
                    [
                        '{}.0'.format(j), ConvBNLayer, [channel, channel, 1],
                        dict(
                            padding=0, act=act, norm_type=norm_type)
                    ],
                    [
                        '{}.1'.format(j), ConvBNLayer, [channel, channel, 3],
                        dict(
                            padding=1, act=act, norm_type=norm_type)
                    ]
                ]

            if i == 0 and self.spp:
                base_cfg[3] = [
                    'spp', SPP, [channel * 4, channel, 1], dict(
                        pool_size=[5, 9, 13], act=act, norm_type=norm_type)
                ]

            cfg = base_cfg[:4] + dropblock_cfg + base_cfg[4:]
            name = 'fpn.{}'.format(i)
            fpn_block = self.add_sublayer(
                name,
                PPYOLODetBlockCSP(cfg, ch_in, channel, act, norm_type, name,
                                  data_format))
            self.fpn_blocks.append(fpn_block)
            fpn_channels.append(channel * 2)
            if i < self.num_blocks - 1:
                name = 'fpn_transition.{}'.format(i)
                route = self.add_sublayer(
                    name,
                    ConvBNLayer(
                        ch_in=channel * 2,
                        ch_out=channel,
                        filter_size=1,
                        stride=1,
                        padding=0,
                        act=act,
                        norm_type=norm_type,
                        data_format=data_format,
                        name=name))
                self.fpn_routes.append(route)
        # pan
        self.pan_blocks = []
        self.pan_routes = []
        self._out_channels = [512 // (2**(self.num_blocks - 2)), ]
        for i in reversed(range(self.num_blocks - 1)):
            name = 'pan_transition.{}'.format(i)
            route = self.add_sublayer(
                name,
                ConvBNLayer(
                    ch_in=fpn_channels[i + 1],
                    ch_out=fpn_channels[i + 1],
                    filter_size=3,
                    stride=2,
                    padding=1,
                    act=act,
                    norm_type=norm_type,
                    data_format=data_format,
                    name=name))
            self.pan_routes = [route, ] + self.pan_routes
            base_cfg = []
            ch_in = fpn_channels[i] + fpn_channels[i + 1]
            channel = 512 // (2**i)
            for j in range(self.conv_block_num):
                base_cfg += [
                    # name, layer, args
                    [
                        '{}.0'.format(j), ConvBNLayer, [channel, channel, 1],
                        dict(
                            padding=0, act=act, norm_type=norm_type)
                    ],
                    [
                        '{}.1'.format(j), ConvBNLayer, [channel, channel, 3],
                        dict(
                            padding=1, act=act, norm_type=norm_type)
                    ]
                ]

            cfg = base_cfg[:4] + dropblock_cfg + base_cfg[4:]
            name = 'pan.{}'.format(i)
            pan_block = self.add_sublayer(
                name,
                PPYOLODetBlockCSP(cfg, ch_in, channel, act, norm_type, name,
                                  data_format))

            self.pan_blocks = [pan_block, ] + self.pan_blocks
            self._out_channels.append(channel * 2)

        self._out_channels = self._out_channels[::-1]

938
    def forward(self, blocks, for_mot=False):
W
wangxinxin08 已提交
939 940 941
        assert len(blocks) == self.num_blocks
        blocks = blocks[::-1]
        fpn_feats = []
942 943 944 945 946

        # add embedding features output for multi-object tracking model
        if for_mot:
            emb_feats = []

W
wangxinxin08 已提交
947 948 949 950 951 952 953 954 955
        for i, block in enumerate(blocks):
            if i > 0:
                if self.data_format == 'NCHW':
                    block = paddle.concat([route, block], axis=1)
                else:
                    block = paddle.concat([route, block], axis=-1)
            route, tip = self.fpn_blocks[i](block)
            fpn_feats.append(tip)

956 957 958 959
            if for_mot:
                # add embedding features output
                emb_feats.append(route)

W
wangxinxin08 已提交
960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977
            if i < self.num_blocks - 1:
                route = self.fpn_routes[i](route)
                route = F.interpolate(
                    route, scale_factor=2., data_format=self.data_format)

        pan_feats = [fpn_feats[-1], ]
        route = fpn_feats[self.num_blocks - 1]
        for i in reversed(range(self.num_blocks - 1)):
            block = fpn_feats[i]
            route = self.pan_routes[i](route)
            if self.data_format == 'NCHW':
                block = paddle.concat([route, block], axis=1)
            else:
                block = paddle.concat([route, block], axis=-1)

            route, tip = self.pan_blocks[i](block)
            pan_feats.append(tip)

978 979 980 981
        if for_mot:
            return {'yolo_feats': pan_feats[::-1], 'emb_feats': emb_feats}
        else:
            return pan_feats[::-1]
W
wangxinxin08 已提交
982 983 984 985 986 987 988 989

    @classmethod
    def from_config(cls, cfg, input_shape):
        return {'in_channels': [i.channels for i in input_shape], }

    @property
    def out_shape(self):
        return [ShapeSpec(channels=c) for c in self._out_channels]
F
Feng Ni 已提交
990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088


@register
@serializable
class YOLOCSPPAN(nn.Layer):
    """
    YOLO CSP-PAN, used in YOLOv5 and YOLOX.
    """
    __shared__ = ['depth_mult', 'act']

    def __init__(self,
                 depth_mult=1.0,
                 in_channels=[256, 512, 1024],
                 depthwise=False,
                 act='silu'):
        super(YOLOCSPPAN, self).__init__()
        self.in_channels = in_channels
        self._out_channels = in_channels
        Conv = DWConv if depthwise else BaseConv

        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")

        # top-down fpn
        self.lateral_convs = nn.LayerList()
        self.fpn_blocks = nn.LayerList()
        for idx in range(len(in_channels) - 1, 0, -1):
            self.lateral_convs.append(
                BaseConv(
                    int(in_channels[idx]),
                    int(in_channels[idx - 1]),
                    1,
                    1,
                    act=act))
            self.fpn_blocks.append(
                CSPLayer(
                    int(in_channels[idx - 1] * 2),
                    int(in_channels[idx - 1]),
                    round(3 * depth_mult),
                    shortcut=False,
                    depthwise=depthwise,
                    act=act))

        # bottom-up pan
        self.downsample_convs = nn.LayerList()
        self.pan_blocks = nn.LayerList()
        for idx in range(len(in_channels) - 1):
            self.downsample_convs.append(
                Conv(
                    int(in_channels[idx]),
                    int(in_channels[idx]),
                    3,
                    stride=2,
                    act=act))
            self.pan_blocks.append(
                CSPLayer(
                    int(in_channels[idx] * 2),
                    int(in_channels[idx + 1]),
                    round(3 * depth_mult),
                    shortcut=False,
                    depthwise=depthwise,
                    act=act))

    def forward(self, feats, for_mot=False):
        assert len(feats) == len(self.in_channels)

        # top-down fpn
        inner_outs = [feats[-1]]
        for idx in range(len(self.in_channels) - 1, 0, -1):
            feat_heigh = inner_outs[0]
            feat_low = feats[idx - 1]
            feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](
                feat_heigh)
            inner_outs[0] = feat_heigh

            upsample_feat = self.upsample(feat_heigh)
            inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](
                paddle.concat(
                    [upsample_feat, feat_low], axis=1))
            inner_outs.insert(0, inner_out)

        # bottom-up pan
        outs = [inner_outs[0]]
        for idx in range(len(self.in_channels) - 1):
            feat_low = outs[-1]
            feat_height = inner_outs[idx + 1]
            downsample_feat = self.downsample_convs[idx](feat_low)
            out = self.pan_blocks[idx](paddle.concat(
                [downsample_feat, feat_height], axis=1))
            outs.append(out)

        return outs

    @classmethod
    def from_config(cls, cfg, input_shape):
        return {'in_channels': [i.channels for i in input_shape], }

    @property
    def out_shape(self):
        return [ShapeSpec(channels=c) for c in self._out_channels]