提交 28373cfe 编写于 作者: F Frank Chen 提交者: TensorFlower Gardener

Adds preliminary support for Cloud TPUs with Cluster Resolvers. This aims to...

Adds preliminary support for Cloud TPUs with Cluster Resolvers. This aims to allow users to have a better experienec when specifying one or multiple Cloud TPUs for their training jobs by allowing users to use names rather than IP addresses.

PiperOrigin-RevId: 163393443
上级 e5353c94
......@@ -28,6 +28,7 @@ py_library(
deps = [
":cluster_resolver_py",
":gce_cluster_resolver_py",
":tpu_cluster_resolver_py",
],
)
......@@ -54,6 +55,18 @@ py_library(
],
)
py_library(
name = "tpu_cluster_resolver_py",
srcs = [
"python/training/tpu_cluster_resolver.py",
],
srcs_version = "PY2AND3",
deps = [
":cluster_resolver_py",
"//tensorflow/python:training",
],
)
tf_py_test(
name = "cluster_resolver_py_test",
size = "small",
......@@ -81,3 +94,17 @@ tf_py_test(
],
main = "python/training/gce_cluster_resolver_test.py",
)
tf_py_test(
name = "tpu_cluster_resolver_py_test",
size = "small",
srcs = ["python/training/tpu_cluster_resolver_test.py"],
additional_deps = [
":tpu_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
main = "python/training/tpu_cluster_resolver_test.py",
)
......@@ -22,3 +22,4 @@ from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for Cloud TPUs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
from tensorflow.python.training.server_lib import ClusterSpec
_GOOGLE_API_CLIENT_INSTALLED = True
try:
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
except ImportError:
_GOOGLE_API_CLIENT_INSTALLED = False
class TPUClusterResolver(ClusterResolver):
"""Cluster Resolver for Google Cloud TPUs.
This is an implementation of cluster resolvers for the Google Cloud TPU
service. As Cloud TPUs are in alpha, you will need to specify a API definition
file for this to consume, in addition to a list of Cloud TPUs in your Google
Cloud Platform project.
"""
def __init__(self,
api_definition,
project,
zone,
tpu_names,
credentials,
job_name='tpu_worker',
service=None):
"""Creates a new TPUClusterResolver object.
The ClusterResolver will then use the parameters to query the Cloud TPU APIs
for the IP addresses and ports of each Cloud TPU listed.
Args:
api_definition: (Alpha only) A copy of the JSON API definitions for
Cloud TPUs. This will be removed once Cloud TPU enters beta.
project: Name of the GCP project containing Cloud TPUs
zone: Zone where the TPUs are located
tpu_names: A list of names of the target Cloud TPUs.
credentials: GCE Credentials.
job_name: Name of the TensorFlow job the TPUs belong to.
service: The GCE API object returned by the googleapiclient.discovery
function. If you specify a custom service object, then the credentials
parameter will be ignored.
Raises:
ImportError: If the googleapiclient is not installed.
"""
self._project = project
self._zone = zone
self._tpu_names = tpu_names
self._job_name = job_name
if service is None:
if not _GOOGLE_API_CLIENT_INSTALLED:
raise ImportError('googleapiclient must be installed before using the '
'TPU cluster resolver')
# TODO(frankchn): Remove once Cloud TPU API Definitions are public and
# replace with discovery.build('tpu', 'v1')
self._service = discovery.build_from_document(api_definition,
credentials=credentials)
else:
self._service = service
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
We retrieve the information from the GCE APIs every time this method is
called.
Returns:
A ClusterSpec containing host information returned from Cloud TPUs.
"""
worker_list = []
for tpu_name in self._tpu_names:
full_name = 'projects/%s/locations/%s/nodes/%s' % (
self._project, self._zone, tpu_name)
request = self._service.projects().locations().nodes().get(name=full_name)
response = request.execute()
instance_url = '%s:%s' % (response.ipAddress, response.port)
worker_list.append(instance_url)
return ClusterSpec({self._job_name: worker_list})
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TPUClusterResolver."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
mock = test.mock
class TPUClusterResolverTest(test.TestCase):
def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
"""Verifies that the ClusterSpec generates the correct proto.
We are testing this four different ways to ensure that the ClusterSpec
returned by the TPUClusterResolver behaves identically to a normal
ClusterSpec when passed into the generic ClusterSpec libraries.
Args:
cluster_spec: ClusterSpec returned by the TPUClusterResolver
expected_proto: Expected protobuf
"""
self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
self.assertProtoEquals(
expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def())
self.assertProtoEquals(
expected_proto,
server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
self.assertProtoEquals(
expected_proto,
server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
def mock_service_client(
self,
tpu_map=None):
if tpu_map is None:
tpu_map = {}
def get_side_effect(name):
return tpu_map[name]
mock_client = mock.MagicMock()
mock_client.projects.locations.nodes.get.side_effect = get_side_effect
return mock_client
def testSimpleSuccessfulRetrieval(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
'port': '8470'
}
}
tpu_cluster_resolver = TPUClusterResolver(
project='test-project',
zone='us-central1-c',
tpu_names=['test-tpu-1'],
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
expected_proto = """
job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } }
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
def testMultipleSuccessfulRetrieval(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
'port': '8470'
},
'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
'ipAddress': '10.4.5.6',
'port': '8470'
}
}
tpu_cluster_resolver = TPUClusterResolver(
project='test-project',
zone='us-central1-c',
tpu_names=['test-tpu-2', 'test-tpu-1'],
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
expected_proto = """
job { name: 'tpu_worker' tasks { key: 0 value: '10.4.5.6:8470' }
tasks { key: 1 value: '10.1.2.3:8470' } }
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册