提交 dca0b6d9 编写于 作者: T tangwei12

restore param_path

上级 b044724d
......@@ -105,6 +105,7 @@ class Trainer(object):
def __init__(self,
train_func,
optimizer,
param_path=None,
place=None,
parallel=False,
checkpoint_config=None):
......@@ -120,8 +121,8 @@ class Trainer(object):
# only chief worker will save variables
self.chief = True
self.checkpoint = checkpoint_config
if self.checkpoint and not isinstance(self.checkpoint,
CheckpointConfig):
if self.checkpoint and \
not isinstance(self.checkpoint, CheckpointConfig):
raise TypeError(
"The checkpoint_config shoule be an instance of CheckpointConfig"
)
......@@ -159,6 +160,10 @@ class Trainer(object):
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
self.startup_program)
if param_path:
# load params from param_path into scope
io.load_persistables(exe, dirname=param_path)
def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS
if "PADDLE_TRAINER_IPS" not in os.environ:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册