diff --git a/python/paddle/distributed/launch/context/__init__.py b/python/paddle/distributed/launch/context/__init__.py index 902c8189b17200cd296c5d33f2e6c534dca2e4dc..3e8f0de3e69d5539293b7514539acd155f758099 100644 --- a/python/paddle/distributed/launch/context/__init__.py +++ b/python/paddle/distributed/launch/context/__init__.py @@ -76,6 +76,10 @@ class Context(object): def get_envs(self): return self.envs.copy() + def set_envs(self, env={}): + env = {k: v for k, v in env.items() if isinstance(v, str)} + self.envs.update(env) + def _enable_plugin(self): for pl in plugins.enabled_plugins: pl(self) diff --git a/python/paddle/distributed/launch/controllers/controller.py b/python/paddle/distributed/launch/controllers/controller.py index 1f43679d748f1175b09cc3033f8ca63f1751286a..bc628be59dc22e04d937529c3a0afd66f6e30b0d 100644 --- a/python/paddle/distributed/launch/controllers/controller.py +++ b/python/paddle/distributed/launch/controllers/controller.py @@ -49,6 +49,8 @@ class ControllerBase(object): jid=self.ctx.args.job_id) self.pod = Pod() + self.ctx.set_envs({"POD_NAME": self.pod.name}) + self.join_server = None def deploy_pod(self): @@ -104,17 +106,18 @@ class ControllerBase(object): self.ctx.logger.info("Pod {}".format(status)) self.ctx.logger.error("Container failed !!!\n{}".format(fc[0])) fc[0].tail() - self.pod.stop() if self.ctx.args.elastic_level <= 0: + self.pod.stop(timeout=3) return True else: + self.pod.stop(timeout=30) return False # peer failure if self.ctx.status.is_restarting( ) and self.master.get_status() != self.ctx.status.COMPLETED: - self.pod.stop() + self.pod.stop(timeout=30) return False def stop(self, sigint=None): @@ -123,7 +126,7 @@ class ControllerBase(object): self.watcher.stop() self.master.stop() - self.pod.stop(sigint) + self.pod.stop(timeout=30) def finalize(self): self.pod.join() @@ -133,17 +136,16 @@ class ControllerBase(object): sys.exit(self.pod.exit_code) def signal_handler(self, sigint, frame): - self.ctx.logger.info("Terminating with signal {}".format(sigint)) - if hasattr(self, 'sigint'): self.ctx.logger.info("Force quit in 10 seconds...") - time.sleep(11) + self.pod.stop(timeout=10) sys.exit(sigint) + self.ctx.logger.info("Terminating with signal {}".format(sigint)) + self.sigint = sigint self.ctx.status.done() - self.stop(sigint) - time.sleep(1) + self.stop(sigint=sigint) self.ctx.logger.info("Exit with signal {}".format(sigint)) sys.exit(sigint) diff --git a/python/paddle/distributed/launch/controllers/master.py b/python/paddle/distributed/launch/controllers/master.py index 8e8d31f86dd9fe9d805c79dd2760eb24003b63cd..825be9c36888ccd8178f6abb118626eaf51cc040 100644 --- a/python/paddle/distributed/launch/controllers/master.py +++ b/python/paddle/distributed/launch/controllers/master.py @@ -316,5 +316,5 @@ class ETCDMaster(Master): def stop(self): if hasattr(self, 'beat_thread'): self.ctx.status.done() - # TODO(kuizhiqing) thread should exit + # daemon thread #self.beat_thread.join() diff --git a/python/paddle/distributed/launch/controllers/watcher.py b/python/paddle/distributed/launch/controllers/watcher.py index 6e8a2cc4e87818902dc100f5361e9c1873b88205..4b8e346e7908fc9ece77c72a3cba1608369b6be7 100644 --- a/python/paddle/distributed/launch/controllers/watcher.py +++ b/python/paddle/distributed/launch/controllers/watcher.py @@ -93,4 +93,6 @@ class Watcher(object): def stop(self): if hasattr(self, "proc"): - self.proc.join() + # daemon without join + # self.proc.join() + pass diff --git a/python/paddle/distributed/launch/job/container.py b/python/paddle/distributed/launch/job/container.py index 8f515d9e6f38b6770daf4abbde2c936234bdcb26..e0f580da0ac45d125fa8355029dba601f7519436 100644 --- a/python/paddle/distributed/launch/job/container.py +++ b/python/paddle/distributed/launch/job/container.py @@ -131,7 +131,11 @@ class Container(object): return self._proc.terminate(force) def wait(self, timeout=None): - self._proc.wait(timeout) + try: + self._proc.wait(timeout) + return True + except Exception: + return False @property def exit_code(self): diff --git a/python/paddle/distributed/launch/job/pod.py b/python/paddle/distributed/launch/job/pod.py index cda400f0a324a62e90439874897e24ee2a4dc8e1..c99b2db547a268465458cff5bca3903b54f20ef1 100644 --- a/python/paddle/distributed/launch/job/pod.py +++ b/python/paddle/distributed/launch/job/pod.py @@ -116,14 +116,26 @@ class Pod(PodSepc): self._restart += 1 - def stop(self, sigint=0): + def stop(self, sigint=15, timeout=None): for c in self._containers: - force = True if sigint == 9 else False - c.terminate(force) + if isinstance(sigint, int) and timeout is None: + c.send_signal(sigint) + else: + c.terminate() + + if isinstance(timeout, int): + if not self.join(timeout): + for c in self._containers: + c.terminate(force=True) + return False + else: + return True - def join(self): + def join(self, timeout=None): for c in self._containers: - c.wait(None) + if not c.wait(timeout): + return False + return True @property def status(self):