提交 0deb6f90 编写于 作者: T tangwei12

annotation optimized and code style optimized

上级 0211c5df
...@@ -478,9 +478,10 @@ def save_checkpoint(executor, ...@@ -478,9 +478,10 @@ def save_checkpoint(executor,
:param executor :param executor
:param checkpoint_dir :param checkpoint_dir
:param trainer_id
:param is_chief
:param main_program :param main_program
:param max_num_checkpoints :param max_num_checkpoints
:param is_chief
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
raise ValueError("The values of 'checkpoint_dir' should not be None") raise ValueError("The values of 'checkpoint_dir' should not be None")
...@@ -502,6 +503,11 @@ def save_checkpoint(executor, ...@@ -502,6 +503,11 @@ def save_checkpoint(executor,
def need_load_checkpoint(checkpoint_dir): def need_load_checkpoint(checkpoint_dir):
"""
If the directory have checkpoint files, it will return lastest checkpoint directory serial number
:param checkpoint_dir
"""
serial = _get_lastest_checkpoint_dir(checkpoint_dir) serial = _get_lastest_checkpoint_dir(checkpoint_dir)
if serial < 0: if serial < 0:
return None return None
...@@ -515,6 +521,7 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): ...@@ -515,6 +521,7 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
:param executor :param executor
:param checkpoint_dir :param checkpoint_dir
:param serial
:param main_program :param main_program
""" """
...@@ -536,7 +543,11 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): ...@@ -536,7 +543,11 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
""" """
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before. clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised. delete_dir only works when the directory is empty, otherwise, OSError is raised.
:param checkpoint_dir
:param delete_dir
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
raise ValueError("The values of 'checkpoint_dir' should not be None") raise ValueError("The values of 'checkpoint_dir' should not be None")
_lru_delete(checkpoint_dir, max_num_checkpoints=0) _lru_delete(checkpoint_dir, max_num_checkpoints=0)
...@@ -549,6 +560,11 @@ def load_persist_vars_without_grad(executor, dirname, program, nest=True): ...@@ -549,6 +560,11 @@ def load_persist_vars_without_grad(executor, dirname, program, nest=True):
""" """
load_persist_vars_without_grad will load variables from a directory by an executor, load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded. the variable named end with "@GRAD" will not be loaded.
:param executor
:param dirname
:param program
:param nest
""" """
if nest: if nest:
...@@ -566,6 +582,10 @@ def save_persist_vars_without_grad(executor, dirname, program): ...@@ -566,6 +582,10 @@ def save_persist_vars_without_grad(executor, dirname, program):
""" """
save_persist_vars_without_grad will save variables to a directory by an executor, save_persist_vars_without_grad will save variables to a directory by an executor,
the variable named end with "@GRAD" will not be saved. the variable named end with "@GRAD" will not be saved.
:param executor
:param dirname
:param program
""" """
cur_dir = _get_model_dir(dirname) cur_dir = _get_model_dir(dirname)
save_vars( save_vars(
......
...@@ -79,8 +79,8 @@ class CheckpointConfig(object): ...@@ -79,8 +79,8 @@ class CheckpointConfig(object):
else: else:
self.step_interval = step_interval self.step_interval = step_interval
self._epoch_id = 0 self.epoch_id = 0
self._step_id = 0 self.step_id = 0
self._load_serial = None self._load_serial = None
...@@ -185,8 +185,8 @@ class Trainer(object): ...@@ -185,8 +185,8 @@ class Trainer(object):
epoch_id, step_id = io.load_trainer_args( epoch_id, step_id = io.load_trainer_args(
self.checkpoint.checkpoint_dir, self.checkpoint._load_serial, self.checkpoint.checkpoint_dir, self.checkpoint._load_serial,
self.trainer_id, ["epoch_id", "step_id"]) self.trainer_id, ["epoch_id", "step_id"])
self.checkpoint._epoch_id = int(epoch_id) self.checkpoint.epoch_id = int(epoch_id)
self.checkpoint._step_id = int(step_id) self.checkpoint.step_id = int(step_id)
if param_path and os.path.isdir(param_path): if param_path and os.path.isdir(param_path):
# load params from param_path into scope # load params from param_path into scope
...@@ -353,7 +353,7 @@ class Trainer(object): ...@@ -353,7 +353,7 @@ class Trainer(object):
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader): def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
epochs = [ epochs = [
epoch_id for epoch_id in range(num_epochs) epoch_id for epoch_id in range(num_epochs)
if epoch_id >= self.checkpoint._epoch_id if epoch_id >= self.checkpoint.epoch_id
] ]
for epoch_id in epochs: for epoch_id in epochs:
event_handler(BeginEpochEvent(epoch_id)) event_handler(BeginEpochEvent(epoch_id))
...@@ -363,7 +363,7 @@ class Trainer(object): ...@@ -363,7 +363,7 @@ class Trainer(object):
return return
if self.checkpoint and self.checkpoint._load_serial \ if self.checkpoint and self.checkpoint._load_serial \
and self.checkpoint._step_id >= step_id and self.checkpoint._epoch_id == epoch_id: and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id:
continue continue
begin_event = BeginStepEvent(epoch_id, step_id) begin_event = BeginStepEvent(epoch_id, step_id)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册