未验证 提交 8937205b 编写于 作者: N Nyakku Shigure 提交者: GitHub

add googlenet (#36034)

* update AvgPool2D to AdaptiveAvgPool2D
* class_num -> num_classes
* add en doc
* add googlenet to pretrained test
* remove weights name
* add parameter with_pool
* update en doc
* fix googlenet out shape
* 2020 -> 2021
Co-authored-by: Ainavo's avatarAinavo <ainavo@163.com>
Co-authored-by: Npithygit <pyg20200403@163.com>
Co-authored-by: Ainavo's avatarAinavo <ainavo@163.com>
Co-authored-by: Npithygit <pyg20200403@163.com>
上级 442688a8
......@@ -54,7 +54,7 @@ class TestPretrainedModel(unittest.TestCase):
def test_models(self):
arches = [
'mobilenet_v1', 'mobilenet_v2', 'resnet18', 'vgg16', 'alexnet',
'resnext50_32x4d', 'inception_v3', 'densenet121'
'resnext50_32x4d', 'inception_v3', 'densenet121', 'googlenet'
]
for arch in arches:
self.infer(arch)
......
......@@ -109,6 +109,9 @@ class TestVisonModels(unittest.TestCase):
def test_inception_v3(self):
self.models_infer('inception_v3')
def test_googlenet(self):
self.models_infer('googlenet')
def test_vgg16_num_classes(self):
vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10)
......
......@@ -61,6 +61,8 @@ from .models import resnext152_32x4d # noqa: F401
from .models import resnext152_64x4d # noqa: F401
from .models import InceptionV3 # noqa: F401
from .models import inception_v3 # noqa: F401
from .models import GoogLeNet # noqa: F401
from .models import googlenet # noqa: F401
from .transforms import BaseTransform # noqa: F401
from .transforms import Compose # noqa: F401
from .transforms import Resize # noqa: F401
......
......@@ -45,6 +45,8 @@ from .resnext import resnext152_32x4d # noqa: F401
from .resnext import resnext152_64x4d # noqa: F401
from .inceptionv3 import InceptionV3 # noqa: F401
from .inceptionv3 import inception_v3 # noqa: F401
from .googlenet import GoogLeNet # noqa: F401
from .googlenet import googlenet # noqa: F401
__all__ = [ #noqa
'ResNet',
......@@ -79,5 +81,7 @@ __all__ = [ #noqa
'resnext152_32x4d',
'resnext152_64x4d',
'InceptionV3',
'inception_v3'
'inception_v3',
'GoogLeNet',
'googlenet',
]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, Linear, Dropout
from paddle.nn import MaxPool2D, AvgPool2D, AdaptiveAvgPool2D
from paddle.nn.initializer import Uniform
from paddle.fluid.param_attr import ParamAttr
from paddle.utils.download import get_weights_path_from_url
__all__ = []
model_urls = {
"googlenet":
("https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GoogLeNet_pretrained.pdparams",
"80c06f038e905c53ab32c40eca6e26ae")
}
def xavier(channels, filter_size):
stdv = (3.0 / (filter_size**2 * channels))**0.5
param_attr = ParamAttr(initializer=Uniform(-stdv, stdv))
return param_attr
class ConvLayer(nn.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1):
super(ConvLayer, self).__init__()
self._conv = Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
bias_attr=False)
def forward(self, inputs):
y = self._conv(inputs)
return y
class Inception(nn.Layer):
def __init__(self, input_channels, output_channels, filter1, filter3R,
filter3, filter5R, filter5, proj):
super(Inception, self).__init__()
self._conv1 = ConvLayer(input_channels, filter1, 1)
self._conv3r = ConvLayer(input_channels, filter3R, 1)
self._conv3 = ConvLayer(filter3R, filter3, 3)
self._conv5r = ConvLayer(input_channels, filter5R, 1)
self._conv5 = ConvLayer(filter5R, filter5, 5)
self._pool = MaxPool2D(kernel_size=3, stride=1, padding=1)
self._convprj = ConvLayer(input_channels, proj, 1)
def forward(self, inputs):
conv1 = self._conv1(inputs)
conv3r = self._conv3r(inputs)
conv3 = self._conv3(conv3r)
conv5r = self._conv5r(inputs)
conv5 = self._conv5(conv5r)
pool = self._pool(inputs)
convprj = self._convprj(pool)
cat = paddle.concat([conv1, conv3, conv5, convprj], axis=1)
cat = F.relu(cat)
return cat
class GoogLeNet(nn.Layer):
"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <https://arxiv.org/pdf/1409.4842.pdf>`_
Args:
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool, optional): use pool before the last fc layer or not. Default: True.
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import GoogLeNet
# build model
model = GoogLeNet()
x = paddle.rand([1, 3, 224, 224])
out, out1, out2 = model(x)
print(out.shape)
"""
def __init__(self, num_classes=1000, with_pool=True):
super(GoogLeNet, self).__init__()
self.num_classes = num_classes
self.with_pool = with_pool
self._conv = ConvLayer(3, 64, 7, 2)
self._pool = MaxPool2D(kernel_size=3, stride=2)
self._conv_1 = ConvLayer(64, 64, 1)
self._conv_2 = ConvLayer(64, 192, 3)
self._ince3a = Inception(192, 192, 64, 96, 128, 16, 32, 32)
self._ince3b = Inception(256, 256, 128, 128, 192, 32, 96, 64)
self._ince4a = Inception(480, 480, 192, 96, 208, 16, 48, 64)
self._ince4b = Inception(512, 512, 160, 112, 224, 24, 64, 64)
self._ince4c = Inception(512, 512, 128, 128, 256, 24, 64, 64)
self._ince4d = Inception(512, 512, 112, 144, 288, 32, 64, 64)
self._ince4e = Inception(528, 528, 256, 160, 320, 32, 128, 128)
self._ince5a = Inception(832, 832, 256, 160, 320, 32, 128, 128)
self._ince5b = Inception(832, 832, 384, 192, 384, 48, 128, 128)
if with_pool:
# out
self._pool_5 = AdaptiveAvgPool2D(1)
# out1
self._pool_o1 = AvgPool2D(kernel_size=5, stride=3)
# out2
self._pool_o2 = AvgPool2D(kernel_size=5, stride=3)
if num_classes > 0:
# out
self._drop = Dropout(p=0.4, mode="downscale_in_infer")
self._fc_out = Linear(
1024, num_classes, weight_attr=xavier(1024, 1))
# out1
self._conv_o1 = ConvLayer(512, 128, 1)
self._fc_o1 = Linear(1152, 1024, weight_attr=xavier(2048, 1))
self._drop_o1 = Dropout(p=0.7, mode="downscale_in_infer")
self._out1 = Linear(1024, num_classes, weight_attr=xavier(1024, 1))
# out2
self._conv_o2 = ConvLayer(528, 128, 1)
self._fc_o2 = Linear(1152, 1024, weight_attr=xavier(2048, 1))
self._drop_o2 = Dropout(p=0.7, mode="downscale_in_infer")
self._out2 = Linear(1024, num_classes, weight_attr=xavier(1024, 1))
def forward(self, inputs):
x = self._conv(inputs)
x = self._pool(x)
x = self._conv_1(x)
x = self._conv_2(x)
x = self._pool(x)
x = self._ince3a(x)
x = self._ince3b(x)
x = self._pool(x)
ince4a = self._ince4a(x)
x = self._ince4b(ince4a)
x = self._ince4c(x)
ince4d = self._ince4d(x)
x = self._ince4e(ince4d)
x = self._pool(x)
x = self._ince5a(x)
ince5b = self._ince5b(x)
out, out1, out2 = ince5b, ince4a, ince4d
if self.with_pool:
out = self._pool_5(out)
out1 = self._pool_o1(out1)
out2 = self._pool_o2(out2)
if self.num_classes > 0:
out = self._drop(out)
out = paddle.squeeze(out, axis=[2, 3])
out = self._fc_out(out)
out1 = self._conv_o1(out1)
out1 = paddle.flatten(out1, start_axis=1, stop_axis=-1)
out1 = self._fc_o1(out1)
out1 = F.relu(out1)
out1 = self._drop_o1(out1)
out1 = self._out1(out1)
out2 = self._conv_o2(out2)
out2 = paddle.flatten(out2, start_axis=1, stop_axis=-1)
out2 = self._fc_o2(out2)
out2 = self._drop_o2(out2)
out2 = self._out2(out2)
return [out, out1, out2]
def googlenet(pretrained=False, **kwargs):
"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <https://arxiv.org/pdf/1409.4842.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
import paddle
from paddle.vision.models import googlenet
# build model
model = googlenet()
# build model and load imagenet pretrained weight
# model = googlenet(pretrained=True)
x = paddle.rand([1, 3, 224, 224])
out, out1, out2 = model(x)
print(out.shape)
"""
model = GoogLeNet(**kwargs)
arch = "googlenet"
if pretrained:
assert (
arch in model_urls
), "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
param = paddle.load(weight_path)
model.set_dict(param)
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册