提交 90860430 编写于 作者: T tangwei12

bug fix and optimize

上级 486e1e33
...@@ -529,7 +529,6 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): ...@@ -529,7 +529,6 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
raise ValueError("The values of 'main_program'should not be None") raise ValueError("The values of 'main_program'should not be None")
cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_serial_dir(checkpoint_dir, serial)
cur_dir = _get_model_dir(cur_dir)
load_persist_vars_without_grad(executor, cur_dir, main_program) load_persist_vars_without_grad(executor, cur_dir, main_program)
......
...@@ -144,7 +144,7 @@ class Trainer(object): ...@@ -144,7 +144,7 @@ class Trainer(object):
raise TypeError( raise TypeError(
"The checkpoint_config shoule be an instance of CheckpointConfig" "The checkpoint_config shoule be an instance of CheckpointConfig"
) )
self.load_checkpoint_serial = io.need_load_checkpoint( self.checkpoint._load_serial = io.need_load_checkpoint(
self.checkpoint.checkpoint_dir) self.checkpoint.checkpoint_dir)
self.scope = core.Scope() self.scope = core.Scope()
...@@ -182,7 +182,7 @@ class Trainer(object): ...@@ -182,7 +182,7 @@ class Trainer(object):
self.startup_program) self.startup_program)
epoch_id, step_id = io.load_trainer_args( epoch_id, step_id = io.load_trainer_args(
self.checkpoint.checkpoint_dir, self.load_checkpoint_serial, self.checkpoint.checkpoint_dir, self.checkpoint._load_serial,
self.trainer_id, ["epoch_id", "step_id"]) self.trainer_id, ["epoch_id", "step_id"])
self.checkpoint._epoch_id = int(epoch_id) self.checkpoint._epoch_id = int(epoch_id)
self.checkpoint._step_id = int(step_id) self.checkpoint._step_id = int(step_id)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册