From bdd8178c15038805d3e2e4897fbe42d08ca03357 Mon Sep 17 00:00:00 2001 From: lyuwenyu Date: Fri, 2 Apr 2021 20:01:55 +0800 Subject: [PATCH] add `_load_pretrained_parameters` --- hubconf.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/hubconf.py b/hubconf.py index 7eca4160..a843cba6 100644 --- a/hubconf.py +++ b/hubconf.py @@ -67,9 +67,12 @@ def _load_pretrained_urls(): _checkpoints = _load_pretrained_urls() -def _load_parameters(model, ): - pass - +def _load_pretrained_parameters(model, name): + assert name in _checkpoints, 'Not provide {} pretrained model.'.format(name) + path = paddle.utils.download.get_weights_path_from_url(_checkpoints[name]) + model.set_state_dict(paddle.load(path)) + return model + def AlexNet(pretrained=False, **kwargs): '''AlexNet -- GitLab