提交 5f9ae657 编写于 作者: R Rick Chao 提交者: TensorFlower Gardener

Fix tsan failure in multi_process_runner_test.

PiperOrigin-RevId: 317747609
Change-Id: I8bf2e493431a69a0cf581012666045df2879055e
上级 1c12a84e
......@@ -1794,7 +1794,7 @@ py_test(
name = "multi_process_runner_test",
srcs = ["multi_process_runner_test.py"],
python_version = "PY3",
tags = ["notsan"], # TODO(b/158874970)
shard_count = 12,
deps = [
":multi_process_runner",
":multi_worker_test_base",
......
......@@ -423,6 +423,18 @@ class MultiProcessRunner(object):
def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
"""Joins all the processes with timeout.
If any of the subprocesses does not exit approximately after `timeout`
seconds has passed after `join` call, this raises a
`SubprocessTimeoutError`.
Note: At timeout, it uses SIGTERM to terminate the subprocesses, in order to
log the stack traces of the subprocesses when they exit. However, this
results in timeout when the test runs with tsan (thread sanitizer); if tsan
is being run on the test targets that rely on timeout to assert information,
`MultiProcessRunner.terminate_all()` must be called after `join()`, before
the test exits, so the subprocesses are terminated with SIGKILL, and data
race is removed.
Args:
timeout: if set and not all processes report status within roughly
`timeout` seconds, a `SubprocessTimeoutError` exception will be raised.
......
......@@ -124,24 +124,6 @@ class MultiProcessRunnerTest(test.TestCase):
std_stream_results)
self.assertIn('This is returned data.', return_value)
def test_process_that_exits(self):
def func_to_exit_in_25_sec():
logging.error('foo')
time.sleep(100)
logging.error('bar')
mpr = multi_process_runner.MultiProcessRunner(
func_to_exit_in_25_sec,
multi_worker_test_base.create_cluster_spec(num_workers=1),
list_stdout=True,
max_run_time=25)
mpr.start()
stdout = mpr.join().stdout
self.assertLen([msg for msg in stdout if 'foo' in msg], 1)
self.assertLen([msg for msg in stdout if 'bar' in msg], 0)
def test_termination(self):
def proc_func():
......@@ -301,29 +283,21 @@ class MultiProcessRunnerTest(test.TestCase):
def test_stdout_available_when_timeout(self):
def proc_func():
for i in range(50):
logging.info('(logging) %s-%d, i: %d',
multi_worker_test_base.get_task_type(), self._worker_idx(),
i)
time.sleep(1)
logging.info('something printed')
time.sleep(10000) # Intentionally make the test timeout.
with self.assertRaises(multi_process_runner.SubprocessTimeoutError) as cm:
multi_process_runner.run(
mpr = multi_process_runner.MultiProcessRunner(
proc_func,
multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1),
list_stdout=True,
timeout=5)
multi_worker_test_base.create_cluster_spec(num_workers=1),
list_stdout=True)
mpr.start()
mpr.join(timeout=60)
mpr.terminate_all()
list_to_assert = cm.exception.mpr_result.stdout
# We should see 5 iterations from worker and ps, however sometime on TAP
# due to CPU throttling and slugginess of msan/asan build, this became
# flaky. Therefore we allow more margin of errors to only check the first
# 3 iterations.
for job in ['worker', 'ps']:
for iteration in range(0, 3):
self.assertTrue(
any('(logging) {}-0, i: {}'.format(job, iteration) in line
for line in list_to_assert))
self.assertTrue(
any('something printed' in line for line in list_to_assert))
def test_seg_fault_raises_error(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册