提交 514b2427 编写于 作者: T tangwei12

add save/load persist_vars_without_grad

上级 b2cb7c6f
...@@ -24,7 +24,8 @@ __all__ = [ ...@@ -24,7 +24,8 @@ __all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model', 'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint', 'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint' 'clean_checkpoint', 'load_persist_vars_without_grad',
'save_persist_vars_without_grad'
] ]
...@@ -455,6 +456,33 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -455,6 +456,33 @@ def get_parameter_value_by_name(name, executor, program=None):
return get_parameter_value(var, executor) return get_parameter_value(var, executor)
def load_persist_vars_without_grad(executor, dirname, program):
"""
load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded.
"""
load_vars(
executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var,
filename=None)
def save_persist_vars_without_grad(executor, dirname, program):
"""
save_persist_vars_without_grad will save variables to a directory by an executor,
the variable named end with "@GRAD" will not be saved.
"""
save_vars(
executor,
dirname=dirname,
main_program=program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
SUCCESS_MARK_FILENAME = "_SUCCESS" SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint" CHECKPOINT_PREFIX = "checkpoint"
CHECKPOINT_SEPARATOR = "_" CHECKPOINT_SEPARATOR = "_"
...@@ -491,13 +519,7 @@ def save_checkpoint(executor, ...@@ -491,13 +519,7 @@ def save_checkpoint(executor,
serial += 1 serial += 1
cur_dir = _get_serial_dir(serial, checkpoint_dir) cur_dir = _get_serial_dir(serial, checkpoint_dir)
save_vars( load_persist_vars_without_grad(executor, cur_dir, main_program)
executor,
dirname=cur_dir,
main_program=main_program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir) _write_success(cur_dir)
_lru_delete(checkpoint_dir, max_num_checkpoints) _lru_delete(checkpoint_dir, max_num_checkpoints)
...@@ -521,13 +543,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None): ...@@ -521,13 +543,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
return return
cur_dir = _get_serial_dir(serial, checkpoint_dir) cur_dir = _get_serial_dir(serial, checkpoint_dir)
load_persist_vars_without_grad(executor, cur_dir, main_program)
load_vars(
executor,
dirname=cur_dir,
main_program=main_program,
predicate=_is_checkpoint_var,
filename=None)
def clean_checkpoint(checkpoint_dir, delete_dir=False): def clean_checkpoint(checkpoint_dir, delete_dir=False):
......
...@@ -162,7 +162,8 @@ class Trainer(object): ...@@ -162,7 +162,8 @@ class Trainer(object):
if param_path: if param_path:
# load params from param_path into scope # load params from param_path into scope
io.load_persistables(exe, dirname=param_path) io.load_persist_vars_without_grad(
exe, dirname=param_path, program=self.startup_program)
def _transpile_nccl2_dist(self): def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS # PADDLE_TRAINER_IPS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册