提交 bdd8178c 编写于 作者: L lyuwenyu

add `_load_pretrained_parameters`

上级 48cdd362
......@@ -67,8 +67,11 @@ def _load_pretrained_urls():
_checkpoints = _load_pretrained_urls()
def _load_parameters(model, ):
pass
def _load_pretrained_parameters(model, name):
assert name in _checkpoints, 'Not provide {} pretrained model.'.format(name)
path = paddle.utils.download.get_weights_path_from_url(_checkpoints[name])
model.set_state_dict(paddle.load(path))
return model
def AlexNet(pretrained=False, **kwargs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册