未验证 提交 01a14e1b 编写于 作者: L LielinJiang 提交者: GitHub

Add with_pool args for vgg (#28684)

* add arg for vgg
上级 532e4bbf
...@@ -245,7 +245,7 @@ class ResNet(nn.Layer): ...@@ -245,7 +245,7 @@ class ResNet(nn.Layer):
x = self.layer3(x) x = self.layer3(x)
x = self.layer4(x) x = self.layer4(x)
if self.with_pool > 0: if self.with_pool:
x = self.avgpool(x) x = self.avgpool(x)
if self.num_classes > 0: if self.num_classes > 0:
......
...@@ -36,9 +36,10 @@ class VGG(nn.Layer): ...@@ -36,9 +36,10 @@ class VGG(nn.Layer):
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
Args: Args:
features (nn.Layer): vgg features create by function make_layers. features (nn.Layer): Vgg features create by function make_layers.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer num_classes (int): Output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000. will not be defined. Default: 1000.
with_pool (bool): Use pool before the last three fc layer or not. Default: True.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -54,10 +55,16 @@ class VGG(nn.Layer): ...@@ -54,10 +55,16 @@ class VGG(nn.Layer):
""" """
def __init__(self, features, num_classes=1000): def __init__(self, features, num_classes=1000, with_pool=True):
super(VGG, self).__init__() super(VGG, self).__init__()
self.features = features self.features = features
self.num_classes = num_classes
self.with_pool = with_pool
if with_pool:
self.avgpool = nn.AdaptiveAvgPool2D((7, 7)) self.avgpool = nn.AdaptiveAvgPool2D((7, 7))
if num_classes > 0:
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096), nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(), nn.ReLU(),
...@@ -69,9 +76,14 @@ class VGG(nn.Layer): ...@@ -69,9 +76,14 @@ class VGG(nn.Layer):
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.features(x)
if self.with_pool:
x = self.avgpool(x) x = self.avgpool(x)
if self.num_classes > 0:
x = paddle.flatten(x, 1) x = paddle.flatten(x, 1)
x = self.classifier(x) x = self.classifier(x)
return x return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册