提交 6db240d7 编写于 作者: T tangwei12

update trainer about epoch_id and step id

上级 3b5e3f9b
...@@ -188,7 +188,7 @@ class Trainer(object): ...@@ -188,7 +188,7 @@ class Trainer(object):
if not self.checkpoint.is_pserver: if not self.checkpoint.is_pserver:
epoch_id, step_id = io.load_trainer_args( epoch_id, step_id = io.load_trainer_args(
self.checkpoint.checkpoint_dir, self.checkpoint.load_serial, self.checkpoint.checkpoint_dir, self.checkpoint.load_serial,
self.trainer_id, ["epoch_id", "step_id"]) self.trainer_id, self._get_checkpoint_load_args())
self.checkpoint.epoch_id = int(epoch_id) self.checkpoint.epoch_id = int(epoch_id)
self.checkpoint.step_id = int(step_id) self.checkpoint.step_id = int(step_id)
...@@ -432,22 +432,33 @@ class Trainer(object): ...@@ -432,22 +432,33 @@ class Trainer(object):
return return
io.clean_checkpoint(checkpoint_dir=self.checkpoint.checkpoint_dir) io.clean_checkpoint(checkpoint_dir=self.checkpoint.checkpoint_dir)
def _get_checkpoint_load_args(self):
"""
epoch_id and step_id are runtime arguments, they are not variables, will load them independently.
"""
return ["epoch_id", "step_id"]
def _get_checkpoint_save_args(self, epoch_id, step_id):
"""
epoch_id and step_id are runtime arguments, they are not variables, will save them independently.
"""
trainer_args = {}
trainer_args["epoch_id"] = epoch_id
trainer_args["step_id"] = step_id
return trainer_args
def _save_checkpoint(self, epoch_id, step_id): def _save_checkpoint(self, epoch_id, step_id):
if not self.checkpoint: if not self.checkpoint:
return return
if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0: if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0:
trainer_args = {}
trainer_args["epoch_id"] = epoch_id
trainer_args["step_id"] = step_id
exe = executor.Executor(self.place) exe = executor.Executor(self.place)
io.save_checkpoint( io.save_checkpoint(
executor=exe, executor=exe,
checkpoint_dir=self.checkpoint.checkpoint_dir, checkpoint_dir=self.checkpoint.checkpoint_dir,
trainer_id=self.trainer_id, trainer_id=self.trainer_id,
is_chief=self.chief, is_chief=self.chief,
trainer_args=trainer_args, trainer_args=self._get_checkpoint_save_args(epoch_id, step_id),
main_program=self.train_program, main_program=self.train_program,
max_num_checkpoints=self.checkpoint.max_num_checkpoints) max_num_checkpoints=self.checkpoint.max_num_checkpoints)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册