提交 f74cc7a6 编写于 作者: Y Yuefeng Zhou 提交者: TensorFlower Gardener

Use MPR for fault tolerance test

PiperOrigin-RevId: 327766188
Change-Id: I247539f5561940a29fef658818b1e815dd194c1d
上级 e918c5c7
......@@ -870,6 +870,7 @@ py_library(
srcs = ["multi_worker_test_base.py"],
srcs_version = "PY2AND3",
deps = [
":multi_process_runner",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:distributed_framework_test_lib",
......@@ -879,12 +880,22 @@ py_library(
"//tensorflow/python:session",
"//tensorflow/python:training_lib",
"//tensorflow/python:util",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:remote",
"//third_party/py/numpy",
],
)
tf_py_test(
name = "multi_worker_test_base_test",
srcs = ["multi_worker_test_base_test.py"],
srcs_version = "PY2AND3",
deps = [
":multi_worker_test_base",
],
)
cuda_py_test(
name = "checkpoint_utils_test",
size = "medium",
......
......@@ -41,6 +41,9 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.eager import remote
from tensorflow.python.framework import errors
......@@ -200,6 +203,156 @@ def create_in_process_cluster(num_workers,
return cluster
class MultiProcessCluster(object):
"""A cluster of TensorFlow servers in separate processes.
This class is not thread-safe.
"""
def __init__(self, cluster_resolver):
self._cluster_resolver = cluster_resolver
self._cluster_spec = cluster_resolver.cluster_spec().as_dict()
self._rpc_layer = cluster_resolver.rpc_layer
self._start_events = {}
self._finish_events = {}
self._mpr_manager = multi_process_runner.manager()
def task_function(start_events, finish_events):
cluster_resolver = TFConfigClusterResolver()
cluster_spec = cluster_resolver.cluster_spec()
task_type = cluster_resolver.task_type
task_id = cluster_resolver.task_id
rpc_layer = cluster_resolver.rpc_layer
logging.info(
'Starting server with cluster_spec = %r, task_type = %r, '
'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id,
rpc_layer)
# TODO(yuefengz): support GPU clusters.
server_config = config_pb2.ConfigProto()
server_config.device_count['GPU'] = 0
server_lib.Server(
cluster_spec,
job_name=task_type,
protocol=rpc_layer,
task_index=task_id,
config=server_config,
start=True)
start_event = start_events[task_type][task_id]
start_event.set()
finish_event = finish_events[task_type][task_id]
finish_event.wait()
os._exit(0) # pylint: disable=protected-access
self._task_function = task_function
self._mpr = None
def start(self):
"""Starts one TensorFlow server for each task in the cluster_resolver.
It will wait until all the servers are up before returns.
"""
if self._mpr:
raise ValueError('The cluster has already been started.')
for task_type, task_addresses in self._cluster_spec.items():
self._start_events[task_type] = []
self._finish_events[task_type] = []
for _ in task_addresses:
self._start_events[task_type].append(self._mpr_manager.Event())
self._finish_events[task_type].append(self._mpr_manager.Event())
self._mpr = multi_process_runner.MultiProcessRunner(
self._task_function,
self._cluster_spec,
args=(self._start_events, self._finish_events),
rpc_layer=self._rpc_layer,
stream_stdout=False,
list_stdout=False,
use_dill_for_args=False)
self._mpr.start()
for task_type, task_addresses in self._cluster_spec.items():
for i in range(len(task_addresses)):
self._start_events[task_type][i].wait()
def stop(self):
"""Stops all the servers."""
for task_type, task_addresses in self._cluster_spec.items():
for i in range(len(task_addresses)):
self._finish_events[task_type][i].set()
try:
self._mpr.join()
except multi_process_runner.UnexpectedSubprocessExitError:
# TODO(yuefengz): investigate why processes exit with 255.
pass
self._mpr = None
self._start_events = {}
self._finish_events = {}
def kill_task(self, task_type, task_id):
"""Kill a server given task_type and task_id.
Args:
task_type: the type of the task such as "worker".
task_id: the id the task such as 1.
"""
assert self._mpr
if (not self._start_events[task_type][task_id].is_set() or
self._finish_events[task_type][task_id].is_set()):
raise ValueError("The task %s:%d doesn't exist." % (task_type, task_id))
self._finish_events[task_type][task_id].set()
self._mpr._processes[(task_type, task_id)].join()
def start_task(self, task_type, task_id):
"""Starts a server given task_type and task_id.
Args:
task_type: the type of the task such as "worker".
task_id: the id the task such as 1.
Raises:
ValueError: if the server alreay exists.
"""
assert self._mpr
if (not self._start_events[task_type][task_id].is_set() or
not self._finish_events[task_type][task_id].is_set()):
raise ValueError(
'The task %s:%d is still alive. You cannot start another one.' %
(task_type, task_id))
self._start_events[task_type][task_id] = self._mpr_manager.Event()
self._finish_events[task_type][task_id] = self._mpr_manager.Event()
self._mpr.start_single_process(task_type=task_type, task_id=task_id)
self._start_events[task_type][task_id].wait()
@property
def cluster_resolver(self):
return copy.deepcopy(self._cluster_resolver)
def create_multi_process_cluster(num_workers,
num_ps,
has_chief=False,
has_eval=False,
rpc_layer='grpc'):
cluster_spec = create_cluster_spec(
has_chief=has_chief,
num_workers=num_workers,
num_ps=num_ps,
has_eval=has_eval)
cluster = MultiProcessCluster(
SimpleClusterResolver(
server_lib.ClusterSpec(cluster_spec), rpc_layer=rpc_layer))
cluster.start()
return cluster
# TODO(rchao): Remove `test_obj` once estimator repo picks up the updated
# nightly TF.
def create_cluster_spec(has_chief=False,
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for multi-process clusters."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.eager import context
from tensorflow.python.eager import remote
from tensorflow.python.eager import test
class MultiProcessClusterTest(test.TestCase):
def setUp(self):
super(MultiProcessClusterTest, self).setUp()
self._cluster = multi_worker_test_base.create_multi_process_cluster(
num_workers=2, num_ps=1, has_chief=True, rpc_layer="grpc")
remote.connect_to_cluster(
self._cluster.cluster_resolver.cluster_spec(), protocol="grpc")
context.ensure_initialized()
def testClusterIsAlive(self):
self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
self.assertTrue(context.check_alive("/job:worker/replica:0/task:1"))
self.assertTrue(context.check_alive("/job:ps/replica:0/task:0"))
self.assertTrue(context.check_alive("/job:chief/replica:0/task:0"))
def testKillAndStartTask(self):
self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
# It is not allowed to start a task before killing it.
with self.assertRaises(ValueError):
self._cluster.start_task("worker", 0)
self._cluster.kill_task("worker", 0)
self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
# The task is already killed.
with self.assertRaises(ValueError):
self._cluster.kill_task("worker", 0)
self._cluster.start_task("worker", 0)
# Without a call to update_server_def, the next check_alive will return
# False. Alternatively sleeping for 2 seconds here also works.
context.context().update_server_def(context.get_server_def())
self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
def testStop(self):
self._cluster.stop()
self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
self.assertFalse(context.check_alive("/job:worker/replica:0/task:1"))
self.assertFalse(context.check_alive("/job:ps/replica:0/task:0"))
self.assertFalse(context.check_alive("/job:chief/replica:0/task:0"))
def testClusterResolverProperty(self):
cluster_spec = self._cluster.cluster_resolver.cluster_spec().as_dict()
self.assertEqual(len(cluster_spec["worker"]), 2)
self.assertEqual(len(cluster_spec["ps"]), 1)
self.assertEqual(len(cluster_spec["chief"]), 1)
if __name__ == "__main__":
multi_process_runner.test_main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册