From 1c3eef4cee16b327c0a305c4eebe6dc369fd1121 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Mon, 16 Nov 2020 11:28:03 +0800 Subject: [PATCH] Fix vgg error when num_classes is given (#28557) * fix vgg num classes --- python/paddle/tests/test_vision_models.py | 3 +++ python/paddle/vision/models/vgg.py | 5 +---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/tests/test_vision_models.py b/python/paddle/tests/test_vision_models.py index 5f35a1e0e5..a25a8f373c 100644 --- a/python/paddle/tests/test_vision_models.py +++ b/python/paddle/tests/test_vision_models.py @@ -71,6 +71,9 @@ class TestVisonModels(unittest.TestCase): def test_resnet152(self): self.models_infer('resnet152') + def test_vgg16_num_classes(self): + vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10) + def test_lenet(self): input = InputSpec([None, 1, 28, 28], 'float32', 'x') lenet = paddle.Model(models.__dict__['LeNet'](), input) diff --git a/python/paddle/vision/models/vgg.py b/python/paddle/vision/models/vgg.py index bb158569d3..00f6cccbdf 100644 --- a/python/paddle/vision/models/vgg.py +++ b/python/paddle/vision/models/vgg.py @@ -107,10 +107,7 @@ cfgs = { def _vgg(arch, cfg, batch_norm, pretrained, **kwargs): - model = VGG(make_layers( - cfgs[cfg], batch_norm=batch_norm), - num_classes=1000, - **kwargs) + model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) if pretrained: assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( -- GitLab