提交 514b2427 编写于 作者: T tangwei12

add save/load persist_vars_without_grad

上级 b2cb7c6f
......@@ -24,7 +24,8 @@ __all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint'
'clean_checkpoint', 'load_persist_vars_without_grad',
'save_persist_vars_without_grad'
]
......@@ -455,6 +456,33 @@ def get_parameter_value_by_name(name, executor, program=None):
return get_parameter_value(var, executor)
def load_persist_vars_without_grad(executor, dirname, program):
"""
load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded.
"""
load_vars(
executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var,
filename=None)
def save_persist_vars_without_grad(executor, dirname, program):
"""
save_persist_vars_without_grad will save variables to a directory by an executor,
the variable named end with "@GRAD" will not be saved.
"""
save_vars(
executor,
dirname=dirname,
main_program=program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint"
CHECKPOINT_SEPARATOR = "_"
......@@ -491,13 +519,7 @@ def save_checkpoint(executor,
serial += 1
cur_dir = _get_serial_dir(serial, checkpoint_dir)
save_vars(
executor,
dirname=cur_dir,
main_program=main_program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
load_persist_vars_without_grad(executor, cur_dir, main_program)
_write_success(cur_dir)
_lru_delete(checkpoint_dir, max_num_checkpoints)
......@@ -521,13 +543,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
return
cur_dir = _get_serial_dir(serial, checkpoint_dir)
load_vars(
executor,
dirname=cur_dir,
main_program=main_program,
predicate=_is_checkpoint_var,
filename=None)
load_persist_vars_without_grad(executor, cur_dir, main_program)
def clean_checkpoint(checkpoint_dir, delete_dir=False):
......
......@@ -162,7 +162,8 @@ class Trainer(object):
if param_path:
# load params from param_path into scope
io.load_persistables(exe, dirname=param_path)
io.load_persist_vars_without_grad(
exe, dirname=param_path, program=self.startup_program)
def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册