resnext101_wsl.py 11.7 KB
Newer Older
W
WuHaobo 已提交
1 2 3
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
4
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
W
WuHaobo 已提交
5

6 7 8 9
__all__ = ["ResNeXt101_32x8d_wsl",
            "ResNeXt101_wsl_32x16d_wsl",
            "ResNeXt101_wsl_32x32d_wsl",
            "ResNeXt101_wsl_32x48d_wsl"]
W
WuHaobo 已提交
10

11 12 13 14 15 16 17 18 19 20
class ConvBNLayer(fluid.dygraph.Layer):
    def __init__(self, 
                input_channels, 
                output_channels,
                filter_size,
                stride=1,
                groups=1,
                act=None, 
                name=None):
        super(ConvBNLayer, self).__init__()
W
WuHaobo 已提交
21
        if "downsample" in name:
22
            conv_name = name + ".0"
W
WuHaobo 已提交
23
        else:
24 25 26 27 28 29 30 31 32 33
            conv_name = name 
        self._conv = Conv2D(num_channels=input_channels,
                            num_filters=output_channels,
                            filter_size=filter_size,
                            stride=stride,
                            padding=(filter_size-1)//2,
                            groups=groups,
                            act=None,
                            param_attr=ParamAttr(name=conv_name + ".weight"),
                            bias_attr=False)
W
WuHaobo 已提交
34
        if "downsample" in name:
35
            bn_name = name[:9] + "downsample.1"
W
WuHaobo 已提交
36 37
        else:
            if "conv1" == name:
38
                bn_name = "bn" + name[-1]
W
WuHaobo 已提交
39
            else:
40 41 42 43 44 45 46 47 48 49 50 51 52
                bn_name = (name[:10] if name[7:9].isdigit() else name[:9]) + "bn" + name[-1]
        self._bn = BatchNorm(num_channels=output_channels,
                            act=act,
                            param_attr=ParamAttr(name=bn_name + ".weight"),
                            bias_attr=ParamAttr(name=bn_name + ".bias"),
                            moving_mean_name=bn_name + ".running_mean",
                            moving_variance_name=bn_name + ".running_var")

    def forward(self, inputs):
        x = self._conv(inputs)
        x = self._bn(x)
        return x

W
wqz960 已提交
53
class ShortCut(fluid.dygraph.Layer):
54
    def __init__(self, input_channels, output_channels, stride, name=None):
W
wqz960 已提交
55
        super(ShortCut, self).__init__()
56 57 58 59 60 61 62 63 64 65 66 67 68

        self.input_channels = input_channels
        self.output_channels = output_channels
        self.stride = stride
        if input_channels!=output_channels or stride!=1:
            self._conv = ConvBNLayer(
                input_channels, output_channels, filter_size=1, stride=stride, name=name)

    def forward(self, inputs):
        if self.input_channels!= self.output_channels or self.stride!=1:
            return self._conv(inputs)
        return inputs 

W
wqz960 已提交
69
class BottleneckBlock(fluid.dygraph.Layer):
70
    def __init__(self, input_channels, output_channels, stride, cardinality, width, name):
W
wqz960 已提交
71
        super(BottleneckBlock, self).__init__()
72 73 74 75 76 77 78

        self._conv0 = ConvBNLayer(
            input_channels, output_channels, filter_size=1, act="relu", name=name + ".conv1")
        self._conv1 = ConvBNLayer(
            output_channels, output_channels, filter_size=3, act="relu", stride=stride, groups=cardinality, name=name + ".conv2")
        self._conv2 = ConvBNLayer(
            output_channels, output_channels//(width//8), filter_size=1, act=None, name=name + ".conv3")
W
wqz960 已提交
79
        self._short = ShortCut(
80 81 82 83 84 85 86 87 88
            input_channels, output_channels//(width//8), stride=stride, name=name + ".downsample")

    def forward(self, inputs):
        x = self._conv0(inputs)
        x = self._conv1(x)
        x = self._conv2(x)
        y = self._short(inputs)
        return fluid.layers.elementwise_add(x, y, act="relu")

W
fix  
wqz960 已提交
89
class ResNeXt101WSL(fluid.dygraph.Layer):
90
    def __init__(self, layers=101, cardinality=32, width=48, class_dim=1000):
W
fix  
wqz960 已提交
91
        super(ResNeXt101WSL, self).__init__()
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109

        self.class_dim = class_dim

        self.layers = layers
        self.cardinality = cardinality
        self.width = width
        self.scale = width//8

        self.depth = [3, 4, 23, 3]
        self.base_width = cardinality * width
        num_filters = [self.base_width*i for i in [1,2,4,8]] #[256, 512, 1024, 2048]
        self._conv_stem = ConvBNLayer(
            3, 64, 7, stride=2, act="relu", name="conv1")
        self._pool = Pool2D(pool_size=3,
                            pool_stride=2,
                            pool_padding=1,
                            pool_type="max")

W
wqz960 已提交
110
        self._conv1_0 = BottleneckBlock(
111
            64, num_filters[0], stride=1, cardinality=self.cardinality, width=self.width, name="layer1.0")
W
wqz960 已提交
112
        self._conv1_1 = BottleneckBlock(
113
            num_filters[0]//(width//8), num_filters[0], stride=1, cardinality=self.cardinality, width=self.width, name="layer1.1")
W
wqz960 已提交
114
        self._conv1_2 = BottleneckBlock(
115 116
            num_filters[0]//(width//8), num_filters[0], stride=1, cardinality=self.cardinality, width=self.width, name="layer1.2")

W
wqz960 已提交
117
        self._conv2_0 = BottleneckBlock(
118
            num_filters[0]//(width//8), num_filters[1], stride=2, cardinality=self.cardinality, width=self.width, name="layer2.0")
W
wqz960 已提交
119
        self._conv2_1 = BottleneckBlock(
120
            num_filters[1]//(width//8), num_filters[1], stride=1, cardinality=self.cardinality, width=self.width, name="layer2.1")
W
wqz960 已提交
121
        self._conv2_2 = BottleneckBlock(
122
            num_filters[1]//(width//8), num_filters[1], stride=1, cardinality=self.cardinality, width=self.width, name="layer2.2")
W
wqz960 已提交
123
        self._conv2_3 = BottleneckBlock(
124 125
            num_filters[1]//(width//8), num_filters[1], stride=1, cardinality=self.cardinality, width=self.width, name="layer2.3")

W
wqz960 已提交
126
        self._conv3_0 = BottleneckBlock(
127
            num_filters[1]//(width//8), num_filters[2], stride=2, cardinality=self.cardinality, width=self.width, name="layer3.0")
W
wqz960 已提交
128
        self._conv3_1 = BottleneckBlock(
129
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.1")
W
wqz960 已提交
130
        self._conv3_2 = BottleneckBlock(
131
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.2")
W
wqz960 已提交
132
        self._conv3_3 = BottleneckBlock(
133
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.3")
W
wqz960 已提交
134
        self._conv3_4 = BottleneckBlock(
135
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.4")
W
wqz960 已提交
136
        self._conv3_5 = BottleneckBlock(
137
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.5")
W
wqz960 已提交
138
        self._conv3_6 = BottleneckBlock(
139
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.6")
W
wqz960 已提交
140
        self._conv3_7 = BottleneckBlock(
141
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.7")
W
wqz960 已提交
142
        self._conv3_8 = BottleneckBlock(
143
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.8")
W
wqz960 已提交
144
        self._conv3_9 = BottleneckBlock(
145
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.9")
W
wqz960 已提交
146
        self._conv3_10 = BottleneckBlock(
147
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.10")
W
wqz960 已提交
148
        self._conv3_11 = BottleneckBlock(
149
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.11")
W
wqz960 已提交
150
        self._conv3_12 = BottleneckBlock(
151
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.12")
W
wqz960 已提交
152
        self._conv3_13 = BottleneckBlock(
153
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.13")
W
wqz960 已提交
154
        self._conv3_14 = BottleneckBlock(
155
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.14")
W
wqz960 已提交
156
        self._conv3_15 = BottleneckBlock(
157
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.15")
W
wqz960 已提交
158
        self._conv3_16 = BottleneckBlock(
159
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.16")
W
wqz960 已提交
160
        self._conv3_17 = BottleneckBlock(
161
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.17")
W
wqz960 已提交
162
        self._conv3_18 = BottleneckBlock(
163
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.18")
W
wqz960 已提交
164
        self._conv3_19 = BottleneckBlock(
165
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.19")
W
wqz960 已提交
166
        self._conv3_20 = BottleneckBlock(
167
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.20")
W
wqz960 已提交
168
        self._conv3_21 = BottleneckBlock(
169
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.21")
W
wqz960 已提交
170
        self._conv3_22 = BottleneckBlock(
171 172
            num_filters[2]//(width//8), num_filters[2], stride=1, cardinality=self.cardinality, width=self.width, name="layer3.22")

W
wqz960 已提交
173
        self._conv4_0 = BottleneckBlock(
174
            num_filters[2]//(width//8), num_filters[3], stride=2, cardinality=self.cardinality, width=self.width, name="layer4.0")
W
wqz960 已提交
175
        self._conv4_1 = BottleneckBlock(
176
            num_filters[3]//(width//8), num_filters[3], stride=1, cardinality=self.cardinality, width=self.width, name="layer4.1")
W
wqz960 已提交
177
        self._conv4_2 = BottleneckBlock(
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
            num_filters[3]//(width//8), num_filters[3], stride=1, cardinality=self.cardinality, width=self.width, name="layer4.2")

        self._avg_pool = Pool2D(pool_type="avg", global_pooling=True)
        self._out = Linear(input_dim=num_filters[3]//(width//8),
                        output_dim=class_dim,
                        param_attr=ParamAttr(name="fc.weight"),
                        bias_attr=ParamAttr(name="fc.bias"))

    def forward(self, inputs):
        x = self._conv_stem(inputs)
        x = self._pool(x)

        x = self._conv1_0(x)
        x = self._conv1_1(x)
        x = self._conv1_2(x)
W
WuHaobo 已提交
193

194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
        x = self._conv2_0(x)
        x = self._conv2_1(x)
        x = self._conv2_2(x)
        x = self._conv2_3(x)

        x = self._conv3_0(x)
        x = self._conv3_1(x)
        x = self._conv3_2(x)
        x = self._conv3_3(x)
        x = self._conv3_4(x)
        x = self._conv3_5(x)
        x = self._conv3_6(x)
        x = self._conv3_7(x)
        x = self._conv3_8(x)
        x = self._conv3_9(x)
        x = self._conv3_10(x)
        x = self._conv3_11(x)
        x = self._conv3_12(x)
        x = self._conv3_13(x)
        x = self._conv3_14(x)
        x = self._conv3_15(x)
        x = self._conv3_16(x)
        x = self._conv3_17(x)
        x = self._conv3_18(x)
        x = self._conv3_19(x)
        x = self._conv3_20(x)
        x = self._conv3_21(x)
        x = self._conv3_22(x)

        x = self._conv4_0(x)
        x = self._conv4_1(x)
        x = self._conv4_2(x)

        x = self._avg_pool(x)
        x = fluid.layers.squeeze(x, axes=[2, 3])
        x = self._out(x)
        return x
W
WuHaobo 已提交
231

W
wqz960 已提交
232 233
def ResNeXt101_32x8d_wsl(**args):
    model = ResNeXt101WSL(cardinality=32, width=8, **args)
234
    return model 
W
WuHaobo 已提交
235

W
wqz960 已提交
236 237
def ResNeXt101_32x16d_wsl(**args):
    model = ResNeXt101WSL(cardinality=32, width=16, **args)
238
    return model 
W
WuHaobo 已提交
239

W
wqz960 已提交
240 241
def ResNeXt101_32x32d_wsl(**args):
    model = ResNeXt101WSL(cardinality=32, width=32, **args)
242
    return model 
W
WuHaobo 已提交
243

W
wqz960 已提交
244 245
def ResNeXt101_32x48d_wsl(**args):
    model = ResNeXt101WSL(cardinality=32, width=48, **args)
W
fix  
wqz960 已提交
246
    return model