From dca0b6d9ccc5b770e78a0903839f2ed89d79be58 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 23 May 2018 19:50:25 +0800 Subject: [PATCH] restore param_path --- python/paddle/fluid/trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 01c40bb90e..24254b4980 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -105,6 +105,7 @@ class Trainer(object): def __init__(self, train_func, optimizer, + param_path=None, place=None, parallel=False, checkpoint_config=None): @@ -120,8 +121,8 @@ class Trainer(object): # only chief worker will save variables self.chief = True self.checkpoint = checkpoint_config - if self.checkpoint and not isinstance(self.checkpoint, - CheckpointConfig): + if self.checkpoint and \ + not isinstance(self.checkpoint, CheckpointConfig): raise TypeError( "The checkpoint_config shoule be an instance of CheckpointConfig" ) @@ -159,6 +160,10 @@ class Trainer(object): io.load_checkpoint(exe, self.checkpoint.checkpoint_dir, self.startup_program) + if param_path: + # load params from param_path into scope + io.load_persistables(exe, dirname=param_path) + def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS if "PADDLE_TRAINER_IPS" not in os.environ: -- GitLab