From 0140d74e2397a8ab78a390e5f5df5d33738ef60e Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Wed, 14 Oct 2020 21:29:45 +0800 Subject: [PATCH] Fix load pretrained model in hapi (#27893) * fix hapi load --- python/paddle/tests/CMakeLists.txt | 2 ++ python/paddle/tests/test_pretrained_model.py | 4 ++-- python/paddle/vision/models/mobilenetv1.py | 5 ++--- python/paddle/vision/models/mobilenetv2.py | 5 ++--- python/paddle/vision/models/resnet.py | 5 ++--- python/paddle/vision/models/vgg.py | 5 ++--- 6 files changed, 12 insertions(+), 14 deletions(-) diff --git a/python/paddle/tests/CMakeLists.txt b/python/paddle/tests/CMakeLists.txt index e1bc65a5d15..9f64a6d2b7b 100644 --- a/python/paddle/tests/CMakeLists.txt +++ b/python/paddle/tests/CMakeLists.txt @@ -39,3 +39,5 @@ foreach(src ${DIST_TEST_OPS}) message(STATUS ${src}) py_dist_test(${src} SRCS ${src}.py) endforeach() + +set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 600) diff --git a/python/paddle/tests/test_pretrained_model.py b/python/paddle/tests/test_pretrained_model.py index bf9c2a2ae06..a36dd75549a 100644 --- a/python/paddle/tests/test_pretrained_model.py +++ b/python/paddle/tests/test_pretrained_model.py @@ -33,7 +33,7 @@ class TestPretrainedModel(unittest.TestCase): if not dygraph: paddle.enable_static() - net = models.__dict__[arch]() + net = models.__dict__[arch](pretrained=True) inputs = [InputSpec([None, 3, 224, 224], 'float32', 'image')] model = paddle.Model(network=net, inputs=inputs) model.prepare() @@ -52,7 +52,7 @@ class TestPretrainedModel(unittest.TestCase): np.testing.assert_allclose(res['dygraph'], res['static']) def test_models(self): - arches = ['mobilenet_v1', 'mobilenet_v2', 'resnet18'] + arches = ['mobilenet_v1', 'mobilenet_v2', 'resnet18', 'vgg16'] for arch in arches: self.infer(arch) diff --git a/python/paddle/vision/models/mobilenetv1.py b/python/paddle/vision/models/mobilenetv1.py index 39654122e3b..4e6030bd14b 100644 --- a/python/paddle/vision/models/mobilenetv1.py +++ b/python/paddle/vision/models/mobilenetv1.py @@ -240,9 +240,8 @@ def _mobilenet(arch, pretrained=False, **kwargs): arch) weight_path = get_weights_path_from_url(model_urls[arch][0], model_urls[arch][1]) - assert weight_path.endswith( - '.pdparams'), "suffix of weight must be .pdparams" - param, _ = paddle.load(weight_path) + + param = paddle.load(weight_path) model.load_dict(param) return model diff --git a/python/paddle/vision/models/mobilenetv2.py b/python/paddle/vision/models/mobilenetv2.py index bab8b7b2b1b..0f4dc22f679 100644 --- a/python/paddle/vision/models/mobilenetv2.py +++ b/python/paddle/vision/models/mobilenetv2.py @@ -194,9 +194,8 @@ def _mobilenet(arch, pretrained=False, **kwargs): arch) weight_path = get_weights_path_from_url(model_urls[arch][0], model_urls[arch][1]) - assert weight_path.endswith( - '.pdparams'), "suffix of weight must be .pdparams" - param, _ = paddle.load(weight_path) + + param = paddle.load(weight_path) model.load_dict(param) return model diff --git a/python/paddle/vision/models/resnet.py b/python/paddle/vision/models/resnet.py index f9e00aefd6b..3ae01b6fd7d 100644 --- a/python/paddle/vision/models/resnet.py +++ b/python/paddle/vision/models/resnet.py @@ -262,9 +262,8 @@ def _resnet(arch, Block, depth, pretrained, **kwargs): arch) weight_path = get_weights_path_from_url(model_urls[arch][0], model_urls[arch][1]) - assert weight_path.endswith( - '.pdparams'), "suffix of weight must be .pdparams" - param, _ = paddle.load(weight_path) + + param = paddle.load(weight_path) model.set_dict(param) return model diff --git a/python/paddle/vision/models/vgg.py b/python/paddle/vision/models/vgg.py index d11845b6616..2d62e1d22d4 100644 --- a/python/paddle/vision/models/vgg.py +++ b/python/paddle/vision/models/vgg.py @@ -117,9 +117,8 @@ def _vgg(arch, cfg, batch_norm, pretrained, **kwargs): arch) weight_path = get_weights_path_from_url(model_urls[arch][0], model_urls[arch][1]) - assert weight_path.endswith( - '.pdparams'), "suffix of weight must be .pdparams" - param, _ = paddle.load(weight_path) + + param = paddle.load(weight_path) model.load_dict(param) return model -- GitLab