提交 dca0b6d9 编写于 作者: T tangwei12

restore param_path

上级 b044724d
...@@ -105,6 +105,7 @@ class Trainer(object): ...@@ -105,6 +105,7 @@ class Trainer(object):
def __init__(self, def __init__(self,
train_func, train_func,
optimizer, optimizer,
param_path=None,
place=None, place=None,
parallel=False, parallel=False,
checkpoint_config=None): checkpoint_config=None):
...@@ -120,8 +121,8 @@ class Trainer(object): ...@@ -120,8 +121,8 @@ class Trainer(object):
# only chief worker will save variables # only chief worker will save variables
self.chief = True self.chief = True
self.checkpoint = checkpoint_config self.checkpoint = checkpoint_config
if self.checkpoint and not isinstance(self.checkpoint, if self.checkpoint and \
CheckpointConfig): not isinstance(self.checkpoint, CheckpointConfig):
raise TypeError( raise TypeError(
"The checkpoint_config shoule be an instance of CheckpointConfig" "The checkpoint_config shoule be an instance of CheckpointConfig"
) )
...@@ -159,6 +160,10 @@ class Trainer(object): ...@@ -159,6 +160,10 @@ class Trainer(object):
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir, io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
self.startup_program) self.startup_program)
if param_path:
# load params from param_path into scope
io.load_persistables(exe, dirname=param_path)
def _transpile_nccl2_dist(self): def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS # PADDLE_TRAINER_IPS
if "PADDLE_TRAINER_IPS" not in os.environ: if "PADDLE_TRAINER_IPS" not in os.environ:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册