diff --git a/PaddleNLP/PaddleLARK/BERT/utils/init.py b/PaddleNLP/PaddleLARK/BERT/utils/init.py index 8948eae548fc3779c701a40fca08f07a136e5298..e359487190a87ca033c128bb8123323771bf648b 100644 --- a/PaddleNLP/PaddleLARK/BERT/utils/init.py +++ b/PaddleNLP/PaddleLARK/BERT/utils/init.py @@ -62,7 +62,19 @@ def init_pretraining_params(exe, assert os.path.exists(pretraining_params_path ), "[%s] cann't be found." % pretraining_params_path - fluid.load( main_program, pretraining_params_path, exe) + def existed_params(var): + if not isinstance(var, fluid.framework.Parameter): + return False + if os.path.exists(os.path.join(pretraining_params_path, var.name)): + print("INIT {}".format(var.name)) + return True + else: + print("SKIP {}".format(var.name)) + return False + + load_var_list = list(filter(existed_params, main_program.list_vars()) ) + para_state = fluid.load_program_state( pretraining_params_path, var_list = load_var_list) + fluid.set_program_state( main_program, para_state) print("Load pretraining parameters from {}.".format( pretraining_params_path))