diff --git a/python/paddle/distributed/launch/controllers/collective.py b/python/paddle/distributed/launch/controllers/collective.py index 302070e55c1ed9d84d13a3dbcc38c2499c2a230a..06612bd7c823d508d8baeeb26efafa453a950686 100644 --- a/python/paddle/distributed/launch/controllers/collective.py +++ b/python/paddle/distributed/launch/controllers/collective.py @@ -52,7 +52,8 @@ class CollectiveController(Controller): self.ctx.logger.debug("job endpoints: {}".format(job_endpoints)) rank_offset = ips.index( - self.ctx.node.ip) if self.ctx.node.ip in ips else 0 + self.ctx.node.ip + ) * self.pod.replicas if self.ctx.node.ip in ips else 0 self.save_pod_log(job_endpoints) @@ -66,7 +67,7 @@ class CollectiveController(Controller): "PADDLE_LOCAL_SIZE": "{}".format(self.pod.replicas), "PADDLE_GLOBAL_RANK": "{}".format(i + rank_offset), "PADDLE_LOCAL_RANK": "{}".format(i), - "PADDLE_NNODES": "{}".format(self.job.replicas), + "PADDLE_NNODES": "{}".format(len(ips)), ## compatible env "PADDLE_TRAINER_ENDPOINTS": ",".join(job_endpoints), "PADDLE_CURRENT_ENDPOINT": job_endpoints[i + rank_offset],