diff --git a/python/paddle/distributed/fleet/elastic.py b/python/paddle/distributed/fleet/elastic.py index da61169cacc3bc15e99d629dffc8b1600faa40e7..4e8853780f4dcd4a5f35c6fcb022413d068ae243 100644 --- a/python/paddle/distributed/fleet/elastic.py +++ b/python/paddle/distributed/fleet/elastic.py @@ -133,6 +133,7 @@ class ElasticManager(object): self.stopped = False self.sigint = 0 + self.need_sync = False if not server or ':' not in server or not name or not np: logger.info( @@ -177,6 +178,7 @@ class ElasticManager(object): logger.info('register host again {}'.format(self.host)) self.etcd.put(self.host_path, six.b(self.host)) + self.need_sync = True host_watch = self.etcd.add_watch_callback(self.host_path, host_call_back) @@ -254,6 +256,7 @@ class ElasticManager(object): return int(self.etcd.get(self.prefix)[0]) == 1 def _match(self): + self.hosts = [ six.ensure_str(i[0]) for i in self.etcd.get_prefix(self.node_prefix) ] @@ -307,7 +310,8 @@ class ElasticManager(object): self.hosts)) idx += 1 - time.sleep(3) + time.sleep(2) + return def run(self, launcher): @@ -319,6 +323,9 @@ class ElasticManager(object): def watch(self): + if self.need_sync: + self.need_sync = False + while not self.stopped: ret = self.launcher.watch() @@ -334,7 +341,7 @@ class ElasticManager(object): else: return ElasticStatus.ERROR - if not self._completed() and not self._match(): + if not self._completed() and (not self._match() or self.need_sync): self.launcher.stop() return ElasticStatus.HOLD