hardnet.py 8.7 KB
Newer Older
jm_12138's avatar
jm_12138 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17 18 19 20 21
import paddle
import paddle.nn as nn

from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url

MODEL_URLS = {
    'HarDNet39_ds':
C
cuicheng01 已提交
22
    'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/HarDNet39_ds_pretrained.pdparams',
23
    'HarDNet68_ds':
C
cuicheng01 已提交
24
    'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/HarDNet68_ds_pretrained.pdparams',
25
    'HarDNet68':
C
cuicheng01 已提交
26
    'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/HarDNet68_pretrained.pdparams',
27
    'HarDNet85':
C
cuicheng01 已提交
28
    'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/HarDNet85_pretrained.pdparams'
29 30
}

jm_12138's avatar
jm_12138 已提交
31 32 33
__all__ = MODEL_URLS.keys()


littletomatodonkey's avatar
littletomatodonkey 已提交
34 35 36 37 38
def ConvLayer(in_channels,
              out_channels,
              kernel_size=3,
              stride=1,
              bias_attr=False):
39 40
    layer = nn.Sequential(
        ('conv', nn.Conv2D(
littletomatodonkey's avatar
littletomatodonkey 已提交
41 42 43 44 45 46 47 48
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=kernel_size // 2,
            groups=1,
            bias_attr=bias_attr)), ('norm', nn.BatchNorm2D(out_channels)),
        ('relu', nn.ReLU6()))
49 50 51
    return layer


littletomatodonkey's avatar
littletomatodonkey 已提交
52 53 54 55 56
def DWConvLayer(in_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                bias_attr=False):
57 58
    layer = nn.Sequential(
        ('dwconv', nn.Conv2D(
littletomatodonkey's avatar
littletomatodonkey 已提交
59 60 61 62 63 64 65
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=1,
            groups=out_channels,
            bias_attr=bias_attr)), ('norm', nn.BatchNorm2D(out_channels)))
66 67 68 69 70
    return layer


def CombConvLayer(in_channels, out_channels, kernel_size=1, stride=1):
    layer = nn.Sequential(
littletomatodonkey's avatar
littletomatodonkey 已提交
71 72 73 74
        ('layer1', ConvLayer(
            in_channels, out_channels, kernel_size=kernel_size)),
        ('layer2', DWConvLayer(
            out_channels, out_channels, stride=stride)))
75 76 77 78
    return layer


class HarDBlock(nn.Layer):
littletomatodonkey's avatar
littletomatodonkey 已提交
79 80 81 82 83 84 85 86
    def __init__(self,
                 in_channels,
                 growth_rate,
                 grmul,
                 n_layers,
                 keepBase=False,
                 residual_out=False,
                 dwconv=False):
87 88 89 90 91 92
        super().__init__()
        self.keepBase = keepBase
        self.links = []
        layers_ = []
        self.out_channels = 0  # if upsample else in_channels
        for i in range(n_layers):
littletomatodonkey's avatar
littletomatodonkey 已提交
93 94
            outch, inch, link = self.get_link(i + 1, in_channels, growth_rate,
                                              grmul)
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
            self.links.append(link)
            if dwconv:
                layers_.append(CombConvLayer(inch, outch))
            else:
                layers_.append(ConvLayer(inch, outch))

            if (i % 2 == 0) or (i == n_layers - 1):
                self.out_channels += outch
        # print("Blk out =",self.out_channels)
        self.layers = nn.LayerList(layers_)

    def get_link(self, layer, base_ch, growth_rate, grmul):
        if layer == 0:
            return base_ch, 0, []
        out_channels = growth_rate

        link = []
        for i in range(10):
littletomatodonkey's avatar
littletomatodonkey 已提交
113
            dv = 2**i
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
            if layer % dv == 0:
                k = layer - dv
                link.append(k)
                if i > 0:
                    out_channels *= grmul

        out_channels = int(int(out_channels + 1) / 2) * 2
        in_channels = 0

        for i in link:
            ch, _, _ = self.get_link(i, base_ch, growth_rate, grmul)
            in_channels += ch

        return out_channels, in_channels, link

    def forward(self, x):
        layers_ = [x]

        for layer in range(len(self.layers)):
            link = self.links[layer]
            tin = []
            for i in link:
                tin.append(layers_[i])
            if len(tin) > 1:
                x = paddle.concat(tin, 1)
            else:
                x = tin[0]
            out = self.layers[layer](x)
            layers_.append(out)

        t = len(layers_)
        out_ = []
        for i in range(t):
littletomatodonkey's avatar
littletomatodonkey 已提交
147
            if (i == 0 and self.keepBase) or (i == t - 1) or (i % 2 == 1):
148 149 150 151 152 153 154
                out_.append(layers_[i])
        out = paddle.concat(out_, 1)

        return out


class HarDNet(nn.Layer):
littletomatodonkey's avatar
littletomatodonkey 已提交
155 156 157 158 159
    def __init__(self,
                 depth_wise=False,
                 arch=85,
                 class_num=1000,
                 with_pool=True):
160 161 162 163 164 165 166 167 168 169
        super().__init__()
        first_ch = [32, 64]
        second_kernel = 3
        max_pool = True
        grmul = 1.7
        drop_rate = 0.1

        # HarDNet68
        ch_list = [128, 256, 320, 640, 1024]
        gr = [14, 16, 20, 40, 160]
littletomatodonkey's avatar
littletomatodonkey 已提交
170 171
        n_layers = [8, 16, 16, 16, 4]
        downSamp = [1, 0, 1, 1, 0]
172 173 174 175 176

        if arch == 85:
            # HarDNet85
            first_ch = [48, 96]
            ch_list = [192, 256, 320, 480, 720, 1280]
littletomatodonkey's avatar
littletomatodonkey 已提交
177 178 179
            gr = [24, 24, 28, 36, 48, 256]
            n_layers = [8, 16, 16, 16, 16, 4]
            downSamp = [1, 0, 1, 0, 1, 0]
180 181 182 183 184 185 186
            drop_rate = 0.2

        elif arch == 39:
            # HarDNet39
            first_ch = [24, 48]
            ch_list = [96, 320, 640, 1024]
            grmul = 1.6
littletomatodonkey's avatar
littletomatodonkey 已提交
187 188 189
            gr = [16, 20, 64, 160]
            n_layers = [4, 16, 8, 4]
            downSamp = [1, 1, 1, 0]
190 191 192 193 194 195 196 197 198 199 200

        if depth_wise:
            second_kernel = 1
            max_pool = False
            drop_rate = 0.05

        blks = len(n_layers)
        self.base = nn.LayerList([])

        # First Layer: Standard Conv3x3, Stride=2
        self.base.append(
littletomatodonkey's avatar
littletomatodonkey 已提交
201 202 203 204 205 206
            ConvLayer(
                in_channels=3,
                out_channels=first_ch[0],
                kernel_size=3,
                stride=2,
                bias_attr=False))
207 208 209

        # Second Layer
        self.base.append(
littletomatodonkey's avatar
littletomatodonkey 已提交
210 211
            ConvLayer(
                first_ch[0], first_ch[1], kernel_size=second_kernel))
212 213 214 215 216 217 218 219 220 221 222 223 224 225

        # Maxpooling or DWConv3x3 downsampling
        if max_pool:
            self.base.append(nn.MaxPool2D(kernel_size=3, stride=2, padding=1))
        else:
            self.base.append(DWConvLayer(first_ch[1], first_ch[1], stride=2))

        # Build all HarDNet blocks
        ch = first_ch[1]
        for i in range(blks):
            blk = HarDBlock(ch, gr[i], grmul, n_layers[i], dwconv=depth_wise)
            ch = blk.out_channels
            self.base.append(blk)

littletomatodonkey's avatar
littletomatodonkey 已提交
226
            if i == blks - 1 and arch == 85:
227 228 229 230 231 232 233 234 235 236
                self.base.append(nn.Dropout(0.1))

            self.base.append(ConvLayer(ch, ch_list[i], kernel_size=1))
            ch = ch_list[i]
            if downSamp[i] == 1:
                if max_pool:
                    self.base.append(nn.MaxPool2D(kernel_size=2, stride=2))
                else:
                    self.base.append(DWConvLayer(ch, ch, stride=2))

littletomatodonkey's avatar
littletomatodonkey 已提交
237
        ch = ch_list[blks - 1]
238 239 240 241 242 243

        layers = []

        if with_pool:
            layers.append(nn.AdaptiveAvgPool2D((1, 1)))

littletomatodonkey's avatar
littletomatodonkey 已提交
244
        if class_num > 0:
245 246
            layers.append(nn.Flatten())
            layers.append(nn.Dropout(drop_rate))
littletomatodonkey's avatar
littletomatodonkey 已提交
247
            layers.append(nn.Linear(ch, class_num))
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290

        self.base.append(nn.Sequential(*layers))

    def forward(self, x):
        for layer in self.base:
            x = layer(x)
        return x


def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )


def HarDNet39_ds(pretrained=False, **kwargs):
    model = HarDNet(arch=39, depth_wise=True, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["HarDNet39_ds"])
    return model


def HarDNet68_ds(pretrained=False, **kwargs):
    model = HarDNet(arch=68, depth_wise=True, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["HarDNet68_ds"])
    return model


def HarDNet68(pretrained=False, **kwargs):
    model = HarDNet(arch=68, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["HarDNet68"])
    return model


def HarDNet85(pretrained=False, **kwargs):
    model = HarDNet(arch=85, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["HarDNet85"])
jm_12138's avatar
jm_12138 已提交
291
    return model