yolo_fpn.py 11.4 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from ppdet.core.workspace import register, serializable
20
from ..backbones.darknet import ConvBNLayer
W
wangxinxin08 已提交
21
import numpy as np
Q
qingqing01 已提交
22

23 24 25 26
from ..shape_spec import ShapeSpec

__all__ = ['YOLOv3FPN', 'PPYOLOFPN']

Q
qingqing01 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69

class YoloDetBlock(nn.Layer):
    def __init__(self, ch_in, channel, norm_type, name):
        super(YoloDetBlock, self).__init__()
        self.ch_in = ch_in
        self.channel = channel
        assert channel % 2 == 0, \
            "channel {} cannot be divided by 2".format(channel)
        conv_def = [
            ['conv0', ch_in, channel, 1, '.0.0'],
            ['conv1', channel, channel * 2, 3, '.0.1'],
            ['conv2', channel * 2, channel, 1, '.1.0'],
            ['conv3', channel, channel * 2, 3, '.1.1'],
            ['route', channel * 2, channel, 1, '.2'],
        ]

        self.conv_module = nn.Sequential()
        for idx, (conv_name, ch_in, ch_out, filter_size,
                  post_name) in enumerate(conv_def):
            self.conv_module.add_sublayer(
                conv_name,
                ConvBNLayer(
                    ch_in=ch_in,
                    ch_out=ch_out,
                    filter_size=filter_size,
                    padding=(filter_size - 1) // 2,
                    norm_type=norm_type,
                    name=name + post_name))

        self.tip = ConvBNLayer(
            ch_in=channel,
            ch_out=channel * 2,
            filter_size=3,
            padding=1,
            norm_type=norm_type,
            name=name + '.tip')

    def forward(self, inputs):
        route = self.conv_module(inputs)
        tip = self.tip(route)
        return route, tip


W
wangxinxin08 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
class SPP(nn.Layer):
    def __init__(self, ch_in, ch_out, k, pool_size, norm_type, name):
        super(SPP, self).__init__()
        self.pool = []
        for size in pool_size:
            pool = self.add_sublayer(
                '{}.pool1'.format(name),
                nn.MaxPool2D(
                    kernel_size=size,
                    stride=1,
                    padding=size // 2,
                    ceil_mode=False))
            self.pool.append(pool)
        self.conv = ConvBNLayer(
            ch_in, ch_out, k, padding=k // 2, norm_type=norm_type, name=name)

    def forward(self, x):
        outs = [x]
        for pool in self.pool:
            outs.append(pool(x))
        y = paddle.concat(outs, axis=1)
        y = self.conv(y)
        return y


class DropBlock(nn.Layer):
    def __init__(self, block_size, keep_prob, name):
        super(DropBlock, self).__init__()
        self.block_size = block_size
        self.keep_prob = keep_prob
        self.name = name

    def forward(self, x):
        if not self.training or self.keep_prob == 1:
            return x
        else:
            gamma = (1. - self.keep_prob) / (self.block_size**2)
            for s in x.shape[2:]:
                gamma *= s / (s - self.block_size + 1)

            matrix = paddle.cast(paddle.rand(x.shape, x.dtype) < gamma, x.dtype)
            mask_inv = F.max_pool2d(
                matrix, self.block_size, stride=1, padding=self.block_size // 2)
            mask = 1. - mask_inv
            y = x * mask * (mask.numel() / mask.sum())
            return y


class CoordConv(nn.Layer):
    def __init__(self, ch_in, ch_out, filter_size, padding, norm_type, name):
        super(CoordConv, self).__init__()
        self.conv = ConvBNLayer(
            ch_in + 2,
            ch_out,
            filter_size=filter_size,
            padding=padding,
            norm_type=norm_type,
            name=name)

    def forward(self, x):
        b = x.shape[0]
        h = x.shape[2]
        w = x.shape[3]

        gx = paddle.arange(w, dtype='float32') / (w - 1.) * 2.0 - 1.
        gx = gx.reshape([1, 1, 1, w]).expand([b, 1, h, w])
        gx.stop_gradient = True

        gy = paddle.arange(h, dtype='float32') / (h - 1.) * 2.0 - 1.
        gy = gy.reshape([1, 1, h, 1]).expand([b, 1, h, w])
        gy.stop_gradient = True

        y = paddle.concat([x, gx, gy], axis=1)
        y = self.conv(y)
        return y


class PPYOLODetBlock(nn.Layer):
    def __init__(self, cfg, name):
        super(PPYOLODetBlock, self).__init__()
        self.conv_module = nn.Sequential()
        for idx, (conv_name, layer, args, kwargs) in enumerate(cfg[:-1]):
            kwargs.update(name='{}.{}'.format(name, conv_name))
            self.conv_module.add_sublayer(conv_name, layer(*args, **kwargs))

        conv_name, layer, args, kwargs = cfg[-1]
        kwargs.update(name='{}.{}'.format(name, conv_name))
        self.tip = layer(*args, **kwargs)

    def forward(self, inputs):
        route = self.conv_module(inputs)
        tip = self.tip(route)
        return route, tip


Q
qingqing01 已提交
165 166 167 168 169
@register
@serializable
class YOLOv3FPN(nn.Layer):
    __shared__ = ['norm_type']

170
    def __init__(self, in_channels=[256, 512, 1024], norm_type='bn'):
Q
qingqing01 已提交
171
        super(YOLOv3FPN, self).__init__()
172 173 174 175 176
        assert len(in_channels) > 0, "in_channels length should > 0"
        self.in_channels = in_channels
        self.num_blocks = len(in_channels)

        self._out_channels = []
Q
qingqing01 已提交
177 178 179 180
        self.yolo_blocks = []
        self.routes = []
        for i in range(self.num_blocks):
            name = 'yolo_block.{}'.format(i)
181 182 183
            in_channel = in_channels[-i - 1]
            if i > 0:
                in_channel += 512 // (2**i)
Q
qingqing01 已提交
184 185 186
            yolo_block = self.add_sublayer(
                name,
                YoloDetBlock(
187
                    in_channel,
Q
qingqing01 已提交
188 189 190 191
                    channel=512 // (2**i),
                    norm_type=norm_type,
                    name=name))
            self.yolo_blocks.append(yolo_block)
192 193
            # tip layer output channel doubled
            self._out_channels.append(1024 // (2**i))
Q
qingqing01 已提交
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223

            if i < self.num_blocks - 1:
                name = 'yolo_transition.{}'.format(i)
                route = self.add_sublayer(
                    name,
                    ConvBNLayer(
                        ch_in=512 // (2**i),
                        ch_out=256 // (2**i),
                        filter_size=1,
                        stride=1,
                        padding=0,
                        norm_type=norm_type,
                        name=name))
                self.routes.append(route)

    def forward(self, blocks):
        assert len(blocks) == self.num_blocks
        blocks = blocks[::-1]
        yolo_feats = []
        for i, block in enumerate(blocks):
            if i > 0:
                block = paddle.concat([route, block], axis=1)
            route, tip = self.yolo_blocks[i](block)
            yolo_feats.append(tip)

            if i < self.num_blocks - 1:
                route = self.routes[i](route)
                route = F.interpolate(route, scale_factor=2.)

        return yolo_feats
W
wangxinxin08 已提交
224

225 226 227 228 229 230 231 232
    @classmethod
    def from_config(cls, cfg, input_shape):
        return {'in_channels': [i.channels for i in input_shape], }

    @property
    def out_shape(self):
        return [ShapeSpec(channels=c) for c in self._out_channels]

W
wangxinxin08 已提交
233 234 235 236 237 238

@register
@serializable
class PPYOLOFPN(nn.Layer):
    __shared__ = ['norm_type']

239
    def __init__(self, in_channels=[512, 1024, 2048], norm_type='bn', **kwargs):
W
wangxinxin08 已提交
240
        super(PPYOLOFPN, self).__init__()
241 242 243
        assert len(in_channels) > 0, "in_channels length should > 0"
        self.in_channels = in_channels
        self.num_blocks = len(in_channels)
W
wangxinxin08 已提交
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
        # parse kwargs
        self.coord_conv = kwargs.get('coord_conv', False)
        self.drop_block = kwargs.get('drop_block', False)
        if self.drop_block:
            self.block_size = kwargs.get('block_size', 3)
            self.keep_prob = kwargs.get('keep_prob', 0.9)

        self.spp = kwargs.get('spp', False)
        if self.coord_conv:
            ConvLayer = CoordConv
        else:
            ConvLayer = ConvBNLayer

        if self.drop_block:
            dropblock_cfg = [[
                'dropblock', DropBlock, [self.block_size, self.keep_prob],
                dict()
            ]]
        else:
            dropblock_cfg = []

265
        self._out_channels = []
W
wangxinxin08 已提交
266 267
        self.yolo_blocks = []
        self.routes = []
268 269 270
        for i, ch_in in enumerate(self.in_channels[::-1]):
            if i > 0:
                ch_in += 512 // (2**i)
W
wangxinxin08 已提交
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
            channel = 64 * (2**self.num_blocks) // (2**i)
            base_cfg = [
                # name of layer, Layer, args
                ['conv0', ConvLayer, [ch_in, channel, 1]],
                ['conv1', ConvBNLayer, [channel, channel * 2, 3]],
                ['conv2', ConvLayer, [channel * 2, channel, 1]],
                ['conv3', ConvBNLayer, [channel, channel * 2, 3]],
                ['route', ConvLayer, [channel * 2, channel, 1]],
                ['tip', ConvLayer, [channel, channel * 2, 3]]
            ]
            for conf in base_cfg:
                filter_size = conf[-1][-1]
                conf.append(dict(padding=filter_size // 2, norm_type=norm_type))
            if i == 0:
                if self.spp:
                    pool_size = [5, 9, 13]
                    spp_cfg = [[
                        'spp', SPP,
                        [channel * (len(pool_size) + 1), channel, 1], dict(
                            pool_size=pool_size, norm_type=norm_type)
                    ]]
                else:
                    spp_cfg = []
                cfg = base_cfg[0:3] + spp_cfg + base_cfg[
                    3:4] + dropblock_cfg + base_cfg[4:6]
            else:
                cfg = base_cfg[0:2] + dropblock_cfg + base_cfg[2:6]
            name = 'yolo_block.{}'.format(i)
            yolo_block = self.add_sublayer(name, PPYOLODetBlock(cfg, name))
            self.yolo_blocks.append(yolo_block)
301
            self._out_channels.append(channel * 2)
W
wangxinxin08 已提交
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
            if i < self.num_blocks - 1:
                name = 'yolo_transition.{}'.format(i)
                route = self.add_sublayer(
                    name,
                    ConvBNLayer(
                        ch_in=channel,
                        ch_out=channel // 2,
                        filter_size=1,
                        stride=1,
                        padding=0,
                        norm_type=norm_type,
                        name=name))
                self.routes.append(route)

    def forward(self, blocks):
        assert len(blocks) == self.num_blocks
        blocks = blocks[::-1]
        yolo_feats = []
        for i, block in enumerate(blocks):
            if i > 0:
                block = paddle.concat([route, block], axis=1)
            route, tip = self.yolo_blocks[i](block)
            yolo_feats.append(tip)

            if i < self.num_blocks - 1:
                route = self.routes[i](route)
                route = F.interpolate(route, scale_factor=2.)

330 331 332 333 334 335 336 337 338
        return yolo_feats

    @classmethod
    def from_config(cls, cfg, input_shape):
        return {'in_channels': [i.channels for i in input_shape], }

    @property
    def out_shape(self):
        return [ShapeSpec(channels=c) for c in self._out_channels]