diff --git a/python/paddle/distributed/launch/context/node.py b/python/paddle/distributed/launch/context/node.py index 39f42d02107a2fbe9bcbb55e20bbc41b32da4ebf..6ee8fa6d10c86703fd347e719bd0f413e635fadf 100644 --- a/python/paddle/distributed/launch/context/node.py +++ b/python/paddle/distributed/launch/context/node.py @@ -49,6 +49,8 @@ class Node(object): for _ in range(100): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, + struct.pack('ii', 1, 0)) s.bind(('', 0)) port = s.getsockname()[1] if port in self._allocated_ports: diff --git a/python/paddle/distributed/launch/controllers/collective.py b/python/paddle/distributed/launch/controllers/collective.py index e155d31b459ff51a610a56d572c64d63726dadcd..302070e55c1ed9d84d13a3dbcc38c2499c2a230a 100644 --- a/python/paddle/distributed/launch/controllers/collective.py +++ b/python/paddle/distributed/launch/controllers/collective.py @@ -93,7 +93,7 @@ class CollectiveController(Controller): self.pod.replicas = self.pod_replicas() # rank will be reset when restart - self.pod.rank = self.ctx.args.rank + self.pod.rank = int(self.ctx.args.rank) port = self.ctx.node.get_free_port() diff --git a/python/paddle/distributed/launch/controllers/master.py b/python/paddle/distributed/launch/controllers/master.py index 825be9c36888ccd8178f6abb118626eaf51cc040..c71d0890f196a0394f2edd78d61415b841f2562f 100644 --- a/python/paddle/distributed/launch/controllers/master.py +++ b/python/paddle/distributed/launch/controllers/master.py @@ -102,7 +102,7 @@ class HTTPMaster(Master): print(" ".join(cmd)) print("-" * 80) - if self.ctx.args.rank >= 0: + if int(self.ctx.args.rank) >= 0: self.ctx.logger.warning( "--rank set in the command may not compatible in auto mode") diff --git a/python/paddle/distributed/launch/controllers/ps.py b/python/paddle/distributed/launch/controllers/ps.py index 573f578d249e133b3a3f39444c93c273cb695844..f785311a525402c49008658f73a39e69502a1e30 100644 --- a/python/paddle/distributed/launch/controllers/ps.py +++ b/python/paddle/distributed/launch/controllers/ps.py @@ -111,7 +111,7 @@ class PSController(Controller): def _build_pod_with_master(self): - self.pod.rank = self.ctx.args.rank + self.pod.rank = int(self.ctx.args.rank) server_num = self.ctx.args.server_num or 1 servers = [