From 01a14e1be209b3300be1f36a27152cfd429533a4 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Wed, 18 Nov 2020 15:26:50 +0800 Subject: [PATCH] Add with_pool args for vgg (#28684) * add arg for vgg --- python/paddle/vision/models/resnet.py | 2 +- python/paddle/vision/models/vgg.py | 42 +++++++++++++++++---------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/python/paddle/vision/models/resnet.py b/python/paddle/vision/models/resnet.py index 8cf797f1719..1f44e0bc6df 100644 --- a/python/paddle/vision/models/resnet.py +++ b/python/paddle/vision/models/resnet.py @@ -245,7 +245,7 @@ class ResNet(nn.Layer): x = self.layer3(x) x = self.layer4(x) - if self.with_pool > 0: + if self.with_pool: x = self.avgpool(x) if self.num_classes > 0: diff --git a/python/paddle/vision/models/vgg.py b/python/paddle/vision/models/vgg.py index 00f6cccbdfe..f6b4c75e84f 100644 --- a/python/paddle/vision/models/vgg.py +++ b/python/paddle/vision/models/vgg.py @@ -36,9 +36,10 @@ class VGG(nn.Layer): `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ Args: - features (nn.Layer): vgg features create by function make_layers. - num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + features (nn.Layer): Vgg features create by function make_layers. + num_classes (int): Output dim of last fc layer. If num_classes <=0, last fc layer will not be defined. Default: 1000. + with_pool (bool): Use pool before the last three fc layer or not. Default: True. Examples: .. code-block:: python @@ -54,24 +55,35 @@ class VGG(nn.Layer): """ - def __init__(self, features, num_classes=1000): + def __init__(self, features, num_classes=1000, with_pool=True): super(VGG, self).__init__() self.features = features - self.avgpool = nn.AdaptiveAvgPool2D((7, 7)) - self.classifier = nn.Sequential( - nn.Linear(512 * 7 * 7, 4096), - nn.ReLU(), - nn.Dropout(), - nn.Linear(4096, 4096), - nn.ReLU(), - nn.Dropout(), - nn.Linear(4096, num_classes), ) + self.num_classes = num_classes + self.with_pool = with_pool + + if with_pool: + self.avgpool = nn.AdaptiveAvgPool2D((7, 7)) + + if num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(), + nn.Dropout(), + nn.Linear(4096, num_classes), ) def forward(self, x): x = self.features(x) - x = self.avgpool(x) - x = paddle.flatten(x, 1) - x = self.classifier(x) + + if self.with_pool: + x = self.avgpool(x) + + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.classifier(x) + return x -- GitLab