未验证 提交 0140d74e 编写于 作者: L LielinJiang 提交者: GitHub

Fix load pretrained model in hapi (#27893)

* fix hapi load
上级 ed6ee53e
......@@ -39,3 +39,5 @@ foreach(src ${DIST_TEST_OPS})
message(STATUS ${src})
py_dist_test(${src} SRCS ${src}.py)
endforeach()
set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 600)
......@@ -33,7 +33,7 @@ class TestPretrainedModel(unittest.TestCase):
if not dygraph:
paddle.enable_static()
net = models.__dict__[arch]()
net = models.__dict__[arch](pretrained=True)
inputs = [InputSpec([None, 3, 224, 224], 'float32', 'image')]
model = paddle.Model(network=net, inputs=inputs)
model.prepare()
......@@ -52,7 +52,7 @@ class TestPretrainedModel(unittest.TestCase):
np.testing.assert_allclose(res['dygraph'], res['static'])
def test_models(self):
arches = ['mobilenet_v1', 'mobilenet_v2', 'resnet18']
arches = ['mobilenet_v1', 'mobilenet_v2', 'resnet18', 'vgg16']
for arch in arches:
self.infer(arch)
......
......@@ -240,9 +240,8 @@ def _mobilenet(arch, pretrained=False, **kwargs):
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
param, _ = paddle.load(weight_path)
param = paddle.load(weight_path)
model.load_dict(param)
return model
......
......@@ -194,9 +194,8 @@ def _mobilenet(arch, pretrained=False, **kwargs):
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
param, _ = paddle.load(weight_path)
param = paddle.load(weight_path)
model.load_dict(param)
return model
......
......@@ -262,9 +262,8 @@ def _resnet(arch, Block, depth, pretrained, **kwargs):
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
param, _ = paddle.load(weight_path)
param = paddle.load(weight_path)
model.set_dict(param)
return model
......
......@@ -117,9 +117,8 @@ def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
param, _ = paddle.load(weight_path)
param = paddle.load(weight_path)
model.load_dict(param)
return model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册