未验证 提交 0140d74e 编写于 作者: L LielinJiang 提交者: GitHub

Fix load pretrained model in hapi (#27893)

* fix hapi load
上级 ed6ee53e
...@@ -39,3 +39,5 @@ foreach(src ${DIST_TEST_OPS}) ...@@ -39,3 +39,5 @@ foreach(src ${DIST_TEST_OPS})
message(STATUS ${src}) message(STATUS ${src})
py_dist_test(${src} SRCS ${src}.py) py_dist_test(${src} SRCS ${src}.py)
endforeach() endforeach()
set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 600)
...@@ -33,7 +33,7 @@ class TestPretrainedModel(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestPretrainedModel(unittest.TestCase):
if not dygraph: if not dygraph:
paddle.enable_static() paddle.enable_static()
net = models.__dict__[arch]() net = models.__dict__[arch](pretrained=True)
inputs = [InputSpec([None, 3, 224, 224], 'float32', 'image')] inputs = [InputSpec([None, 3, 224, 224], 'float32', 'image')]
model = paddle.Model(network=net, inputs=inputs) model = paddle.Model(network=net, inputs=inputs)
model.prepare() model.prepare()
...@@ -52,7 +52,7 @@ class TestPretrainedModel(unittest.TestCase): ...@@ -52,7 +52,7 @@ class TestPretrainedModel(unittest.TestCase):
np.testing.assert_allclose(res['dygraph'], res['static']) np.testing.assert_allclose(res['dygraph'], res['static'])
def test_models(self): def test_models(self):
arches = ['mobilenet_v1', 'mobilenet_v2', 'resnet18'] arches = ['mobilenet_v1', 'mobilenet_v2', 'resnet18', 'vgg16']
for arch in arches: for arch in arches:
self.infer(arch) self.infer(arch)
......
...@@ -240,9 +240,8 @@ def _mobilenet(arch, pretrained=False, **kwargs): ...@@ -240,9 +240,8 @@ def _mobilenet(arch, pretrained=False, **kwargs):
arch) arch)
weight_path = get_weights_path_from_url(model_urls[arch][0], weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1]) model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams" param = paddle.load(weight_path)
param, _ = paddle.load(weight_path)
model.load_dict(param) model.load_dict(param)
return model return model
......
...@@ -194,9 +194,8 @@ def _mobilenet(arch, pretrained=False, **kwargs): ...@@ -194,9 +194,8 @@ def _mobilenet(arch, pretrained=False, **kwargs):
arch) arch)
weight_path = get_weights_path_from_url(model_urls[arch][0], weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1]) model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams" param = paddle.load(weight_path)
param, _ = paddle.load(weight_path)
model.load_dict(param) model.load_dict(param)
return model return model
......
...@@ -262,9 +262,8 @@ def _resnet(arch, Block, depth, pretrained, **kwargs): ...@@ -262,9 +262,8 @@ def _resnet(arch, Block, depth, pretrained, **kwargs):
arch) arch)
weight_path = get_weights_path_from_url(model_urls[arch][0], weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1]) model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams" param = paddle.load(weight_path)
param, _ = paddle.load(weight_path)
model.set_dict(param) model.set_dict(param)
return model return model
......
...@@ -117,9 +117,8 @@ def _vgg(arch, cfg, batch_norm, pretrained, **kwargs): ...@@ -117,9 +117,8 @@ def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
arch) arch)
weight_path = get_weights_path_from_url(model_urls[arch][0], weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1]) model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams" param = paddle.load(weight_path)
param, _ = paddle.load(weight_path)
model.load_dict(param) model.load_dict(param)
return model return model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册