未验证 提交 31e59dfa 编写于 作者: C cuicheng01 提交者: GitHub

add load_dygraph_pretrain_from_url function

上级 5671f9d9
...@@ -24,6 +24,7 @@ import tempfile ...@@ -24,6 +24,7 @@ import tempfile
import paddle import paddle
from paddle.static import load_program_state from paddle.static import load_program_state
from paddle.utils.download import get_weights_path_from_url
from ppcls.utils import logger from ppcls.utils import logger
...@@ -70,6 +71,14 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False): ...@@ -70,6 +71,14 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False):
return return
def load_dygraph_pretrain_from_url(model, pretrained_url, use_ssld, load_static_weights=False):
if use_ssld:
pretrained_url = pretrained_url.replace("_pretrained", "_ssld_pretrained")
local_weight_path = get_weights_path_from_url(pretrained_url).replace(".pdparams", "")
load_dygraph_pretrain(model, path=local_weight_path, load_static_weights=load_static_weights)
return
def load_distillation_model(model, pretrained_model, load_static_weights): def load_distillation_model(model, pretrained_model, load_static_weights):
logger.info("In distillation mode, teacher model will be " logger.info("In distillation mode, teacher model will be "
"loaded firstly before student model.") "loaded firstly before student model.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册