提交 4351b2de 编写于 作者: Y Yuefeng Zhou 提交者: TensorFlower Gardener

Stop the preemption handler thread synchronously in testOneWorkerPreemptionWithCancellation.

PiperOrigin-RevId: 358651724
Change-Id: Ia66fd8780e874fe7da96a153b763deb135886e95
上级 0e6b3a30
......@@ -623,15 +623,20 @@ class WorkerPreemptionHandler(object):
self._cluster_due_for_update_or_finish = threading.Event()
self._worker_up_cond = threading.Condition(self._cluster_update_lock)
self._should_preemption_thread_run = True
threading.Thread(target=self._preemption_handler,
name="WorkerPreemptionHandler",
daemon=True).start()
self._preemption_handler_thread = threading.Thread(
target=self._preemption_handler,
name="WorkerPreemptionHandler",
daemon=True)
self._preemption_handler_thread.start()
def stop(self):
"""Ensure the worker preemption thread is closed."""
self._should_preemption_thread_run = False
with self._cluster_update_lock:
self._cluster_due_for_update_or_finish.set()
# TODO(yuefengz): The preemption handler thread shouldn't be terminated
# asynchronously since it touches eager context which is a process-wide
# singleton. The problem is in OSS unit tests will time out.
def _validate_preemption_failure(self, e):
"""Validates that the given exception represents worker preemption."""
......@@ -656,6 +661,7 @@ class WorkerPreemptionHandler(object):
Yields:
None.
"""
assert self._should_preemption_thread_run
try:
yield
except errors.OpError as e:
......@@ -689,9 +695,11 @@ class WorkerPreemptionHandler(object):
it waits until all workers are back and updates the cluster about the
restarted workers.
"""
assert self._should_preemption_thread_run
while True:
self._cluster_due_for_update_or_finish.wait()
if not self._should_preemption_thread_run:
logging.info("Stopping the failure handing thread.")
break
with self._cluster_update_lock:
......@@ -704,7 +712,10 @@ class WorkerPreemptionHandler(object):
# all workers that they are recovered from failure.
logging.info("Cluster successfully recovered.")
self._worker_up_cond.notify_all()
self._cluster_due_for_update_or_finish.clear()
# The check for _should_preemption_thread_run is necessary since the
# `stop` may have already set _cluster_due_for_update_or_finish.
if self._should_preemption_thread_run:
self._cluster_due_for_update_or_finish.clear()
except Exception as e: # pylint: disable=broad-except
self._validate_preemption_failure(e)
# NOTE: Since the first RPC (GetStatus) of update_server_def is
......
......@@ -278,6 +278,12 @@ class BaseFaultToleranceTest(object): # pylint: disable=missing-docstring
self.cluster_coord.schedule(normal_function)
self.cluster_coord.join()
# The cluster is likely still being recovered since `join` returned early
# due to the error_function.
failure_handler = self.cluster_coord._cluster.failure_handler
failure_handler.stop()
failure_handler._preemption_handler_thread.join()
def testHandleDatasetCreationFailure(self):
model = Model(self.cluster_coord)
......@@ -443,6 +449,12 @@ class BaseFaultToleranceTest(object): # pylint: disable=missing-docstring
model.join_training_functions()
self.assertGreaterEqual(model.iterations.numpy(), 10)
def testPSFailureWhileRecoveryFromWokerFailure(self):
# Only by adding this empty test, can the problem of b/180348454 be
# reproduced.
# TODO(yuefengz): fill in this test.
pass
def testNumpyFetchedAfterWorkerFailure(self):
with self.strategy.scope():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册