darknet.py 5.7 KB
Newer Older
W
wangxinxin08 已提交
1 2 3 4
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
5
from paddle.regularizer import L2Decay
6
from ppdet.core.workspace import register, serializable
7
from ppdet.modeling.ops import BatchNorm
8 9 10 11

__all__ = ['DarkNet', 'ConvBNLayer']


W
wangxinxin08 已提交
12
class ConvBNLayer(nn.Layer):
13 14 15 16 17 18 19
    def __init__(self,
                 ch_in,
                 ch_out,
                 filter_size=3,
                 stride=1,
                 groups=1,
                 padding=0,
20
                 norm_type='bn',
W
wangguanzhong 已提交
21 22
                 act="leaky",
                 name=None):
23 24
        super(ConvBNLayer, self).__init__()

W
wangguanzhong 已提交
25
        self.conv = nn.Conv2D(
W
wangxinxin08 已提交
26 27 28
            in_channels=ch_in,
            out_channels=ch_out,
            kernel_size=filter_size,
29 30 31
            stride=stride,
            padding=padding,
            groups=groups,
W
wangxinxin08 已提交
32 33
            weight_attr=ParamAttr(name=name + '.conv.weights'),
            bias_attr=False)
34
        self.batch_norm = BatchNorm(ch_out, norm_type=norm_type, name=name)
35 36 37 38 39 40
        self.act = act

    def forward(self, inputs):
        out = self.conv(inputs)
        out = self.batch_norm(out)
        if self.act == 'leaky':
W
wangxinxin08 已提交
41
            out = F.leaky_relu(out, 0.1)
42 43 44
        return out


W
wangxinxin08 已提交
45
class DownSample(nn.Layer):
W
wangguanzhong 已提交
46 47 48 49 50 51
    def __init__(self,
                 ch_in,
                 ch_out,
                 filter_size=3,
                 stride=2,
                 padding=1,
52
                 norm_type='bn',
W
wangguanzhong 已提交
53
                 name=None):
54 55 56 57 58 59 60 61

        super(DownSample, self).__init__()

        self.conv_bn_layer = ConvBNLayer(
            ch_in=ch_in,
            ch_out=ch_out,
            filter_size=filter_size,
            stride=stride,
W
wangguanzhong 已提交
62
            padding=padding,
63
            norm_type=norm_type,
W
wangguanzhong 已提交
64
            name=name)
65 66 67 68 69 70 71
        self.ch_out = ch_out

    def forward(self, inputs):
        out = self.conv_bn_layer(inputs)
        return out


W
wangxinxin08 已提交
72
class BasicBlock(nn.Layer):
73
    def __init__(self, ch_in, ch_out, norm_type='bn', name=None):
74 75 76
        super(BasicBlock, self).__init__()

        self.conv1 = ConvBNLayer(
W
wangguanzhong 已提交
77 78 79 80 81
            ch_in=ch_in,
            ch_out=ch_out,
            filter_size=1,
            stride=1,
            padding=0,
82
            norm_type=norm_type,
W
wangguanzhong 已提交
83
            name=name + '.0')
84
        self.conv2 = ConvBNLayer(
W
wangguanzhong 已提交
85 86 87 88 89
            ch_in=ch_out,
            ch_out=ch_out * 2,
            filter_size=3,
            stride=1,
            padding=1,
90
            norm_type=norm_type,
W
wangguanzhong 已提交
91
            name=name + '.1')
92 93 94 95

    def forward(self, inputs):
        conv1 = self.conv1(inputs)
        conv2 = self.conv2(conv1)
W
wangxinxin08 已提交
96
        out = paddle.add(x=inputs, y=conv2)
97 98 99
        return out


W
wangxinxin08 已提交
100
class Blocks(nn.Layer):
101
    def __init__(self, ch_in, ch_out, count, norm_type='bn', name=None):
102 103
        super(Blocks, self).__init__()

104 105
        self.basicblock0 = BasicBlock(
            ch_in, ch_out, norm_type=norm_type, name=name + '.0')
106 107
        self.res_out_list = []
        for i in range(1, count):
W
wangguanzhong 已提交
108 109
            block_name = '{}.{}'.format(name, i)
            res_out = self.add_sublayer(
110 111 112
                block_name,
                BasicBlock(
                    ch_out * 2, ch_out, norm_type=norm_type, name=block_name))
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
            self.res_out_list.append(res_out)
        self.ch_out = ch_out

    def forward(self, inputs):
        y = self.basicblock0(inputs)
        for basic_block_i in self.res_out_list:
            y = basic_block_i(y)
        return y


DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}


@register
@serializable
W
wangxinxin08 已提交
128
class DarkNet(nn.Layer):
129 130
    __shared__ = ['norm_type']

W
wangguanzhong 已提交
131 132 133 134
    def __init__(self,
                 depth=53,
                 freeze_at=-1,
                 return_idx=[2, 3, 4],
135 136
                 num_stages=5,
                 norm_type='bn'):
137 138
        super(DarkNet, self).__init__()
        self.depth = depth
W
wangguanzhong 已提交
139 140 141 142
        self.freeze_at = freeze_at
        self.return_idx = return_idx
        self.num_stages = num_stages
        self.stages = DarkNet_cfg[self.depth][0:num_stages]
143 144

        self.conv0 = ConvBNLayer(
W
wangguanzhong 已提交
145 146 147 148 149
            ch_in=3,
            ch_out=32,
            filter_size=3,
            stride=1,
            padding=1,
150
            norm_type=norm_type,
W
wangguanzhong 已提交
151
            name='yolo_input')
152

W
wangguanzhong 已提交
153
        self.downsample0 = DownSample(
154 155 156 157
            ch_in=32,
            ch_out=32 * 2,
            norm_type=norm_type,
            name='yolo_input.downsample')
158

W
wangguanzhong 已提交
159
        self.darknet_conv_block_list = []
160 161 162
        self.downsample_list = []
        ch_in = [64, 128, 256, 512, 1024]
        for i, stage in enumerate(self.stages):
W
wangguanzhong 已提交
163 164
            name = 'stage.{}'.format(i)
            conv_block = self.add_sublayer(
165 166 167 168 169 170 171
                name,
                Blocks(
                    int(ch_in[i]),
                    32 * (2**i),
                    stage,
                    norm_type=norm_type,
                    name=name))
W
wangguanzhong 已提交
172 173 174
            self.darknet_conv_block_list.append(conv_block)
        for i in range(num_stages - 1):
            down_name = 'stage.{}.downsample'.format(i)
175
            downsample = self.add_sublayer(
W
wangguanzhong 已提交
176
                down_name,
177
                DownSample(
W
wangguanzhong 已提交
178 179
                    ch_in=32 * (2**(i + 1)),
                    ch_out=32 * (2**(i + 2)),
180
                    norm_type=norm_type,
W
wangguanzhong 已提交
181
                    name=down_name))
182 183 184 185 186 187 188 189
            self.downsample_list.append(downsample)

    def forward(self, inputs):
        x = inputs['image']

        out = self.conv0(x)
        out = self.downsample0(out)
        blocks = []
W
wangguanzhong 已提交
190
        for i, conv_block_i in enumerate(self.darknet_conv_block_list):
191
            out = conv_block_i(out)
W
wangguanzhong 已提交
192 193 194 195 196
            if i == self.freeze_at:
                out.stop_gradient = True
            if i in self.return_idx:
                blocks.append(out)
            if i < self.num_stages - 1:
197
                out = self.downsample_list[i](out)
W
wangguanzhong 已提交
198
        return blocks