diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 8e58e5eb794e1bb507ab05394a1f7b57a1d2ed42..f626039363a6864a59ec12af7f2e2f19d4d60099 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -24,7 +24,8 @@ __all__ = [ 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'load_persistables', 'save_inference_model', 'load_inference_model', 'get_inference_program', 'save_checkpoint', 'load_checkpoint', - 'clean_checkpoint' + 'clean_checkpoint', 'load_persist_vars_without_grad', + 'save_persist_vars_without_grad' ] @@ -455,6 +456,33 @@ def get_parameter_value_by_name(name, executor, program=None): return get_parameter_value(var, executor) +def load_persist_vars_without_grad(executor, dirname, program): + """ + load_persist_vars_without_grad will load variables from a directory by an executor, + the variable named end with "@GRAD" will not be loaded. + """ + load_vars( + executor, + dirname=dirname, + main_program=program, + predicate=_is_checkpoint_var, + filename=None) + + +def save_persist_vars_without_grad(executor, dirname, program): + """ + save_persist_vars_without_grad will save variables to a directory by an executor, + the variable named end with "@GRAD" will not be saved. + """ + save_vars( + executor, + dirname=dirname, + main_program=program, + vars=None, + predicate=_is_checkpoint_var, + filename=None) + + SUCCESS_MARK_FILENAME = "_SUCCESS" CHECKPOINT_PREFIX = "checkpoint" CHECKPOINT_SEPARATOR = "_" @@ -491,13 +519,7 @@ def save_checkpoint(executor, serial += 1 cur_dir = _get_serial_dir(serial, checkpoint_dir) - save_vars( - executor, - dirname=cur_dir, - main_program=main_program, - vars=None, - predicate=_is_checkpoint_var, - filename=None) + load_persist_vars_without_grad(executor, cur_dir, main_program) _write_success(cur_dir) _lru_delete(checkpoint_dir, max_num_checkpoints) @@ -521,13 +543,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None): return cur_dir = _get_serial_dir(serial, checkpoint_dir) - - load_vars( - executor, - dirname=cur_dir, - main_program=main_program, - predicate=_is_checkpoint_var, - filename=None) + load_persist_vars_without_grad(executor, cur_dir, main_program) def clean_checkpoint(checkpoint_dir, delete_dir=False): diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 24254b4980c130da886afbe293ea169075933688..b4b7b75b96e3155ed8e5c9d0cca6a6f5fb45d46b 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -162,7 +162,8 @@ class Trainer(object): if param_path: # load params from param_path into scope - io.load_persistables(exe, dirname=param_path) + io.load_persist_vars_without_grad( + exe, dirname=param_path, program=self.startup_program) def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS