diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index d65b9954842d22b9422e6187e52499633b34ee97..35f0bc58074fe053bc0eeafec6dc149e2014e639 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -24,6 +24,7 @@ import tempfile import paddle from paddle.static import load_program_state +from paddle.utils.download import get_weights_path_from_url from ppcls.utils import logger @@ -70,6 +71,14 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False): return +def load_dygraph_pretrain_from_url(model, pretrained_url, use_ssld, load_static_weights=False): + if use_ssld: + pretrained_url = pretrained_url.replace("_pretrained", "_ssld_pretrained") + local_weight_path = get_weights_path_from_url(pretrained_url).replace(".pdparams", "") + load_dygraph_pretrain(model, path=local_weight_path, load_static_weights=load_static_weights) + return + + def load_distillation_model(model, pretrained_model, load_static_weights): logger.info("In distillation mode, teacher model will be " "loaded firstly before student model.")