未验证 提交 1c3eef4c 编写于 作者: L LielinJiang 提交者: GitHub

Fix vgg error when num_classes is given (#28557)

* fix vgg num classes
上级 1de3cdd0
......@@ -71,6 +71,9 @@ class TestVisonModels(unittest.TestCase):
def test_resnet152(self):
self.models_infer('resnet152')
def test_vgg16_num_classes(self):
vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10)
def test_lenet(self):
input = InputSpec([None, 1, 28, 28], 'float32', 'x')
lenet = paddle.Model(models.__dict__['LeNet'](), input)
......
......@@ -107,10 +107,7 @@ cfgs = {
def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
model = VGG(make_layers(
cfgs[cfg], batch_norm=batch_norm),
num_classes=1000,
**kwargs)
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册