vgg.py 4.4 KB
Newer Older
W
WuHaobo 已提交
1
import paddle
2 3 4 5 6
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
7 8 9

__all__ = ["VGG11", "VGG13", "VGG16", "VGG19"]

10 11 12

class ConvBlock(nn.Layer):
    def __init__(self, input_channels, output_channels, groups, name=None):
W
wqz960 已提交
13
        super(ConvBlock, self).__init__()
14 15

        self.groups = groups
16 17 18 19 20 21 22 23
        self._conv_1 = Conv2d(
            in_channels=input_channels,
            out_channels=output_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            weight_attr=ParamAttr(name=name + "1_weights"),
            bias_attr=False)
24
        if groups == 2 or groups == 3 or groups == 4:
25 26 27 28 29 30 31 32
            self._conv_2 = Conv2d(
                in_channels=output_channels,
                out_channels=output_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                weight_attr=ParamAttr(name=name + "2_weights"),
                bias_attr=False)
33
        if groups == 3 or groups == 4:
34 35 36 37 38 39 40 41
            self._conv_3 = Conv2d(
                in_channels=output_channels,
                out_channels=output_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                weight_attr=ParamAttr(name=name + "3_weights"),
                bias_attr=False)
42
        if groups == 4:
43 44 45 46 47 48 49 50 51 52
            self._conv_4 = Conv2d(
                in_channels=output_channels,
                out_channels=output_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                weight_attr=ParamAttr(name=name + "4_weights"),
                bias_attr=False)

        self._pool = MaxPool2d(kernel_size=2, stride=2, padding=0)
53 54 55

    def forward(self, inputs):
        x = self._conv_1(inputs)
56
        x = F.relu(x)
57 58
        if self.groups == 2 or self.groups == 3 or self.groups == 4:
            x = self._conv_2(x)
59 60
            x = F.relu(x)
        if self.groups == 3 or self.groups == 4:
61
            x = self._conv_3(x)
62
            x = F.relu(x)
63 64
        if self.groups == 4:
            x = self._conv_4(x)
65
            x = F.relu(x)
66 67 68
        x = self._pool(x)
        return x

69 70

class VGGNet(nn.Layer):
71 72
    def __init__(self, layers=11, class_dim=1000):
        super(VGGNet, self).__init__()
W
WuHaobo 已提交
73 74

        self.layers = layers
75 76 77 78 79 80
        self.vgg_configure = {
            11: [1, 1, 2, 2, 2],
            13: [2, 2, 2, 2, 2],
            16: [2, 2, 3, 3, 3],
            19: [2, 2, 4, 4, 4]
        }
81
        assert self.layers in self.vgg_configure.keys(), \
82 83
            "supported layers are {} but input layer is {}".format(
                vgg_configure.keys(), layers)
84 85
        self.groups = self.vgg_configure[self.layers]

W
wqz960 已提交
86 87 88 89 90
        self._conv_block_1 = ConvBlock(3, 64, self.groups[0], name="conv1_")
        self._conv_block_2 = ConvBlock(64, 128, self.groups[1], name="conv2_")
        self._conv_block_3 = ConvBlock(128, 256, self.groups[2], name="conv3_")
        self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_")
        self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_")
91

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
        self._drop = Dropout(p=0.5)
        self._fc1 = Linear(
            7 * 7 * 512,
            4096,
            weight_attr=ParamAttr(name="fc6_weights"),
            bias_attr=ParamAttr(name="fc6_offset"))
        self._fc2 = Linear(
            4096,
            4096,
            weight_attr=ParamAttr(name="fc7_weights"),
            bias_attr=ParamAttr(name="fc7_offset"))
        self._out = Linear(
            4096,
            class_dim,
            weight_attr=ParamAttr(name="fc8_weights"),
            bias_attr=ParamAttr(name="fc8_offset"))
108 109 110 111 112 113 114 115

    def forward(self, inputs):
        x = self._conv_block_1(inputs)
        x = self._conv_block_2(x)
        x = self._conv_block_3(x)
        x = self._conv_block_4(x)
        x = self._conv_block_5(x)

116
        x = paddle.reshape(x, [0, -1])
117
        x = self._fc1(x)
118
        x = F.relu(x)
W
wqz960 已提交
119
        x = self._drop(x)
120
        x = self._fc2(x)
121
        x = F.relu(x)
W
wqz960 已提交
122
        x = self._drop(x)
123 124
        x = self._out(x)
        return x
W
WuHaobo 已提交
125

126

W
wqz960 已提交
127 128
def VGG11(**args):
    model = VGGNet(layers=11, **args)
129 130
    return model

W
WuHaobo 已提交
131

W
wqz960 已提交
132 133
def VGG13(**args):
    model = VGGNet(layers=13, **args)
W
WuHaobo 已提交
134 135
    return model

136

W
wqz960 已提交
137 138
def VGG16(**args):
    model = VGGNet(layers=16, **args)
139 140
    return model

W
WuHaobo 已提交
141

W
wqz960 已提交
142 143
def VGG19(**args):
    model = VGGNet(layers=19, **args)
144
    return model