提交 9e026a93 编写于 作者: T tangwei12

remove chief

上级 7fbddaa6
......@@ -466,7 +466,6 @@ CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor,
checkpoint_dir,
trainer_id,
is_chief=False,
trainer_args=None,
main_program=None,
max_num_checkpoints=3):
......@@ -478,8 +477,7 @@ def save_checkpoint(executor,
:param executor executor for save the value
:param checkpoint_dir the checkpoint directory
:param trainer_id currect trainer id
:param is_chief if the trainer id equals 0, the is_chief will be true
:param trainer_id currect trainer id, if id is equal to 0, the trainer is chief
:param main_program will save all variables in program
:param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints
"""
......@@ -497,7 +495,7 @@ def save_checkpoint(executor,
save_trainer_args(cur_dir, trainer_id, trainer_args)
if is_chief:
if trainer_id == 0:
save_persist_vars_without_grad(executor, cur_dir, main_program)
_scroll_delete(checkpoint_dir, max_num_checkpoints)
......
......@@ -136,7 +136,6 @@ class Trainer(object):
# config for checkpoint
# only chief worker will save variables
self.trainer_id = 0
self.chief = True
self.checkpoint_cfg = checkpoint_config
if self.checkpoint_cfg:
assert isinstance(self.checkpoint_cfg, CheckpointConfig)
......@@ -201,7 +200,6 @@ class Trainer(object):
self.nccl_id_var = None
else:
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
self.chief = self.trainer_id == 0
port = os.getenv("PADDLE_PSERVER_PORT")
worker_ips = os.getenv("PADDLE_TRAINER_IPS")
worker_endpoints = []
......@@ -250,7 +248,7 @@ class Trainer(object):
# the unique trainer id, starting from 0, needed by trainer
# only
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self.chief = self.trainer_id == 0
# the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE")
with self._prog_and_scope_guard():
......@@ -456,7 +454,6 @@ class Trainer(object):
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
trainer_id=self.trainer_id,
is_chief=self.chief,
trainer_args=self._get_checkpoint_save_args(epoch_id, step_id),
main_program=self.train_program,
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册