diff --git a/python/paddle/tests/test_vision_models.py b/python/paddle/tests/test_vision_models.py index 5f35a1e0e5a4ba9d6c6683918a2a03c190089762..a25a8f373c29c4e678f87453eade1fd958c9ac33 100644 --- a/python/paddle/tests/test_vision_models.py +++ b/python/paddle/tests/test_vision_models.py @@ -71,6 +71,9 @@ class TestVisonModels(unittest.TestCase): def test_resnet152(self): self.models_infer('resnet152') + def test_vgg16_num_classes(self): + vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10) + def test_lenet(self): input = InputSpec([None, 1, 28, 28], 'float32', 'x') lenet = paddle.Model(models.__dict__['LeNet'](), input) diff --git a/python/paddle/vision/models/vgg.py b/python/paddle/vision/models/vgg.py index bb158569d3bc9fb658c1e103b41bb02784e68d8b..00f6cccbdfe9f11c0f234e923383ecc566958a33 100644 --- a/python/paddle/vision/models/vgg.py +++ b/python/paddle/vision/models/vgg.py @@ -107,10 +107,7 @@ cfgs = { def _vgg(arch, cfg, batch_norm, pretrained, **kwargs): - model = VGG(make_layers( - cfgs[cfg], batch_norm=batch_norm), - num_classes=1000, - **kwargs) + model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) if pretrained: assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(