squeezenet.py 5.7 KB
Newer Older
W
WuHaobo 已提交
1 2 3
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
4 5 6 7
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout

__all__ = ["SqueezeNet1_0", "SqueezeNet1_1"]

W
fix  
wqz960 已提交
8
class MakeFireConv(fluid.dygraph.Layer):
9 10 11 12 13 14
    def __init__(self, 
                input_channels,
                output_channels,
                filter_size,
                padding=0,
                name=None):
W
fix  
wqz960 已提交
15
        super(MakeFireConv, self).__init__()
16 17 18 19 20 21 22 23 24 25
        self._conv = Conv2D(input_channels,
                            output_channels,
                            filter_size, 
                            padding=padding, 
                            act="relu",
                            param_attr=ParamAttr(name=name + "_weights"),
                            bias_attr=ParamAttr(name=name + "_offset"))

    def forward(self, inputs):
        return self._conv(inputs)
W
WuHaobo 已提交
26

W
fix  
wqz960 已提交
27
class MakeFire(fluid.dygraph.Layer):
28 29 30 31 32 33
    def __init__(self,
                input_channels,
                squeeze_channels,
                expand1x1_channels,
                expand3x3_channels,
                name=None):
W
fix  
wqz960 已提交
34 35
        super(MakeFire, self).__init__()
        self._conv = MakeFireConv(input_channels,
36 37 38
                                    squeeze_channels,
                                    1,
                                    name=name + "_squeeze1x1")
W
fix  
wqz960 已提交
39
        self._conv_path1 = MakeFireConv(squeeze_channels,
40 41 42
                                        expand1x1_channels,
                                        1,
                                        name=name + "_expand1x1")
W
fix  
wqz960 已提交
43
        self._conv_path2 = MakeFireConv(squeeze_channels,
44 45 46 47
                                        expand3x3_channels,
                                        3,
                                        padding=1,
                                        name=name + "_expand3x3")
W
WuHaobo 已提交
48

49 50 51 52 53
    def forward(self, inputs):
        x = self._conv(inputs)
        x1 = self._conv_path1(x)
        x2 = self._conv_path2(x)
        return fluid.layers.concat([x1, x2], axis=1)
W
WuHaobo 已提交
54

55 56 57
class SqueezeNet(fluid.dygraph.Layer):
    def __init__(self, version, class_dim=1000):
        super(SqueezeNet, self).__init__()
W
WuHaobo 已提交
58 59
        self.version = version

60 61 62 63 64 65 66 67 68 69 70
        if self.version == "1.0":
            self._conv = Conv2D(3,
                                96,
                                7,
                                stride=2,
                                act="relu",
                                param_attr=ParamAttr(name="conv1_weights"),
                                bias_attr=ParamAttr(name="conv1_offset"))
            self._pool = Pool2D(pool_size=3,
                                pool_stride=2,
                                pool_type="max")
W
fix  
wqz960 已提交
71 72 73
            self._conv1 = MakeFire(96, 16, 64, 64, name="fire2")
            self._conv2 = MakeFire(128, 16, 64, 64, name="fire3")
            self._conv3 = MakeFire(128, 32, 128, 128, name="fire4")
74

W
fix  
wqz960 已提交
75 76 77 78
            self._conv4 = MakeFire(256, 32, 128, 128, name="fire5")
            self._conv5 = MakeFire(256, 48, 192, 192, name="fire6")
            self._conv6 = MakeFire(384, 48, 192, 192, name="fire7")
            self._conv7 = MakeFire(384, 64, 256, 256, name="fire8")
79

W
fix  
wqz960 已提交
80
            self._conv8 = MakeFire(512, 64, 256, 256, name="fire9")
W
WuHaobo 已提交
81
        else:
82 83 84 85 86 87 88 89 90 91 92
            self._conv = Conv2D(3,
                                64,
                                3,
                                stride=2,
                                padding=1,
                                act="relu",
                                param_attr=ParamAttr(name="conv1_weights"),
                                bias_attr=ParamAttr(name="conv1_offset"))
            self._pool = Pool2D(pool_size=3,
                                pool_stride=2,
                                pool_type="max")
W
fix  
wqz960 已提交
93 94
            self._conv1 = MakeFire(64, 16, 64, 64, name="fire2")
            self._conv2 = MakeFire(128, 16, 64, 64, name="fire3")
W
WuHaobo 已提交
95

W
fix  
wqz960 已提交
96 97
            self._conv3 = MakeFire(128, 32, 128, 128, name="fire4")
            self._conv4 = MakeFire(256, 32, 128, 128, name="fire5")
W
WuHaobo 已提交
98

W
fix  
wqz960 已提交
99 100 101 102
            self._conv5 = MakeFire(256, 48, 192, 192, name="fire6")
            self._conv6 = MakeFire(384, 48, 192, 192, name="fire7")
            self._conv7 = MakeFire(384, 64, 256, 256, name="fire8")
            self._conv8 = MakeFire(512, 64, 256, 256, name="fire9")
103 104 105 106 107 108 109 110 111 112

        self._drop = Dropout(p=0.5)
        self._conv9 = Conv2D(512, 
                            class_dim, 
                            1, 
                            act="relu",
                            param_attr=ParamAttr(name="conv10_weights"),
                            bias_attr=ParamAttr(name="conv10_offset"))
        self._avg_pool = Pool2D(pool_type="avg",
                                global_pooling=True)
W
WuHaobo 已提交
113

114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
    def forward(self, inputs):
        x = self._conv(inputs)
        x = self._pool(x)
        if self.version=="1.0":
            x = self._conv1(x)
            x = self._conv2(x)
            x = self._conv3(x)
            x = self._pool(x)
            x = self._conv4(x)
            x = self._conv5(x)
            x = self._conv6(x)
            x = self._conv7(x)
            x = self._pool(x)
            x = self._conv8(x)
        else:
            x = self._conv1(x)
            x = self._conv2(x)
            x = self._pool(x)
            x = self._conv3(x)
            x = self._conv4(x)
            x = self._pool(x)
            x = self._conv5(x)
            x = self._conv6(x)
            x = self._conv7(x)
            x = self._conv8(x)
        x = self._drop(x)
        x = self._conv9(x)
        x = self._avg_pool(x)
        x = fluid.layers.squeeze(x, axes=[2,3])
        return x

W
wqz960 已提交
145 146
def SqueezeNet1_0(**args):
    model = SqueezeNet(version="1.0", **args)
147
    return model 
W
WuHaobo 已提交
148

W
wqz960 已提交
149 150
def SqueezeNet1_1(**args):
    model = SqueezeNet(version="1.1", **args)
151
    return model