vgg.py 4.8 KB
Newer Older
W
WuHaobo 已提交
1
import paddle
2 3 4
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
5 6
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
7 8 9

__all__ = ["VGG11", "VGG13", "VGG16", "VGG19"]

10 11 12

class ConvBlock(nn.Layer):
    def __init__(self, input_channels, output_channels, groups, name=None):
W
wqz960 已提交
13
        super(ConvBlock, self).__init__()
14 15

        self.groups = groups
16
        self._conv_1 = Conv2D(
17 18 19 20 21 22 23
            in_channels=input_channels,
            out_channels=output_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            weight_attr=ParamAttr(name=name + "1_weights"),
            bias_attr=False)
24
        if groups == 2 or groups == 3 or groups == 4:
25
            self._conv_2 = Conv2D(
26 27 28 29 30 31 32
                in_channels=output_channels,
                out_channels=output_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                weight_attr=ParamAttr(name=name + "2_weights"),
                bias_attr=False)
33
        if groups == 3 or groups == 4:
34
            self._conv_3 = Conv2D(
35 36 37 38 39 40 41
                in_channels=output_channels,
                out_channels=output_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                weight_attr=ParamAttr(name=name + "3_weights"),
                bias_attr=False)
42
        if groups == 4:
43
            self._conv_4 = Conv2D(
44 45 46 47 48 49 50 51
                in_channels=output_channels,
                out_channels=output_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                weight_attr=ParamAttr(name=name + "4_weights"),
                bias_attr=False)

52
        self._pool = MaxPool2D(kernel_size=2, stride=2, padding=0)
53 54 55

    def forward(self, inputs):
        x = self._conv_1(inputs)
56
        x = F.relu(x)
57 58
        if self.groups == 2 or self.groups == 3 or self.groups == 4:
            x = self._conv_2(x)
59 60
            x = F.relu(x)
        if self.groups == 3 or self.groups == 4:
61
            x = self._conv_3(x)
62
            x = F.relu(x)
63 64
        if self.groups == 4:
            x = self._conv_4(x)
65
            x = F.relu(x)
66 67 68
        x = self._pool(x)
        return x

69 70

class VGGNet(nn.Layer):
L
littletomatodonkey 已提交
71
    def __init__(self, layers=11, stop_grad_layers=0, class_dim=1000):
72
        super(VGGNet, self).__init__()
W
WuHaobo 已提交
73 74

        self.layers = layers
L
littletomatodonkey 已提交
75
        self.stop_grad_layers = stop_grad_layers
76 77 78 79 80 81
        self.vgg_configure = {
            11: [1, 1, 2, 2, 2],
            13: [2, 2, 2, 2, 2],
            16: [2, 2, 3, 3, 3],
            19: [2, 2, 4, 4, 4]
        }
82
        assert self.layers in self.vgg_configure.keys(), \
83
            "supported layers are {} but input layer is {}".format(
C
cuicheng01 已提交
84
                self.vgg_configure.keys(), layers)
85 86
        self.groups = self.vgg_configure[self.layers]

W
wqz960 已提交
87 88 89 90 91
        self._conv_block_1 = ConvBlock(3, 64, self.groups[0], name="conv1_")
        self._conv_block_2 = ConvBlock(64, 128, self.groups[1], name="conv2_")
        self._conv_block_3 = ConvBlock(128, 256, self.groups[2], name="conv3_")
        self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_")
        self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_")
92

L
littletomatodonkey 已提交
93 94 95 96 97 98 99 100
        for idx, block in enumerate([
                self._conv_block_1, self._conv_block_2, self._conv_block_3,
                self._conv_block_4, self._conv_block_5
        ]):
            if self.stop_grad_layers >= idx + 1:
                for param in block.parameters():
                    param.trainable = False

littletomatodonkey's avatar
littletomatodonkey 已提交
101
        self._drop = Dropout(p=0.5, mode="downscale_in_infer")
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
        self._fc1 = Linear(
            7 * 7 * 512,
            4096,
            weight_attr=ParamAttr(name="fc6_weights"),
            bias_attr=ParamAttr(name="fc6_offset"))
        self._fc2 = Linear(
            4096,
            4096,
            weight_attr=ParamAttr(name="fc7_weights"),
            bias_attr=ParamAttr(name="fc7_offset"))
        self._out = Linear(
            4096,
            class_dim,
            weight_attr=ParamAttr(name="fc8_weights"),
            bias_attr=ParamAttr(name="fc8_offset"))
117 118 119 120 121 122 123

    def forward(self, inputs):
        x = self._conv_block_1(inputs)
        x = self._conv_block_2(x)
        x = self._conv_block_3(x)
        x = self._conv_block_4(x)
        x = self._conv_block_5(x)
L
littletomatodonkey 已提交
124
        x = paddle.flatten(x, start_axis=1, stop_axis=-1)
125
        x = self._fc1(x)
126
        x = F.relu(x)
W
wqz960 已提交
127
        x = self._drop(x)
128
        x = self._fc2(x)
129
        x = F.relu(x)
W
wqz960 已提交
130
        x = self._drop(x)
131 132
        x = self._out(x)
        return x
W
WuHaobo 已提交
133

134

W
wqz960 已提交
135 136
def VGG11(**args):
    model = VGGNet(layers=11, **args)
137 138
    return model

W
WuHaobo 已提交
139

W
wqz960 已提交
140 141
def VGG13(**args):
    model = VGGNet(layers=13, **args)
W
WuHaobo 已提交
142 143
    return model

144

W
wqz960 已提交
145 146
def VGG16(**args):
    model = VGGNet(layers=16, **args)
147 148
    return model

W
WuHaobo 已提交
149

W
wqz960 已提交
150 151
def VGG19(**args):
    model = VGGNet(layers=19, **args)
152
    return model