darknet.py 10.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Q
qingqing01 已提交
15 16 17 18 19 20
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
W
wangxinxin08 已提交
21
from ppdet.modeling.ops import batch_norm, mish
22
from ..shape_spec import ShapeSpec
Q
qingqing01 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35

__all__ = ['DarkNet', 'ConvBNLayer']


class ConvBNLayer(nn.Layer):
    def __init__(self,
                 ch_in,
                 ch_out,
                 filter_size=3,
                 stride=1,
                 groups=1,
                 padding=0,
                 norm_type='bn',
F
Feng Ni 已提交
36
                 norm_decay=0.,
Q
qingqing01 已提交
37
                 act="leaky",
38
                 freeze_norm=False,
39 40
                 data_format='NCHW',
                 name=''):
W
wangxinxin08 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53
        """
        conv + bn + activation layer

        Args:
            ch_in (int): input channel
            ch_out (int): output channel
            filter_size (int): filter size, default 3
            stride (int): stride, default 1
            groups (int): number of groups of conv layer, default 1
            padding (int): padding size, default 0
            norm_type (str): batch norm type, default bn
            norm_decay (str): decay for weight and bias of batch norm layer, default 0.
            act (str): activation function type, default 'leaky', which means leaky_relu
54
            freeze_norm (bool): whether to freeze norm, default False
W
wangxinxin08 已提交
55 56
            data_format (str): data format, NCHW or NHWC
        """
Q
qingqing01 已提交
57 58 59 60 61 62 63 64 65
        super(ConvBNLayer, self).__init__()

        self.conv = nn.Conv2D(
            in_channels=ch_in,
            out_channels=ch_out,
            kernel_size=filter_size,
            stride=stride,
            padding=padding,
            groups=groups,
66
            data_format=data_format,
Q
qingqing01 已提交
67
            bias_attr=False)
F
Feng Ni 已提交
68
        self.batch_norm = batch_norm(
69 70 71
            ch_out,
            norm_type=norm_type,
            norm_decay=norm_decay,
72
            freeze_norm=freeze_norm,
73
            data_format=data_format)
Q
qingqing01 已提交
74 75 76 77 78 79 80
        self.act = act

    def forward(self, inputs):
        out = self.conv(inputs)
        out = self.batch_norm(out)
        if self.act == 'leaky':
            out = F.leaky_relu(out, 0.1)
W
wangxinxin08 已提交
81 82
        elif self.act == 'mish':
            out = mish(out)
Q
qingqing01 已提交
83 84 85 86 87 88 89 90 91 92 93
        return out


class DownSample(nn.Layer):
    def __init__(self,
                 ch_in,
                 ch_out,
                 filter_size=3,
                 stride=2,
                 padding=1,
                 norm_type='bn',
F
Feng Ni 已提交
94
                 norm_decay=0.,
95
                 freeze_norm=False,
96
                 data_format='NCHW'):
W
wangxinxin08 已提交
97 98 99 100 101 102 103 104 105 106 107
        """
        downsample layer

        Args:
            ch_in (int): input channel
            ch_out (int): output channel
            filter_size (int): filter size, default 3
            stride (int): stride, default 2
            padding (int): padding size, default 1
            norm_type (str): batch norm type, default bn
            norm_decay (str): decay for weight and bias of batch norm layer, default 0.
108
            freeze_norm (bool): whether to freeze norm, default False
W
wangxinxin08 已提交
109 110
            data_format (str): data format, NCHW or NHWC
        """
Q
qingqing01 已提交
111 112 113 114 115 116 117 118 119 120

        super(DownSample, self).__init__()

        self.conv_bn_layer = ConvBNLayer(
            ch_in=ch_in,
            ch_out=ch_out,
            filter_size=filter_size,
            stride=stride,
            padding=padding,
            norm_type=norm_type,
F
Feng Ni 已提交
121
            norm_decay=norm_decay,
122
            freeze_norm=freeze_norm,
123
            data_format=data_format)
Q
qingqing01 已提交
124 125 126 127 128 129 130 131
        self.ch_out = ch_out

    def forward(self, inputs):
        out = self.conv_bn_layer(inputs)
        return out


class BasicBlock(nn.Layer):
132 133 134 135 136
    def __init__(self,
                 ch_in,
                 ch_out,
                 norm_type='bn',
                 norm_decay=0.,
137
                 freeze_norm=False,
138
                 data_format='NCHW'):
W
wangxinxin08 已提交
139 140 141 142 143 144 145 146
        """
        BasicBlock layer of DarkNet

        Args:
            ch_in (int): input channel
            ch_out (int): output channel
            norm_type (str): batch norm type, default bn
            norm_decay (str): decay for weight and bias of batch norm layer, default 0.
147
            freeze_norm (bool): whether to freeze norm, default False
W
wangxinxin08 已提交
148 149 150
            data_format (str): data format, NCHW or NHWC
        """

Q
qingqing01 已提交
151 152 153 154 155 156 157 158 159
        super(BasicBlock, self).__init__()

        self.conv1 = ConvBNLayer(
            ch_in=ch_in,
            ch_out=ch_out,
            filter_size=1,
            stride=1,
            padding=0,
            norm_type=norm_type,
F
Feng Ni 已提交
160
            norm_decay=norm_decay,
161
            freeze_norm=freeze_norm,
162
            data_format=data_format)
Q
qingqing01 已提交
163 164 165 166 167 168 169
        self.conv2 = ConvBNLayer(
            ch_in=ch_out,
            ch_out=ch_out * 2,
            filter_size=3,
            stride=1,
            padding=1,
            norm_type=norm_type,
F
Feng Ni 已提交
170
            norm_decay=norm_decay,
171
            freeze_norm=freeze_norm,
172
            data_format=data_format)
Q
qingqing01 已提交
173 174 175 176 177 178 179 180 181

    def forward(self, inputs):
        conv1 = self.conv1(inputs)
        conv2 = self.conv2(conv1)
        out = paddle.add(x=inputs, y=conv2)
        return out


class Blocks(nn.Layer):
F
Feng Ni 已提交
182 183 184 185 186 187
    def __init__(self,
                 ch_in,
                 ch_out,
                 count,
                 norm_type='bn',
                 norm_decay=0.,
188
                 freeze_norm=False,
189 190
                 name=None,
                 data_format='NCHW'):
W
wangxinxin08 已提交
191 192 193 194 195 196 197 198 199
        """
        Blocks layer, which consist of some BaickBlock layers

        Args:
            ch_in (int): input channel
            ch_out (int): output channel
            count (int): number of BasicBlock layer
            norm_type (str): batch norm type, default bn
            norm_decay (str): decay for weight and bias of batch norm layer, default 0.
200
            freeze_norm (bool): whether to freeze norm, default False
W
wangxinxin08 已提交
201 202 203
            name (str): layer name
            data_format (str): data format, NCHW or NHWC
        """
Q
qingqing01 已提交
204 205 206
        super(Blocks, self).__init__()

        self.basicblock0 = BasicBlock(
F
Feng Ni 已提交
207 208 209 210
            ch_in,
            ch_out,
            norm_type=norm_type,
            norm_decay=norm_decay,
211
            freeze_norm=freeze_norm,
212
            data_format=data_format)
Q
qingqing01 已提交
213 214 215 216 217 218
        self.res_out_list = []
        for i in range(1, count):
            block_name = '{}.{}'.format(name, i)
            res_out = self.add_sublayer(
                block_name,
                BasicBlock(
F
Feng Ni 已提交
219 220 221 222
                    ch_out * 2,
                    ch_out,
                    norm_type=norm_type,
                    norm_decay=norm_decay,
223
                    freeze_norm=freeze_norm,
224
                    data_format=data_format))
Q
qingqing01 已提交
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
            self.res_out_list.append(res_out)
        self.ch_out = ch_out

    def forward(self, inputs):
        y = self.basicblock0(inputs)
        for basic_block_i in self.res_out_list:
            y = basic_block_i(y)
        return y


DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}


@register
@serializable
class DarkNet(nn.Layer):
241
    __shared__ = ['norm_type', 'data_format']
Q
qingqing01 已提交
242 243 244 245 246 247

    def __init__(self,
                 depth=53,
                 freeze_at=-1,
                 return_idx=[2, 3, 4],
                 num_stages=5,
F
Feng Ni 已提交
248
                 norm_type='bn',
249
                 norm_decay=0.,
250
                 freeze_norm=False,
251
                 data_format='NCHW'):
W
wangxinxin08 已提交
252 253 254 255 256 257 258 259 260 261 262 263
        """
        Darknet, see https://pjreddie.com/darknet/yolo/

        Args:
            depth (int): depth of network
            freeze_at (int): freeze the backbone at which stage
            filter_size (int): filter size, default 3
            return_idx (list): index of stages whose feature maps are returned
            norm_type (str): batch norm type, default bn
            norm_decay (str): decay for weight and bias of batch norm layer, default 0.
            data_format (str): data format, NCHW or NHWC
        """
Q
qingqing01 已提交
264 265 266 267 268 269 270 271 272 273 274 275 276 277
        super(DarkNet, self).__init__()
        self.depth = depth
        self.freeze_at = freeze_at
        self.return_idx = return_idx
        self.num_stages = num_stages
        self.stages = DarkNet_cfg[self.depth][0:num_stages]

        self.conv0 = ConvBNLayer(
            ch_in=3,
            ch_out=32,
            filter_size=3,
            stride=1,
            padding=1,
            norm_type=norm_type,
F
Feng Ni 已提交
278
            norm_decay=norm_decay,
279
            freeze_norm=freeze_norm,
280
            data_format=data_format)
Q
qingqing01 已提交
281 282 283 284 285

        self.downsample0 = DownSample(
            ch_in=32,
            ch_out=32 * 2,
            norm_type=norm_type,
F
Feng Ni 已提交
286
            norm_decay=norm_decay,
287
            freeze_norm=freeze_norm,
288
            data_format=data_format)
Q
qingqing01 已提交
289

290
        self._out_channels = []
Q
qingqing01 已提交
291 292 293 294 295 296 297 298 299 300 301 302
        self.darknet_conv_block_list = []
        self.downsample_list = []
        ch_in = [64, 128, 256, 512, 1024]
        for i, stage in enumerate(self.stages):
            name = 'stage.{}'.format(i)
            conv_block = self.add_sublayer(
                name,
                Blocks(
                    int(ch_in[i]),
                    32 * (2**i),
                    stage,
                    norm_type=norm_type,
F
Feng Ni 已提交
303
                    norm_decay=norm_decay,
304
                    freeze_norm=freeze_norm,
305
                    data_format=data_format,
Q
qingqing01 已提交
306 307
                    name=name))
            self.darknet_conv_block_list.append(conv_block)
308 309
            if i in return_idx:
                self._out_channels.append(64 * (2**i))
Q
qingqing01 已提交
310 311 312 313 314 315 316 317
        for i in range(num_stages - 1):
            down_name = 'stage.{}.downsample'.format(i)
            downsample = self.add_sublayer(
                down_name,
                DownSample(
                    ch_in=32 * (2**(i + 1)),
                    ch_out=32 * (2**(i + 2)),
                    norm_type=norm_type,
F
Feng Ni 已提交
318
                    norm_decay=norm_decay,
319
                    freeze_norm=freeze_norm,
320
                    data_format=data_format))
Q
qingqing01 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
            self.downsample_list.append(downsample)

    def forward(self, inputs):
        x = inputs['image']

        out = self.conv0(x)
        out = self.downsample0(out)
        blocks = []
        for i, conv_block_i in enumerate(self.darknet_conv_block_list):
            out = conv_block_i(out)
            if i == self.freeze_at:
                out.stop_gradient = True
            if i in self.return_idx:
                blocks.append(out)
            if i < self.num_stages - 1:
                out = self.downsample_list[i](out)
        return blocks
338 339 340 341

    @property
    def out_shape(self):
        return [ShapeSpec(channels=c) for c in self._out_channels]