diff --git a/python/paddle/vision/models/resnet.py b/python/paddle/vision/models/resnet.py index 8cf797f1719e992148537e66dc4736e96cf19f03..1f44e0bc6dfeb18cd1eb99489860500a390c33de 100644 --- a/python/paddle/vision/models/resnet.py +++ b/python/paddle/vision/models/resnet.py @@ -245,7 +245,7 @@ class ResNet(nn.Layer): x = self.layer3(x) x = self.layer4(x) - if self.with_pool > 0: + if self.with_pool: x = self.avgpool(x) if self.num_classes > 0: diff --git a/python/paddle/vision/models/vgg.py b/python/paddle/vision/models/vgg.py index 00f6cccbdfe9f11c0f234e923383ecc566958a33..f6b4c75e84f01379264fb2066b218747204fd6da 100644 --- a/python/paddle/vision/models/vgg.py +++ b/python/paddle/vision/models/vgg.py @@ -36,9 +36,10 @@ class VGG(nn.Layer): `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ Args: - features (nn.Layer): vgg features create by function make_layers. - num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + features (nn.Layer): Vgg features create by function make_layers. + num_classes (int): Output dim of last fc layer. If num_classes <=0, last fc layer will not be defined. Default: 1000. + with_pool (bool): Use pool before the last three fc layer or not. Default: True. Examples: .. code-block:: python @@ -54,24 +55,35 @@ class VGG(nn.Layer): """ - def __init__(self, features, num_classes=1000): + def __init__(self, features, num_classes=1000, with_pool=True): super(VGG, self).__init__() self.features = features - self.avgpool = nn.AdaptiveAvgPool2D((7, 7)) - self.classifier = nn.Sequential( - nn.Linear(512 * 7 * 7, 4096), - nn.ReLU(), - nn.Dropout(), - nn.Linear(4096, 4096), - nn.ReLU(), - nn.Dropout(), - nn.Linear(4096, num_classes), ) + self.num_classes = num_classes + self.with_pool = with_pool + + if with_pool: + self.avgpool = nn.AdaptiveAvgPool2D((7, 7)) + + if num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(), + nn.Dropout(), + nn.Linear(4096, num_classes), ) def forward(self, x): x = self.features(x) - x = self.avgpool(x) - x = paddle.flatten(x, 1) - x = self.classifier(x) + + if self.with_pool: + x = self.avgpool(x) + + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.classifier(x) + return x