diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 01c40bb90e464f8c39f7310044969fc153792f27..24254b4980c130da886afbe293ea169075933688 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -105,6 +105,7 @@ class Trainer(object): def __init__(self, train_func, optimizer, + param_path=None, place=None, parallel=False, checkpoint_config=None): @@ -120,8 +121,8 @@ class Trainer(object): # only chief worker will save variables self.chief = True self.checkpoint = checkpoint_config - if self.checkpoint and not isinstance(self.checkpoint, - CheckpointConfig): + if self.checkpoint and \ + not isinstance(self.checkpoint, CheckpointConfig): raise TypeError( "The checkpoint_config shoule be an instance of CheckpointConfig" ) @@ -159,6 +160,10 @@ class Trainer(object): io.load_checkpoint(exe, self.checkpoint.checkpoint_dir, self.startup_program) + if param_path: + # load params from param_path into scope + io.load_persistables(exe, dirname=param_path) + def _transpile_nccl2_dist(self): # PADDLE_TRAINER_IPS if "PADDLE_TRAINER_IPS" not in os.environ: