alexnet.py 5.7 KB
Newer Older
C
cuicheng01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

G
gaotingquan 已提交
15 16
# reference: https://proceedings.neurips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf

17
import paddle
littletomatodonkey's avatar
littletomatodonkey 已提交
18 19 20
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
21 22
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout, ReLU
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
littletomatodonkey's avatar
littletomatodonkey 已提交
23
from paddle.nn.initializer import Uniform
W
WuHaobo 已提交
24
import math
25

R
root 已提交
26
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
C
cuicheng01 已提交
27

littletomatodonkey's avatar
littletomatodonkey 已提交
28 29 30 31
MODEL_URLS = {
    "AlexNet":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/AlexNet_pretrained.pdparams"
}
W
wqz960 已提交
32

C
cuicheng01 已提交
33
__all__ = list(MODEL_URLS.keys())
littletomatodonkey's avatar
littletomatodonkey 已提交
34

littletomatodonkey's avatar
littletomatodonkey 已提交
35

littletomatodonkey's avatar
littletomatodonkey 已提交
36 37
class ConvPoolLayer(nn.Layer):
    def __init__(self,
littletomatodonkey's avatar
fix mv1  
littletomatodonkey 已提交
38
                 input_channels,
littletomatodonkey's avatar
littletomatodonkey 已提交
39 40 41 42 43 44 45 46
                 output_channels,
                 filter_size,
                 stride,
                 padding,
                 stdv,
                 groups=1,
                 act=None,
                 name=None):
W
wqz960 已提交
47 48
        super(ConvPoolLayer, self).__init__()

littletomatodonkey's avatar
littletomatodonkey 已提交
49 50
        self.relu = ReLU() if act == "relu" else None

51
        self._conv = Conv2D(
littletomatodonkey's avatar
fix mv1  
littletomatodonkey 已提交
52
            in_channels=input_channels,
littletomatodonkey's avatar
littletomatodonkey 已提交
53 54 55 56 57 58 59 60 61
            out_channels=output_channels,
            kernel_size=filter_size,
            stride=stride,
            padding=padding,
            groups=groups,
            weight_attr=ParamAttr(
                name=name + "_weights", initializer=Uniform(-stdv, stdv)),
            bias_attr=ParamAttr(
                name=name + "_offset", initializer=Uniform(-stdv, stdv)))
62
        self._pool = MaxPool2D(kernel_size=3, stride=2, padding=0)
63 64

    def forward(self, inputs):
W
wqz960 已提交
65
        x = self._conv(inputs)
littletomatodonkey's avatar
littletomatodonkey 已提交
66 67
        if self.relu is not None:
            x = self.relu(x)
W
wqz960 已提交
68
        x = self._pool(x)
69 70 71
        return x


littletomatodonkey's avatar
littletomatodonkey 已提交
72
class AlexNetDY(nn.Layer):
littletomatodonkey's avatar
littletomatodonkey 已提交
73
    def __init__(self, class_num=1000):
W
wqz960 已提交
74 75
        super(AlexNetDY, self).__init__()

littletomatodonkey's avatar
littletomatodonkey 已提交
76
        stdv = 1.0 / math.sqrt(3 * 11 * 11)
W
wqz960 已提交
77
        self._conv1 = ConvPoolLayer(
littletomatodonkey's avatar
littletomatodonkey 已提交
78 79
            3, 64, 11, 4, 2, stdv, act="relu", name="conv1")
        stdv = 1.0 / math.sqrt(64 * 5 * 5)
W
wqz960 已提交
80 81
        self._conv2 = ConvPoolLayer(
            64, 192, 5, 1, 2, stdv, act="relu", name="conv2")
littletomatodonkey's avatar
littletomatodonkey 已提交
82
        stdv = 1.0 / math.sqrt(192 * 3 * 3)
83
        self._conv3 = Conv2D(
littletomatodonkey's avatar
littletomatodonkey 已提交
84 85 86 87 88 89 90 91 92 93
            192,
            384,
            3,
            stride=1,
            padding=1,
            weight_attr=ParamAttr(
                name="conv3_weights", initializer=Uniform(-stdv, stdv)),
            bias_attr=ParamAttr(
                name="conv3_offset", initializer=Uniform(-stdv, stdv)))
        stdv = 1.0 / math.sqrt(384 * 3 * 3)
94
        self._conv4 = Conv2D(
littletomatodonkey's avatar
littletomatodonkey 已提交
95 96 97 98 99 100 101 102 103 104
            384,
            256,
            3,
            stride=1,
            padding=1,
            weight_attr=ParamAttr(
                name="conv4_weights", initializer=Uniform(-stdv, stdv)),
            bias_attr=ParamAttr(
                name="conv4_offset", initializer=Uniform(-stdv, stdv)))
        stdv = 1.0 / math.sqrt(256 * 3 * 3)
W
wqz960 已提交
105 106
        self._conv5 = ConvPoolLayer(
            256, 256, 3, 1, 1, stdv, act="relu", name="conv5")
littletomatodonkey's avatar
littletomatodonkey 已提交
107
        stdv = 1.0 / math.sqrt(256 * 6 * 6)
W
wqz960 已提交
108

littletomatodonkey's avatar
littletomatodonkey 已提交
109
        self._drop1 = Dropout(p=0.5, mode="downscale_in_infer")
littletomatodonkey's avatar
littletomatodonkey 已提交
110 111 112 113 114 115 116 117
        self._fc6 = Linear(
            in_features=256 * 6 * 6,
            out_features=4096,
            weight_attr=ParamAttr(
                name="fc6_weights", initializer=Uniform(-stdv, stdv)),
            bias_attr=ParamAttr(
                name="fc6_offset", initializer=Uniform(-stdv, stdv)))

littletomatodonkey's avatar
littletomatodonkey 已提交
118
        self._drop2 = Dropout(p=0.5, mode="downscale_in_infer")
littletomatodonkey's avatar
littletomatodonkey 已提交
119 120 121 122 123 124 125 126 127
        self._fc7 = Linear(
            in_features=4096,
            out_features=4096,
            weight_attr=ParamAttr(
                name="fc7_weights", initializer=Uniform(-stdv, stdv)),
            bias_attr=ParamAttr(
                name="fc7_offset", initializer=Uniform(-stdv, stdv)))
        self._fc8 = Linear(
            in_features=4096,
littletomatodonkey's avatar
littletomatodonkey 已提交
128
            out_features=class_num,
littletomatodonkey's avatar
littletomatodonkey 已提交
129 130 131 132
            weight_attr=ParamAttr(
                name="fc8_weights", initializer=Uniform(-stdv, stdv)),
            bias_attr=ParamAttr(
                name="fc8_offset", initializer=Uniform(-stdv, stdv)))
133 134 135 136

    def forward(self, inputs):
        x = self._conv1(inputs)
        x = self._conv2(x)
W
wqz960 已提交
137
        x = self._conv3(x)
littletomatodonkey's avatar
littletomatodonkey 已提交
138
        x = F.relu(x)
W
wqz960 已提交
139
        x = self._conv4(x)
littletomatodonkey's avatar
littletomatodonkey 已提交
140
        x = F.relu(x)
W
wqz960 已提交
141
        x = self._conv5(x)
littletomatodonkey's avatar
littletomatodonkey 已提交
142
        x = paddle.flatten(x, start_axis=1, stop_axis=-1)
W
wqz960 已提交
143 144
        x = self._drop1(x)
        x = self._fc6(x)
littletomatodonkey's avatar
littletomatodonkey 已提交
145
        x = F.relu(x)
W
wqz960 已提交
146 147
        x = self._drop2(x)
        x = self._fc7(x)
littletomatodonkey's avatar
littletomatodonkey 已提交
148
        x = F.relu(x)
W
wqz960 已提交
149
        x = self._fc8(x)
150
        return x
W
WuHaobo 已提交
151

littletomatodonkey's avatar
littletomatodonkey 已提交
152

C
cuicheng01 已提交
153 154 155 156 157 158 159 160 161 162 163
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )
littletomatodonkey's avatar
littletomatodonkey 已提交
164

littletomatodonkey's avatar
littletomatodonkey 已提交
165

C
cuicheng01 已提交
166 167
def AlexNet(pretrained=False, use_ssld=False, **kwargs):
    model = AlexNetDY(**kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
168 169
    _load_pretrained(
        pretrained, model, MODEL_URLS["AlexNet"], use_ssld=use_ssld)
170
    return model