vgg.py 5.1 KB
Newer Older
W
WuHaobo 已提交
1 2
import paddle
import paddle.fluid as fluid
3 4 5 6 7
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear

__all__ = ["VGG11", "VGG13", "VGG16", "VGG19"]

W
wqz960 已提交
8
class ConvBlock(fluid.dygraph.Layer):
9 10 11 12 13
    def __init__(self, 
                input_channels, 
                output_channels,
                groups,
                name=None):
W
wqz960 已提交
14
        super(ConvBlock, self).__init__()
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

        self.groups = groups
        self._conv_1 = Conv2D(num_channels=input_channels,
                            num_filters=output_channels,
                            filter_size=3,
                            stride=1,
                            padding=1,
                            act="relu",
                            param_attr=ParamAttr(name=name + "1_weights"),
                            bias_attr=False)
        if groups == 2 or groups == 3 or groups == 4:
            self._conv_2 = Conv2D(num_channels=output_channels,
                                num_filters=output_channels,
                                filter_size=3,
                                stride=1,
                                padding=1,
                                act="relu",
                                param_attr=ParamAttr(name=name + "2_weights"),
                                bias_attr=False)
        if groups == 3 or groups == 4:
            self._conv_3 = Conv2D(num_channels=output_channels,
                                num_filters=output_channels,
                                filter_size=3,
                                stride=1,
                                padding=1,
                                act="relu",
                                param_attr=ParamAttr(name=name + "3_weights"),
                                bias_attr=False)
        if groups == 4:
W
wqz960 已提交
44 45
            self._conv_4 = Conv2D(num_channels=output_channels,
                                num_filters=output_channels,
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
                                filter_size=3,
                                stride=1,
                                padding=1,
                                act="relu",
                                param_attr=ParamAttr(name=name + "4_weights"),
                                bias_attr=False)
        self._pool = Pool2D(pool_size=2,
                            pool_type="max",
                            pool_stride=2)

    def forward(self, inputs):
        x = self._conv_1(inputs)
        if self.groups == 2 or self.groups == 3 or self.groups == 4:
            x = self._conv_2(x)
        if self.groups == 3 or self.groups == 4 :
            x = self._conv_3(x)
        if self.groups == 4:
            x = self._conv_4(x)
        x = self._pool(x)
        return x

class VGGNet(fluid.dygraph.Layer):
    def __init__(self, layers=11, class_dim=1000):
        super(VGGNet, self).__init__()
W
WuHaobo 已提交
70 71

        self.layers = layers
72 73 74 75 76 77 78 79
        self.vgg_configure = {11: [1, 1, 2, 2, 2],
                            13: [2, 2, 2, 2, 2],
                            16: [2, 2, 3, 3, 3],
                            19: [2, 2, 4, 4, 4]}
        assert self.layers in self.vgg_configure.keys(), \
            "supported layers are {} but input layer is {}".format(vgg_configure.keys(), layers)
        self.groups = self.vgg_configure[self.layers]

W
wqz960 已提交
80 81 82 83 84
        self._conv_block_1 = ConvBlock(3, 64, self.groups[0], name="conv1_")
        self._conv_block_2 = ConvBlock(64, 128, self.groups[1], name="conv2_")
        self._conv_block_3 = ConvBlock(128, 256, self.groups[2], name="conv3_")
        self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_")
        self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_")
85

W
wqz960 已提交
86
        self._drop = fluid.dygraph.Dropout(p=0.5)
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
        self._fc1 = Linear(input_dim=7*7*512,
                        output_dim=4096,
                        act="relu",
                        param_attr=ParamAttr(name="fc6_weights"),
                        bias_attr=ParamAttr(name="fc6_offset"))
        self._fc2 = Linear(input_dim=4096,
                        output_dim=4096,
                        act="relu",
                        param_attr=ParamAttr(name="fc7_weights"),
                        bias_attr=ParamAttr(name="fc7_offset"))
        self._out = Linear(input_dim=4096,
                        output_dim=class_dim,
                        param_attr=ParamAttr(name="fc8_weights"),
                        bias_attr=ParamAttr(name="fc8_offset"))

    def forward(self, inputs):
        x = self._conv_block_1(inputs)
        x = self._conv_block_2(x)
        x = self._conv_block_3(x)
        x = self._conv_block_4(x)
        x = self._conv_block_5(x)

W
wqz960 已提交
109
        x = fluid.layers.reshape(x, [0,-1])
110
        x = self._fc1(x)
W
wqz960 已提交
111
        x = self._drop(x)
112
        x = self._fc2(x)
W
wqz960 已提交
113
        x = self._drop(x)
114 115
        x = self._out(x)
        return x
W
WuHaobo 已提交
116

W
wqz960 已提交
117 118
def VGG11(**args):
    model = VGGNet(layers=11, **args)
119
    return model 
W
WuHaobo 已提交
120

W
wqz960 已提交
121 122
def VGG13(**args):
    model = VGGNet(layers=13, **args)
W
WuHaobo 已提交
123 124
    return model

W
wqz960 已提交
125 126
def VGG16(**args):
    model = VGGNet(layers=16, **args)
127
    return model 
W
WuHaobo 已提交
128

W
wqz960 已提交
129 130
def VGG19(**args):
    model = VGGNet(layers=19, **args)
131
    return model