squeezenet.py 4.9 KB
Newer Older
W
WuHaobo 已提交
1
import paddle
2 3 4 5 6
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
7 8 9

__all__ = ["SqueezeNet1_0", "SqueezeNet1_1"]

10 11 12 13 14 15 16 17

class MakeFireConv(nn.Layer):
    def __init__(self,
                 input_channels,
                 output_channels,
                 filter_size,
                 padding=0,
                 name=None):
W
fix  
wqz960 已提交
18
        super(MakeFireConv, self).__init__()
19 20 21 22 23 24 25 26 27 28 29 30
        self._conv = Conv2d(
            input_channels,
            output_channels,
            filter_size,
            padding=padding,
            weight_attr=ParamAttr(name=name + "_weights"),
            bias_attr=ParamAttr(name=name + "_offset"))

    def forward(self, x):
        x = self._conv(x)
        x = F.relu(x)
        return x
31

W
WuHaobo 已提交
32

33
class MakeFire(nn.Layer):
34
    def __init__(self,
35 36 37 38 39
                 input_channels,
                 squeeze_channels,
                 expand1x1_channels,
                 expand3x3_channels,
                 name=None):
W
fix  
wqz960 已提交
40
        super(MakeFire, self).__init__()
41 42 43 44 45 46 47 48 49 50
        self._conv = MakeFireConv(
            input_channels, squeeze_channels, 1, name=name + "_squeeze1x1")
        self._conv_path1 = MakeFireConv(
            squeeze_channels, expand1x1_channels, 1, name=name + "_expand1x1")
        self._conv_path2 = MakeFireConv(
            squeeze_channels,
            expand3x3_channels,
            3,
            padding=1,
            name=name + "_expand3x3")
W
WuHaobo 已提交
51

52 53 54 55
    def forward(self, inputs):
        x = self._conv(inputs)
        x1 = self._conv_path1(x)
        x2 = self._conv_path2(x)
56
        return paddle.concat([x1, x2], axis=1)
W
WuHaobo 已提交
57

58 59

class SqueezeNet(nn.Layer):
60 61
    def __init__(self, version, class_dim=1000):
        super(SqueezeNet, self).__init__()
W
WuHaobo 已提交
62 63
        self.version = version

64
        if self.version == "1.0":
65 66 67 68 69 70 71 72
            self._conv = Conv2d(
                3,
                96,
                7,
                stride=2,
                weight_attr=ParamAttr(name="conv1_weights"),
                bias_attr=ParamAttr(name="conv1_offset"))
            self._pool = MaxPool2d(kernel_size=3, stride=2, padding=0)
W
fix  
wqz960 已提交
73 74 75
            self._conv1 = MakeFire(96, 16, 64, 64, name="fire2")
            self._conv2 = MakeFire(128, 16, 64, 64, name="fire3")
            self._conv3 = MakeFire(128, 32, 128, 128, name="fire4")
76

W
fix  
wqz960 已提交
77 78 79 80
            self._conv4 = MakeFire(256, 32, 128, 128, name="fire5")
            self._conv5 = MakeFire(256, 48, 192, 192, name="fire6")
            self._conv6 = MakeFire(384, 48, 192, 192, name="fire7")
            self._conv7 = MakeFire(384, 64, 256, 256, name="fire8")
81

W
fix  
wqz960 已提交
82
            self._conv8 = MakeFire(512, 64, 256, 256, name="fire9")
W
WuHaobo 已提交
83
        else:
84 85 86 87 88 89 90 91 92
            self._conv = Conv2d(
                3,
                64,
                3,
                stride=2,
                padding=1,
                weight_attr=ParamAttr(name="conv1_weights"),
                bias_attr=ParamAttr(name="conv1_offset"))
            self._pool = MaxPool2d(kernel_size=3, stride=2, padding=0)
W
fix  
wqz960 已提交
93 94
            self._conv1 = MakeFire(64, 16, 64, 64, name="fire2")
            self._conv2 = MakeFire(128, 16, 64, 64, name="fire3")
W
WuHaobo 已提交
95

W
fix  
wqz960 已提交
96 97
            self._conv3 = MakeFire(128, 32, 128, 128, name="fire4")
            self._conv4 = MakeFire(256, 32, 128, 128, name="fire5")
W
WuHaobo 已提交
98

W
fix  
wqz960 已提交
99 100 101 102
            self._conv5 = MakeFire(256, 48, 192, 192, name="fire6")
            self._conv6 = MakeFire(384, 48, 192, 192, name="fire7")
            self._conv7 = MakeFire(384, 64, 256, 256, name="fire8")
            self._conv8 = MakeFire(512, 64, 256, 256, name="fire9")
103 104

        self._drop = Dropout(p=0.5)
105 106 107 108 109 110 111
        self._conv9 = Conv2d(
            512,
            class_dim,
            1,
            weight_attr=ParamAttr(name="conv10_weights"),
            bias_attr=ParamAttr(name="conv10_offset"))
        self._avg_pool = AdaptiveAvgPool2d(1)
W
WuHaobo 已提交
112

113 114
    def forward(self, inputs):
        x = self._conv(inputs)
115
        x = F.relu(x)
116
        x = self._pool(x)
117
        if self.version == "1.0":
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
            x = self._conv1(x)
            x = self._conv2(x)
            x = self._conv3(x)
            x = self._pool(x)
            x = self._conv4(x)
            x = self._conv5(x)
            x = self._conv6(x)
            x = self._conv7(x)
            x = self._pool(x)
            x = self._conv8(x)
        else:
            x = self._conv1(x)
            x = self._conv2(x)
            x = self._pool(x)
            x = self._conv3(x)
            x = self._conv4(x)
            x = self._pool(x)
            x = self._conv5(x)
            x = self._conv6(x)
            x = self._conv7(x)
            x = self._conv8(x)
        x = self._drop(x)
        x = self._conv9(x)
141
        x = F.relu(x)
142
        x = self._avg_pool(x)
143
        x = paddle.squeeze(x, axis=[2, 3])
144 145
        return x

146

W
wqz960 已提交
147 148
def SqueezeNet1_0(**args):
    model = SqueezeNet(version="1.0", **args)
149 150
    return model

W
WuHaobo 已提交
151

W
wqz960 已提交
152 153
def SqueezeNet1_1(**args):
    model = SqueezeNet(version="1.1", **args)
154
    return model