提交 b44ede80 编写于 作者: T tangwei12

bug fix

上级 d712af25
...@@ -252,14 +252,14 @@ class Trainer(object): ...@@ -252,14 +252,14 @@ class Trainer(object):
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
# the unique trainer id, starting from 0, needed by trainer # the unique trainer id, starting from 0, needed by trainer
# only # only
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self.chief = self.trainer_id == 0 self.chief = self.trainer_id == 0
# the role, should be either PSERVER or TRAINER # the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE") training_role = os.getenv("PADDLE_TRAINING_ROLE")
with self._prog_and_scope_guard(): with self._prog_and_scope_guard():
t = distribute_transpiler.DistributeTranspiler() t = distribute_transpiler.DistributeTranspiler()
t.transpile( t.transpile(
trainer_id, pservers=pserver_endpoints, trainers=trainers) self.trainer_id, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER": if training_role == "PSERVER":
if self.checkpoint: if self.checkpoint:
self.is_pserver = True self.is_pserver = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册