提交 9735f250 编写于 作者: T tangwei12

optimized

上级 bfdcf187
......@@ -492,7 +492,7 @@ def save_checkpoint(executor,
if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir)
serial = _get_latest_checkpoint_dir(checkpoint_dir) + 1
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial)
save_trainer_args(cur_dir, trainer_id, trainer_args)
......@@ -503,18 +503,6 @@ def save_checkpoint(executor,
_lru_delete(checkpoint_dir, max_num_checkpoints)
def get_latest_checkpoint_serial(checkpoint_dir):
"""
If the directory have checkpoint files, it will return latest checkpoint directory serial number
:param checkpoint_dir
"""
serial = _get_latest_checkpoint_dir(checkpoint_dir)
if serial < 0:
return None
return serial
def load_checkpoint(executor, checkpoint_dir, serial, main_program):
"""
Load checkpoint from a directory by executor,
......@@ -527,17 +515,16 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
"""
if checkpoint_dir is None:
raise ValueError(
"The values of 'checkpoint_dir' or 'serial' should not be None")
raise ValueError("The values of 'checkpoint_dir' should not be None")
if serial is None or serial < 0:
raise ValueError("The values of 'serial' should not be None or <0 ")
if main_program is None:
raise ValueError("The values of 'main_program'should not be None")
raise ValueError('main_program should not be None.')
cur_dir = _get_serial_dir(checkpoint_dir, serial)
load_persist_vars_without_grad(executor, cur_dir, main_program)
load_persist_vars_without_grad(executor, cur_dir, main_program, True)
def clean_checkpoint(checkpoint_dir, delete_dir=False):
......@@ -557,18 +544,21 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os.rmdir(checkpoint_dir)
def load_persist_vars_without_grad(executor, dirname, program, nest=True):
def load_persist_vars_without_grad(executor,
dirname,
program,
has_model_dir=False):
"""
load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded.
:param executor
:param dirname
:param program
:param nest
:param executor executor for load the value
:param dirname the checkpoint directory
:param program will load all variables in program
:param has_model_dir if has_model_dir is True, will load variables from sub directory named __model__
"""
if nest:
if has_model_dir:
dirname = _get_model_dir(dirname)
load_vars(
......@@ -584,9 +574,9 @@ def save_persist_vars_without_grad(executor, dirname, program):
save_persist_vars_without_grad will save variables to a directory by an executor,
the variable named end with "@GRAD" will not be saved.
:param executor
:param dirname
:param program
:param executor executor for load the value
:param dirname the checkpoint directory
:param program will load all variables in program
"""
cur_dir = _get_model_dir(dirname)
save_vars(
......@@ -722,7 +712,7 @@ def _write_success(dirname):
f.write(now)
def _get_latest_checkpoint_dir(checkpoint_dir):
def get_latest_checkpoint_serial(checkpoint_dir):
"""
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
......
......@@ -146,8 +146,9 @@ class Trainer(object):
"The checkpoint_config shoule be an instance of CheckpointConfig"
)
else:
self.checkpoint.load_serial = io.get_latest_checkpoint_serial(
serial = io.get_latest_checkpoint_serial(
self.checkpoint.checkpoint_dir)
self.checkpoint.load_serial = serial if serial >= 0 else None
self.scope = core.Scope()
......@@ -194,10 +195,7 @@ class Trainer(object):
if param_path and os.path.isdir(param_path):
# load params from param_path into scope
io.load_persist_vars_without_grad(
exe,
dirname=param_path,
program=self.startup_program,
nest=False)
exe, dirname=param_path, program=self.startup_program)
def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册