diff --git a/hubconf.py b/hubconf.py index 7eca4160dba9d984c54fb06dbae269bd524fff6e..a843cba6f48ca00793a98b23da11a7710fa50614 100644 --- a/hubconf.py +++ b/hubconf.py @@ -67,9 +67,12 @@ def _load_pretrained_urls(): _checkpoints = _load_pretrained_urls() -def _load_parameters(model, ): - pass - +def _load_pretrained_parameters(model, name): + assert name in _checkpoints, 'Not provide {} pretrained model.'.format(name) + path = paddle.utils.download.get_weights_path_from_url(_checkpoints[name]) + model.set_state_dict(paddle.load(path)) + return model + def AlexNet(pretrained=False, **kwargs): '''AlexNet