提交 9735f250 编写于 作者: T tangwei12

optimized

上级 bfdcf187
...@@ -492,7 +492,7 @@ def save_checkpoint(executor, ...@@ -492,7 +492,7 @@ def save_checkpoint(executor,
if not os.path.isdir(checkpoint_dir): if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir) os.makedirs(checkpoint_dir)
serial = _get_latest_checkpoint_dir(checkpoint_dir) + 1 serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_serial_dir(checkpoint_dir, serial)
save_trainer_args(cur_dir, trainer_id, trainer_args) save_trainer_args(cur_dir, trainer_id, trainer_args)
...@@ -503,18 +503,6 @@ def save_checkpoint(executor, ...@@ -503,18 +503,6 @@ def save_checkpoint(executor,
_lru_delete(checkpoint_dir, max_num_checkpoints) _lru_delete(checkpoint_dir, max_num_checkpoints)
def get_latest_checkpoint_serial(checkpoint_dir):
"""
If the directory have checkpoint files, it will return latest checkpoint directory serial number
:param checkpoint_dir
"""
serial = _get_latest_checkpoint_dir(checkpoint_dir)
if serial < 0:
return None
return serial
def load_checkpoint(executor, checkpoint_dir, serial, main_program): def load_checkpoint(executor, checkpoint_dir, serial, main_program):
""" """
Load checkpoint from a directory by executor, Load checkpoint from a directory by executor,
...@@ -527,17 +515,16 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): ...@@ -527,17 +515,16 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
raise ValueError( raise ValueError("The values of 'checkpoint_dir' should not be None")
"The values of 'checkpoint_dir' or 'serial' should not be None")
if serial is None or serial < 0: if serial is None or serial < 0:
raise ValueError("The values of 'serial' should not be None or <0 ") raise ValueError("The values of 'serial' should not be None or <0 ")
if main_program is None: if main_program is None:
raise ValueError("The values of 'main_program'should not be None") raise ValueError('main_program should not be None.')
cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_serial_dir(checkpoint_dir, serial)
load_persist_vars_without_grad(executor, cur_dir, main_program) load_persist_vars_without_grad(executor, cur_dir, main_program, True)
def clean_checkpoint(checkpoint_dir, delete_dir=False): def clean_checkpoint(checkpoint_dir, delete_dir=False):
...@@ -557,18 +544,21 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): ...@@ -557,18 +544,21 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
os.rmdir(checkpoint_dir) os.rmdir(checkpoint_dir)
def load_persist_vars_without_grad(executor, dirname, program, nest=True): def load_persist_vars_without_grad(executor,
dirname,
program,
has_model_dir=False):
""" """
load_persist_vars_without_grad will load variables from a directory by an executor, load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded. the variable named end with "@GRAD" will not be loaded.
:param executor :param executor executor for load the value
:param dirname :param dirname the checkpoint directory
:param program :param program will load all variables in program
:param nest :param has_model_dir if has_model_dir is True, will load variables from sub directory named __model__
""" """
if nest: if has_model_dir:
dirname = _get_model_dir(dirname) dirname = _get_model_dir(dirname)
load_vars( load_vars(
...@@ -584,9 +574,9 @@ def save_persist_vars_without_grad(executor, dirname, program): ...@@ -584,9 +574,9 @@ def save_persist_vars_without_grad(executor, dirname, program):
save_persist_vars_without_grad will save variables to a directory by an executor, save_persist_vars_without_grad will save variables to a directory by an executor,
the variable named end with "@GRAD" will not be saved. the variable named end with "@GRAD" will not be saved.
:param executor :param executor executor for load the value
:param dirname :param dirname the checkpoint directory
:param program :param program will load all variables in program
""" """
cur_dir = _get_model_dir(dirname) cur_dir = _get_model_dir(dirname)
save_vars( save_vars(
...@@ -722,7 +712,7 @@ def _write_success(dirname): ...@@ -722,7 +712,7 @@ def _write_success(dirname):
f.write(now) f.write(now)
def _get_latest_checkpoint_dir(checkpoint_dir): def get_latest_checkpoint_serial(checkpoint_dir):
""" """
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
......
...@@ -146,8 +146,9 @@ class Trainer(object): ...@@ -146,8 +146,9 @@ class Trainer(object):
"The checkpoint_config shoule be an instance of CheckpointConfig" "The checkpoint_config shoule be an instance of CheckpointConfig"
) )
else: else:
self.checkpoint.load_serial = io.get_latest_checkpoint_serial( serial = io.get_latest_checkpoint_serial(
self.checkpoint.checkpoint_dir) self.checkpoint.checkpoint_dir)
self.checkpoint.load_serial = serial if serial >= 0 else None
self.scope = core.Scope() self.scope = core.Scope()
...@@ -194,10 +195,7 @@ class Trainer(object): ...@@ -194,10 +195,7 @@ class Trainer(object):
if param_path and os.path.isdir(param_path): if param_path and os.path.isdir(param_path):
# load params from param_path into scope # load params from param_path into scope
io.load_persist_vars_without_grad( io.load_persist_vars_without_grad(
exe, exe, dirname=param_path, program=self.startup_program)
dirname=param_path,
program=self.startup_program,
nest=False)
def _transpile_nccl2_dist(self): def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS # PADDLE_TRAINER_IPS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册