ghostnet.py 12.7 KB
Newer Older
C
cuicheng01 已提交
1
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
W
weishengyu 已提交
2
#
W
weishengyu 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
weishengyu 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
W
weishengyu 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
weishengyu 已提交
14

C
cuicheng01 已提交
15
# Code was based on https://github.com/huawei-noah/CV-Backbones/tree/master/ghostnet_pytorch
G
gaotingquan 已提交
16
# reference: https://arxiv.org/abs/1911.11907
C
cuicheng01 已提交
17

W
weishengyu 已提交
18 19 20 21 22
import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
23 24 25
from paddle.nn import Conv2D, BatchNorm, AdaptiveAvgPool2D, Linear
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Uniform, KaimingNormal
W
weishengyu 已提交
26

C
cuicheng01 已提交
27 28
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url

littletomatodonkey's avatar
littletomatodonkey 已提交
29 30 31 32 33 34 35 36
MODEL_URLS = {
    "GhostNet_x0_5":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x0_5_pretrained.pdparams",
    "GhostNet_x1_0":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_0_pretrained.pdparams",
    "GhostNet_x1_3":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_pretrained.pdparams",
}
C
cuicheng01 已提交
37 38

__all__ = list(MODEL_URLS.keys())
L
littletomatodonkey 已提交
39

W
weishengyu 已提交
40 41

class ConvBNLayer(nn.Layer):
W
weishengyu 已提交
42 43 44 45 46 47 48 49
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 groups=1,
                 act="relu",
                 name=None):
W
weishengyu 已提交
50
        super(ConvBNLayer, self).__init__()
51
        self._conv = Conv2D(
W
weishengyu 已提交
52 53 54
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
W
weishengyu 已提交
55
            stride=stride,
W
weishengyu 已提交
56
            padding=(kernel_size - 1) // 2,
W
weishengyu 已提交
57
            groups=groups,
W
weishengyu 已提交
58
            weight_attr=ParamAttr(
59
                initializer=KaimingNormal(), name=name + "_weights"),
W
weishengyu 已提交
60
            bias_attr=False)
W
weishengyu 已提交
61
        bn_name = name + "_bn"
W
weishengyu 已提交
62

W
weishengyu 已提交
63
        self._batch_norm = BatchNorm(
W
weishengyu 已提交
64
            num_channels=out_channels,
W
weishengyu 已提交
65
            act=act,
W
weishengyu 已提交
66
            param_attr=ParamAttr(
67
                name=bn_name + "_scale", regularizer=L2Decay(0.0)),
W
weishengyu 已提交
68
            bias_attr=ParamAttr(
69
                name=bn_name + "_offset", regularizer=L2Decay(0.0)),
W
weishengyu 已提交
70
            moving_mean_name=bn_name + "_mean",
L
littletomatodonkey 已提交
71
            moving_variance_name=bn_name + "_variance")
W
weishengyu 已提交
72 73 74 75 76 77 78 79

    def forward(self, inputs):
        y = self._conv(inputs)
        y = self._batch_norm(y)
        return y


class SEBlock(nn.Layer):
W
weishengyu 已提交
80
    def __init__(self, num_channels, reduction_ratio=4, name=None):
W
weishengyu 已提交
81
        super(SEBlock, self).__init__()
82
        self.pool2d_gap = AdaptiveAvgPool2D(1)
W
weishengyu 已提交
83 84 85 86 87 88
        self._num_channels = num_channels
        stdv = 1.0 / math.sqrt(num_channels * 1.0)
        med_ch = num_channels // reduction_ratio
        self.squeeze = Linear(
            num_channels,
            med_ch,
W
weishengyu 已提交
89 90 91
            weight_attr=ParamAttr(
                initializer=Uniform(-stdv, stdv), name=name + "_1_weights"),
            bias_attr=ParamAttr(name=name + "_1_offset"))
W
weishengyu 已提交
92 93 94 95
        stdv = 1.0 / math.sqrt(med_ch * 1.0)
        self.excitation = Linear(
            med_ch,
            num_channels,
W
weishengyu 已提交
96 97 98
            weight_attr=ParamAttr(
                initializer=Uniform(-stdv, stdv), name=name + "_2_weights"),
            bias_attr=ParamAttr(name=name + "_2_offset"))
W
weishengyu 已提交
99 100 101

    def forward(self, inputs):
        pool = self.pool2d_gap(inputs)
L
littletomatodonkey 已提交
102
        pool = paddle.squeeze(pool, axis=[2, 3])
W
weishengyu 已提交
103 104 105
        squeeze = self.squeeze(pool)
        squeeze = F.relu(squeeze)
        excitation = self.excitation(squeeze)
106
        excitation = paddle.clip(x=excitation, min=0, max=1)
L
littletomatodonkey 已提交
107
        excitation = paddle.unsqueeze(excitation, axis=[2, 3])
108
        out = paddle.multiply(inputs, excitation)
W
weishengyu 已提交
109 110 111 112
        return out


class GhostModule(nn.Layer):
W
weishengyu 已提交
113 114 115 116 117 118 119 120 121
    def __init__(self,
                 in_channels,
                 output_channels,
                 kernel_size=1,
                 ratio=2,
                 dw_size=3,
                 stride=1,
                 relu=True,
                 name=None):
W
weishengyu 已提交
122 123 124 125
        super(GhostModule, self).__init__()
        init_channels = int(math.ceil(output_channels / ratio))
        new_channels = int(init_channels * (ratio - 1))
        self.primary_conv = ConvBNLayer(
W
weishengyu 已提交
126 127 128
            in_channels=in_channels,
            out_channels=init_channels,
            kernel_size=kernel_size,
W
weishengyu 已提交
129 130 131
            stride=stride,
            groups=1,
            act="relu" if relu else None,
W
weishengyu 已提交
132
            name=name + "_primary_conv")
W
weishengyu 已提交
133
        self.cheap_operation = ConvBNLayer(
W
weishengyu 已提交
134 135 136
            in_channels=init_channels,
            out_channels=new_channels,
            kernel_size=dw_size,
W
weishengyu 已提交
137 138 139
            stride=1,
            groups=init_channels,
            act="relu" if relu else None,
W
weishengyu 已提交
140
            name=name + "_cheap_operation")
W
weishengyu 已提交
141 142 143 144 145 146 147 148 149

    def forward(self, inputs):
        x = self.primary_conv(inputs)
        y = self.cheap_operation(x)
        out = paddle.concat([x, y], axis=1)
        return out


class GhostBottleneck(nn.Layer):
W
weishengyu 已提交
150 151 152 153 154 155 156 157
    def __init__(self,
                 in_channels,
                 hidden_dim,
                 output_channels,
                 kernel_size,
                 stride,
                 use_se,
                 name=None):
W
weishengyu 已提交
158 159 160
        super(GhostBottleneck, self).__init__()
        self._stride = stride
        self._use_se = use_se
W
weishengyu 已提交
161
        self._num_channels = in_channels
W
weishengyu 已提交
162 163
        self._output_channels = output_channels
        self.ghost_module_1 = GhostModule(
W
weishengyu 已提交
164
            in_channels=in_channels,
W
weishengyu 已提交
165 166 167 168
            output_channels=hidden_dim,
            kernel_size=1,
            stride=1,
            relu=True,
W
weishengyu 已提交
169
            name=name + "_ghost_module_1")
W
weishengyu 已提交
170 171
        if stride == 2:
            self.depthwise_conv = ConvBNLayer(
W
weishengyu 已提交
172 173 174
                in_channels=hidden_dim,
                out_channels=hidden_dim,
                kernel_size=kernel_size,
W
weishengyu 已提交
175 176 177
                stride=stride,
                groups=hidden_dim,
                act=None,
W
weishengyu 已提交
178 179
                name=name +
                "_depthwise_depthwise"  # looks strange due to an old typo, will be fixed later.
W
weishengyu 已提交
180 181
            )
        if use_se:
W
weishengyu 已提交
182
            self.se_block = SEBlock(num_channels=hidden_dim, name=name + "_se")
W
weishengyu 已提交
183
        self.ghost_module_2 = GhostModule(
W
weishengyu 已提交
184
            in_channels=hidden_dim,
W
weishengyu 已提交
185 186 187
            output_channels=output_channels,
            kernel_size=1,
            relu=False,
W
weishengyu 已提交
188
            name=name + "_ghost_module_2")
W
weishengyu 已提交
189
        if stride != 1 or in_channels != output_channels:
W
weishengyu 已提交
190
            self.shortcut_depthwise = ConvBNLayer(
W
weishengyu 已提交
191 192 193
                in_channels=in_channels,
                out_channels=in_channels,
                kernel_size=kernel_size,
W
weishengyu 已提交
194
                stride=stride,
W
weishengyu 已提交
195
                groups=in_channels,
W
weishengyu 已提交
196
                act=None,
W
weishengyu 已提交
197 198
                name=name +
                "_shortcut_depthwise_depthwise"  # looks strange due to an old typo, will be fixed later.
W
weishengyu 已提交
199 200
            )
            self.shortcut_conv = ConvBNLayer(
W
weishengyu 已提交
201 202 203
                in_channels=in_channels,
                out_channels=output_channels,
                kernel_size=1,
W
weishengyu 已提交
204 205 206
                stride=1,
                groups=1,
                act=None,
W
weishengyu 已提交
207
                name=name + "_shortcut_conv")
W
weishengyu 已提交
208 209

    def forward(self, inputs):
W
weishengyu 已提交
210
        x = self.ghost_module_1(inputs)
W
weishengyu 已提交
211 212 213 214 215 216 217 218 219 220
        if self._stride == 2:
            x = self.depthwise_conv(x)
        if self._use_se:
            x = self.se_block(x)
        x = self.ghost_module_2(x)
        if self._stride == 1 and self._num_channels == self._output_channels:
            shortcut = inputs
        else:
            shortcut = self.shortcut_depthwise(inputs)
            shortcut = self.shortcut_conv(shortcut)
221
        return paddle.add(x=x, y=shortcut)
W
weishengyu 已提交
222 223 224


class GhostNet(nn.Layer):
littletomatodonkey's avatar
littletomatodonkey 已提交
225
    def __init__(self, scale, class_num=1000):
W
weishengyu 已提交
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
        super(GhostNet, self).__init__()
        self.cfgs = [
            # k, t, c, SE, s
            [3, 16, 16, 0, 1],
            [3, 48, 24, 0, 2],
            [3, 72, 24, 0, 1],
            [5, 72, 40, 1, 2],
            [5, 120, 40, 1, 1],
            [3, 240, 80, 0, 2],
            [3, 200, 80, 0, 1],
            [3, 184, 80, 0, 1],
            [3, 184, 80, 0, 1],
            [3, 480, 112, 1, 1],
            [3, 672, 112, 1, 1],
            [5, 672, 160, 1, 2],
            [5, 960, 160, 0, 1],
            [5, 960, 160, 1, 1],
            [5, 960, 160, 0, 1],
            [5, 960, 160, 1, 1]
        ]
        self.scale = scale
        output_channels = int(self._make_divisible(16 * self.scale, 4))
        self.conv1 = ConvBNLayer(
W
weishengyu 已提交
249 250 251
            in_channels=3,
            out_channels=output_channels,
            kernel_size=3,
W
weishengyu 已提交
252 253 254
            stride=2,
            groups=1,
            act="relu",
W
weishengyu 已提交
255
            name="conv1")
W
weishengyu 已提交
256 257 258 259
        # build inverted residual blocks
        idx = 0
        self.ghost_bottleneck_list = []
        for k, exp_size, c, use_se, s in self.cfgs:
W
weishengyu 已提交
260
            in_channels = output_channels
W
weishengyu 已提交
261
            output_channels = int(self._make_divisible(c * self.scale, 4))
W
dbg  
weishengyu 已提交
262
            hidden_dim = int(self._make_divisible(exp_size * self.scale, 4))
W
weishengyu 已提交
263 264 265
            ghost_bottleneck = self.add_sublayer(
                name="_ghostbottleneck_" + str(idx),
                sublayer=GhostBottleneck(
W
weishengyu 已提交
266
                    in_channels=in_channels,
W
weishengyu 已提交
267 268 269 270 271
                    hidden_dim=hidden_dim,
                    output_channels=output_channels,
                    kernel_size=k,
                    stride=s,
                    use_se=use_se,
W
weishengyu 已提交
272
                    name="_ghostbottleneck_" + str(idx)))
W
weishengyu 已提交
273 274 275
            self.ghost_bottleneck_list.append(ghost_bottleneck)
            idx += 1
        # build last several layers
W
weishengyu 已提交
276
        in_channels = output_channels
W
weishengyu 已提交
277 278
        output_channels = int(self._make_divisible(exp_size * self.scale, 4))
        self.conv_last = ConvBNLayer(
W
weishengyu 已提交
279 280 281
            in_channels=in_channels,
            out_channels=output_channels,
            kernel_size=1,
W
weishengyu 已提交
282 283 284
            stride=1,
            groups=1,
            act="relu",
W
weishengyu 已提交
285
            name="conv_last")
286
        self.pool2d_gap = AdaptiveAvgPool2D(1)
W
weishengyu 已提交
287
        in_channels = output_channels
W
weishengyu 已提交
288
        self._fc0_output_channels = 1280
W
weishengyu 已提交
289
        self.fc_0 = ConvBNLayer(
W
weishengyu 已提交
290 291 292
            in_channels=in_channels,
            out_channels=self._fc0_output_channels,
            kernel_size=1,
W
weishengyu 已提交
293 294
            stride=1,
            act="relu",
W
weishengyu 已提交
295
            name="fc_0")
W
weishengyu 已提交
296
        self.dropout = nn.Dropout(p=0.2)
W
weishengyu 已提交
297
        stdv = 1.0 / math.sqrt(self._fc0_output_channels * 1.0)
W
weishengyu 已提交
298
        self.fc_1 = Linear(
W
weishengyu 已提交
299
            self._fc0_output_channels,
littletomatodonkey's avatar
littletomatodonkey 已提交
300
            class_num,
W
weishengyu 已提交
301 302 303
            weight_attr=ParamAttr(
                name="fc_1_weights", initializer=Uniform(-stdv, stdv)),
            bias_attr=ParamAttr(name="fc_1_offset"))
W
weishengyu 已提交
304 305 306 307 308 309 310 311 312

    def forward(self, inputs):
        x = self.conv1(inputs)
        for ghost_bottleneck in self.ghost_bottleneck_list:
            x = ghost_bottleneck(x)
        x = self.conv_last(x)
        x = self.pool2d_gap(x)
        x = self.fc_0(x)
        x = self.dropout(x)
W
weishengyu 已提交
313
        x = paddle.reshape(x, shape=[-1, self._fc0_output_channels])
W
weishengyu 已提交
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
        x = self.fc_1(x)
        return x

    def _make_divisible(self, v, divisor, min_value=None):
        """
        This function is taken from the original tf repo.
        It ensures that all layers have a channel number that is divisible by 8
        It can be seen here:
        https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
        """
        if min_value is None:
            min_value = divisor
        new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
        # Make sure that round down does not go down by more than 10%.
        if new_v < 0.9 * v:
            new_v += divisor
        return new_v
W
weishengyu 已提交
331

littletomatodonkey's avatar
littletomatodonkey 已提交
332

C
cuicheng01 已提交
333 334 335 336 337 338 339 340 341 342 343 344
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )

W
weishengyu 已提交
345

C
cuicheng01 已提交
346 347
def GhostNet_x0_5(pretrained=False, use_ssld=False, **kwargs):
    model = GhostNet(scale=0.5, **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
348 349
    _load_pretrained(
        pretrained, model, MODEL_URLS["GhostNet_x0_5"], use_ssld=use_ssld)
W
weishengyu 已提交
350 351 352
    return model


C
cuicheng01 已提交
353 354
def GhostNet_x1_0(pretrained=False, use_ssld=False, **kwargs):
    model = GhostNet(scale=1.0, **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
355 356
    _load_pretrained(
        pretrained, model, MODEL_URLS["GhostNet_x1_0"], use_ssld=use_ssld)
W
weishengyu 已提交
357 358 359
    return model


C
cuicheng01 已提交
360 361
def GhostNet_x1_3(pretrained=False, use_ssld=False, **kwargs):
    model = GhostNet(scale=1.3, **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
362 363
    _load_pretrained(
        pretrained, model, MODEL_URLS["GhostNet_x1_3"], use_ssld=use_ssld)
W
weishengyu 已提交
364
    return model