diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 6d8d4a3e43bad3d98ba1b7182f58fae744f57654..e98672f3187da36b0c5f8efd979003733c61bcb4 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -81,7 +81,8 @@ class CheckpointConfig(object): self.epoch_id = 0 self.step_id = 0 - self._load_serial = None + self.load_serial = None + self.is_pserver = False def check_and_get_place(place): @@ -145,7 +146,7 @@ class Trainer(object): "The checkpoint_config shoule be an instance of CheckpointConfig" ) else: - self.checkpoint._load_serial = io.need_load_checkpoint( + self.checkpoint.load_serial = io.need_load_checkpoint( self.checkpoint.checkpoint_dir) self.scope = core.Scope() @@ -176,17 +177,18 @@ class Trainer(object): exe = executor.Executor(place) exe.run(self.startup_program) - if self.checkpoint and self.checkpoint._load_serial: + if self.checkpoint and self.checkpoint.load_serial: exe = executor.Executor(place) io.load_checkpoint(exe, self.checkpoint.checkpoint_dir, - self.checkpoint._load_serial, + self.checkpoint.load_serial, self.startup_program) - epoch_id, step_id = io.load_trainer_args( - self.checkpoint.checkpoint_dir, self.checkpoint._load_serial, - self.trainer_id, ["epoch_id", "step_id"]) - self.checkpoint.epoch_id = int(epoch_id) - self.checkpoint.step_id = int(step_id) + if not self.checkpoint.is_pserver: + epoch_id, step_id = io.load_trainer_args( + self.checkpoint.checkpoint_dir, self.checkpoint.load_serial, + self.trainer_id, ["epoch_id", "step_id"]) + self.checkpoint.epoch_id = int(epoch_id) + self.checkpoint.step_id = int(step_id) if param_path and os.path.isdir(param_path): # load params from param_path into scope @@ -259,6 +261,9 @@ class Trainer(object): t.transpile( trainer_id, pservers=pserver_endpoints, trainers=trainers) if training_role == "PSERVER": + if self.checkpoint: + self.is_pserver = True + self.train_program = t.get_pserver_program(current_endpoint) self.startup_program = t.get_startup_program(current_endpoint, self.train_program) @@ -362,7 +367,7 @@ class Trainer(object): self._clean_checkpoint() return - if self.checkpoint and self.checkpoint._load_serial \ + if self.checkpoint and self.checkpoint.load_serial \ and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id: continue