alexnet.py 5.6 KB
Newer Older
C
cuicheng01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
littletomatodonkey's avatar
littletomatodonkey 已提交
16 17 18
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
19 20
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout, ReLU
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
littletomatodonkey's avatar
littletomatodonkey 已提交
21
from paddle.nn.initializer import Uniform
W
WuHaobo 已提交
22
import math
23

C
cuicheng01 已提交
24 25
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url

littletomatodonkey's avatar
littletomatodonkey 已提交
26 27 28 29
MODEL_URLS = {
    "AlexNet":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/AlexNet_pretrained.pdparams"
}
W
wqz960 已提交
30

C
cuicheng01 已提交
31
__all__ = list(MODEL_URLS.keys())
littletomatodonkey's avatar
littletomatodonkey 已提交
32

littletomatodonkey's avatar
littletomatodonkey 已提交
33

littletomatodonkey's avatar
littletomatodonkey 已提交
34 35
class ConvPoolLayer(nn.Layer):
    def __init__(self,
littletomatodonkey's avatar
fix mv1  
littletomatodonkey 已提交
36
                 input_channels,
littletomatodonkey's avatar
littletomatodonkey 已提交
37 38 39 40 41 42 43 44
                 output_channels,
                 filter_size,
                 stride,
                 padding,
                 stdv,
                 groups=1,
                 act=None,
                 name=None):
W
wqz960 已提交
45 46
        super(ConvPoolLayer, self).__init__()

littletomatodonkey's avatar
littletomatodonkey 已提交
47 48
        self.relu = ReLU() if act == "relu" else None

49
        self._conv = Conv2D(
littletomatodonkey's avatar
fix mv1  
littletomatodonkey 已提交
50
            in_channels=input_channels,
littletomatodonkey's avatar
littletomatodonkey 已提交
51 52 53 54 55 56 57 58 59
            out_channels=output_channels,
            kernel_size=filter_size,
            stride=stride,
            padding=padding,
            groups=groups,
            weight_attr=ParamAttr(
                name=name + "_weights", initializer=Uniform(-stdv, stdv)),
            bias_attr=ParamAttr(
                name=name + "_offset", initializer=Uniform(-stdv, stdv)))
60
        self._pool = MaxPool2D(kernel_size=3, stride=2, padding=0)
61 62

    def forward(self, inputs):
W
wqz960 已提交
63
        x = self._conv(inputs)
littletomatodonkey's avatar
littletomatodonkey 已提交
64 65
        if self.relu is not None:
            x = self.relu(x)
W
wqz960 已提交
66
        x = self._pool(x)
67 68 69
        return x


littletomatodonkey's avatar
littletomatodonkey 已提交
70
class AlexNetDY(nn.Layer):
littletomatodonkey's avatar
littletomatodonkey 已提交
71
    def __init__(self, class_num=1000):
W
wqz960 已提交
72 73
        super(AlexNetDY, self).__init__()

littletomatodonkey's avatar
littletomatodonkey 已提交
74
        stdv = 1.0 / math.sqrt(3 * 11 * 11)
W
wqz960 已提交
75
        self._conv1 = ConvPoolLayer(
littletomatodonkey's avatar
littletomatodonkey 已提交
76 77
            3, 64, 11, 4, 2, stdv, act="relu", name="conv1")
        stdv = 1.0 / math.sqrt(64 * 5 * 5)
W
wqz960 已提交
78 79
        self._conv2 = ConvPoolLayer(
            64, 192, 5, 1, 2, stdv, act="relu", name="conv2")
littletomatodonkey's avatar
littletomatodonkey 已提交
80
        stdv = 1.0 / math.sqrt(192 * 3 * 3)
81
        self._conv3 = Conv2D(
littletomatodonkey's avatar
littletomatodonkey 已提交
82 83 84 85 86 87 88 89 90 91
            192,
            384,
            3,
            stride=1,
            padding=1,
            weight_attr=ParamAttr(
                name="conv3_weights", initializer=Uniform(-stdv, stdv)),
            bias_attr=ParamAttr(
                name="conv3_offset", initializer=Uniform(-stdv, stdv)))
        stdv = 1.0 / math.sqrt(384 * 3 * 3)
92
        self._conv4 = Conv2D(
littletomatodonkey's avatar
littletomatodonkey 已提交
93 94 95 96 97 98 99 100 101 102
            384,
            256,
            3,
            stride=1,
            padding=1,
            weight_attr=ParamAttr(
                name="conv4_weights", initializer=Uniform(-stdv, stdv)),
            bias_attr=ParamAttr(
                name="conv4_offset", initializer=Uniform(-stdv, stdv)))
        stdv = 1.0 / math.sqrt(256 * 3 * 3)
W
wqz960 已提交
103 104
        self._conv5 = ConvPoolLayer(
            256, 256, 3, 1, 1, stdv, act="relu", name="conv5")
littletomatodonkey's avatar
littletomatodonkey 已提交
105
        stdv = 1.0 / math.sqrt(256 * 6 * 6)
W
wqz960 已提交
106

littletomatodonkey's avatar
littletomatodonkey 已提交
107
        self._drop1 = Dropout(p=0.5, mode="downscale_in_infer")
littletomatodonkey's avatar
littletomatodonkey 已提交
108 109 110 111 112 113 114 115
        self._fc6 = Linear(
            in_features=256 * 6 * 6,
            out_features=4096,
            weight_attr=ParamAttr(
                name="fc6_weights", initializer=Uniform(-stdv, stdv)),
            bias_attr=ParamAttr(
                name="fc6_offset", initializer=Uniform(-stdv, stdv)))

littletomatodonkey's avatar
littletomatodonkey 已提交
116
        self._drop2 = Dropout(p=0.5, mode="downscale_in_infer")
littletomatodonkey's avatar
littletomatodonkey 已提交
117 118 119 120 121 122 123 124 125
        self._fc7 = Linear(
            in_features=4096,
            out_features=4096,
            weight_attr=ParamAttr(
                name="fc7_weights", initializer=Uniform(-stdv, stdv)),
            bias_attr=ParamAttr(
                name="fc7_offset", initializer=Uniform(-stdv, stdv)))
        self._fc8 = Linear(
            in_features=4096,
littletomatodonkey's avatar
littletomatodonkey 已提交
126
            out_features=class_num,
littletomatodonkey's avatar
littletomatodonkey 已提交
127 128 129 130
            weight_attr=ParamAttr(
                name="fc8_weights", initializer=Uniform(-stdv, stdv)),
            bias_attr=ParamAttr(
                name="fc8_offset", initializer=Uniform(-stdv, stdv)))
131 132 133 134

    def forward(self, inputs):
        x = self._conv1(inputs)
        x = self._conv2(x)
W
wqz960 已提交
135
        x = self._conv3(x)
littletomatodonkey's avatar
littletomatodonkey 已提交
136
        x = F.relu(x)
W
wqz960 已提交
137
        x = self._conv4(x)
littletomatodonkey's avatar
littletomatodonkey 已提交
138
        x = F.relu(x)
W
wqz960 已提交
139
        x = self._conv5(x)
littletomatodonkey's avatar
littletomatodonkey 已提交
140
        x = paddle.flatten(x, start_axis=1, stop_axis=-1)
W
wqz960 已提交
141 142
        x = self._drop1(x)
        x = self._fc6(x)
littletomatodonkey's avatar
littletomatodonkey 已提交
143
        x = F.relu(x)
W
wqz960 已提交
144 145
        x = self._drop2(x)
        x = self._fc7(x)
littletomatodonkey's avatar
littletomatodonkey 已提交
146
        x = F.relu(x)
W
wqz960 已提交
147
        x = self._fc8(x)
148
        return x
W
WuHaobo 已提交
149

littletomatodonkey's avatar
littletomatodonkey 已提交
150

C
cuicheng01 已提交
151 152 153 154 155 156 157 158 159 160 161
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )
littletomatodonkey's avatar
littletomatodonkey 已提交
162

littletomatodonkey's avatar
littletomatodonkey 已提交
163

C
cuicheng01 已提交
164 165
def AlexNet(pretrained=False, use_ssld=False, **kwargs):
    model = AlexNetDY(**kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
166 167
    _load_pretrained(
        pretrained, model, MODEL_URLS["AlexNet"], use_ssld=use_ssld)
168
    return model