diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index f7f50e76af61b9b530f207d296ad3d77d467050d..d87bdb47932ef16f0c6d47d66a7900c275631014 100644 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -1180,18 +1180,14 @@ class ParameterServerLauncher(object): _, self.current_node_ip = get_host_name_ip() else: self.current_node_ip = pod_ip - if not self.distribute_mode == DistributeMode.PS_HETER: - assert self.current_node_ip in self.node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \ - % (self.current_node_ip, self.node_ips) - if self.current_node_ip in self.node_ips: - self.node_rank = self.node_ips.index(self.current_node_ip) - logger.debug( - "parsed from args: node_ips:{} current_node_ip:{} node_rank:{}". - format(self.node_ips, self.current_node_ip, self.node_rank)) + assert self.current_node_ip in self.node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \ + % (self.current_node_ip, self.node_ips) + self.node_rank = self.node_ips.index(self.current_node_ip) + logger.debug( + "parsed from args: node_ips:{} current_node_ip:{} node_rank:{}". + format(self.node_ips, self.current_node_ip, self.node_rank)) def start_ps(self): - if not self.current_node_ip in self.node_ips: - return cluster = Cluster(hdfs=None) server_rank = 0 worker_rank = 0