From 9e026a93cff29f1d49fac900b3110968da8594cf Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 7 Jun 2018 16:59:53 +0800 Subject: [PATCH] remove chief --- python/paddle/fluid/io.py | 6 ++---- python/paddle/fluid/trainer.py | 5 +---- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 34c527b62..6323c9899 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -466,7 +466,6 @@ CHECKPOINT_SEPARATOR = "_" def save_checkpoint(executor, checkpoint_dir, trainer_id, - is_chief=False, trainer_args=None, main_program=None, max_num_checkpoints=3): @@ -478,8 +477,7 @@ def save_checkpoint(executor, :param executor executor for save the value :param checkpoint_dir the checkpoint directory - :param trainer_id currect trainer id - :param is_chief if the trainer id equals 0, the is_chief will be true + :param trainer_id currect trainer id, if id is equal to 0, the trainer is chief :param main_program will save all variables in program :param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints """ @@ -497,7 +495,7 @@ def save_checkpoint(executor, save_trainer_args(cur_dir, trainer_id, trainer_args) - if is_chief: + if trainer_id == 0: save_persist_vars_without_grad(executor, cur_dir, main_program) _scroll_delete(checkpoint_dir, max_num_checkpoints) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 5230ded7d..2737f1c70 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -136,7 +136,6 @@ class Trainer(object): # config for checkpoint # only chief worker will save variables self.trainer_id = 0 - self.chief = True self.checkpoint_cfg = checkpoint_config if self.checkpoint_cfg: assert isinstance(self.checkpoint_cfg, CheckpointConfig) @@ -201,7 +200,6 @@ class Trainer(object): self.nccl_id_var = None else: self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) - self.chief = self.trainer_id == 0 port = os.getenv("PADDLE_PSERVER_PORT") worker_ips = os.getenv("PADDLE_TRAINER_IPS") worker_endpoints = [] @@ -250,7 +248,7 @@ class Trainer(object): # the unique trainer id, starting from 0, needed by trainer # only self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) - self.chief = self.trainer_id == 0 + # the role, should be either PSERVER or TRAINER training_role = os.getenv("PADDLE_TRAINING_ROLE") with self._prog_and_scope_guard(): @@ -456,7 +454,6 @@ class Trainer(object): executor=exe, checkpoint_dir=self.checkpoint_cfg.checkpoint_dir, trainer_id=self.trainer_id, - is_chief=self.chief, trainer_args=self._get_checkpoint_save_args(epoch_id, step_id), main_program=self.train_program, max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints) -- GitLab