提交 ea02fb88 编写于 作者: F Frank Chen 提交者: TensorFlower Gardener

Unify num_accelerators for all Cluster Resolvers

PiperOrigin-RevId: 224843723
上级 8eb8217c
......@@ -22,6 +22,8 @@ import abc
import six
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.training.server_lib import ClusterSpec
......@@ -32,6 +34,14 @@ def format_master_url(master, rpc_layer=None):
return master
def get_accelerator_devices(master, config_proto):
# TODO(frankchn): Add support for eager mode as well as graph mode.
with ops.Graph().as_default():
with session.Session(master, config=config_proto) as s:
devices = s.list_devices()
return devices
@six.add_metaclass(abc.ABCMeta)
class ClusterResolver(object):
"""Abstract class for all implementations of ClusterResolvers.
......@@ -91,7 +101,6 @@ class ClusterResolver(object):
"""
raise NotImplementedError()
@abc.abstractmethod
def num_accelerators(self,
task_type=None,
task_index=None,
......@@ -119,7 +128,9 @@ class ClusterResolver(object):
config_proto: (Optional) Configuration for starting a new session to
query how many accelerator cores it has.
"""
raise NotImplementedError()
master = self.master(task_type, task_index)
devices = get_accelerator_devices(master, config_proto)
return sum(1 for d in devices if d.device_type == accelerator_type)
@abc.abstractproperty
def environment(self):
......
......@@ -18,11 +18,64 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session
from tensorflow.python.distribute.cluster_resolver import ClusterResolver
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.cluster_resolver import UnionClusterResolver
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
mock = test.mock
class MockBaseClusterResolver(ClusterResolver):
def cluster_spec(self):
return None
def master(self, task_type=None, task_index=None, rpc_layer=None):
return ""
def environment(self):
return ""
class BaseClusterResolverTest(test.TestCase):
@mock.patch.object(session.BaseSession, "list_devices")
def testNumAcceleratorsSuccess(self, mock_list_devices):
device_names = [
"/job:worker/task:0/device:GPU:0",
"/job:worker/task:0/device:GPU:1",
"/job:worker/task:0/device:GPU:2",
"/job:worker/task:0/device:GPU:3",
]
device_list = [
session._DeviceAttributes(
name, "GPU", 1024, 0) for name in device_names
]
mock_list_devices.return_value = device_list
resolver = MockBaseClusterResolver()
self.assertEqual(resolver.num_accelerators(), 4)
@mock.patch.object(session.BaseSession, "list_devices")
def testNumAcceleratorsFilterSuccess(self, mock_list_devices):
device_names = [
"/job:worker/task:0/device:TPU:0",
"/job:worker/task:0/device:TPU:1",
"/job:worker/task:0/device:TPU:2",
"/job:worker/task:0/device:TPU:3",
]
device_list = [
session._DeviceAttributes(
name, "TPU", 1024, 0) for name in device_names
]
mock_list_devices.return_value = device_list
resolver = MockBaseClusterResolver()
self.assertEqual(resolver.num_accelerators(), 0)
class UnionClusterResolverTest(test.TestCase):
# TODO(frankchn): Transform to parameterized test after it is included in the
......
......@@ -51,7 +51,6 @@ class GceClusterResolver(ClusterResolver):
task_type='worker',
task_index=0,
rpc_layer='grpc',
num_accelerators=0,
credentials='default',
service=None):
"""Creates a new GceClusterResolver object.
......@@ -73,8 +72,6 @@ class GceClusterResolver(ClusterResolver):
can be distinguished from each other.
rpc_layer: The RPC layer TensorFlow should use to communicate across
instances.
num_accelerators: Number of accelerators (GPUs) present per
instance.
credentials: GCE Credentials. If nothing is specified, this defaults to
GoogleCredentials.get_application_default().
service: The GCE API object returned by the googleapiclient.discovery
......@@ -90,7 +87,6 @@ class GceClusterResolver(ClusterResolver):
self._task_type = task_type
self._task_index = task_index
self._rpc_layer = rpc_layer
self._num_accelerators = num_accelerators
self._port = port
self._credentials = credentials
......@@ -201,12 +197,3 @@ class GceClusterResolver(ClusterResolver):
@rpc_layer.setter
def rpc_layer(self, rpc_layer):
self._rpc_layer = rpc_layer
def num_accelerators(self,
task_type=None,
task_index=None,
accelerator_type='GPU',
config_proto=None):
# Unused
del task_type, task_index, accelerator_type, config_proto
return self._num_accelerators
......@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import device_lib
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
from tensorflow.python.training import server_lib
......@@ -167,16 +166,3 @@ class KubernetesClusterResolver(ClusterResolver):
on internal systems.
"""
return ''
def num_accelerators(self,
task_type=None,
task_index=None,
accelerator_type='GPU',
config_proto=None):
# TODO(frankchn): Make querying non-local accelerators work
if task_type is not None or task_index is not None:
raise NotImplementedError('Querying non-local accelerators is not yet'
'implemented.')
local_devices = device_lib.list_local_devices(config_proto)
return sum(d.device_type == accelerator_type for d in local_devices)
......@@ -54,8 +54,7 @@ class TFConfigClusterResolver(ClusterResolver):
task_type=None,
task_index=None,
rpc_layer=None,
environment=None,
num_accelerators=0):
environment=None):
"""Creates a new TFConfigClusterResolver.
Args:
......@@ -66,17 +65,11 @@ class TFConfigClusterResolver(ClusterResolver):
rpc_layer: (String, optional) Overrides the rpc layer TensorFlow uses.
environment: (String, optional) Overrides the environment TensorFlow
operates in.
num_accelerators: (Integer, optional) Specifies the number of
accelerators (e.g. GPUs, TPUs, others) that each node has.
"""
# TODO(frankchn): num_accelerators is a stop-gap and will be removed
# in favor of autodetection of devices soon.
self._task_type = task_type
self._task_index = task_index
self._rpc_layer = rpc_layer
self._environment = environment
self._num_accelerators = num_accelerators
@property
def task_type(self):
......@@ -117,16 +110,6 @@ class TFConfigClusterResolver(ClusterResolver):
def rpc_layer(self, rpc_layer):
self._rpc_layer = rpc_layer
def num_accelerators(self,
task_type=None,
task_index=None,
accelerator_type='GPU',
config_proto=None):
# TODO(frankchn): Connect to server (w/ session_config) in the future.
# Unused, we do not connect to another server here right now.
del task_type, task_index, accelerator_type, config_proto
return self._num_accelerators
def cluster_spec(self):
"""Returns a ClusterSpec based on the TF_CONFIG environment variable.
......
......@@ -168,13 +168,11 @@ class TFConfigClusterResolverTest(test.TestCase):
}
"""
cluster_resolver = TFConfigClusterResolver(task_type='ps', task_index=0,
num_accelerators=8)
cluster_resolver = TFConfigClusterResolver(task_type='ps', task_index=0)
self.assertEqual('grpc://ps0:2222', cluster_resolver.master())
self.assertEqual('ps', cluster_resolver.task_type)
self.assertEqual(0, cluster_resolver.task_index)
self.assertEqual(8, cluster_resolver.num_accelerators())
cluster_resolver.task_type = 'worker'
cluster_resolver.task_index = 1
......
......@@ -25,11 +25,10 @@ import re
from six.moves.urllib.request import Request
from six.moves.urllib.request import urlopen
from tensorflow.python.client import session
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import get_accelerator_devices
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
......@@ -451,17 +450,16 @@ class TPUClusterResolver(ClusterResolver):
retrieve the system metadata.
Raises:
RuntimeError: If this is used with a non-TPU accelerator_type.
RuntimeError: If we cannot talk to a TPU worker after retrying or if the
number of TPU devices per host is different.
"""
retry_count = 1
# TODO(b/120564445): Replace with standard library for retries.
while True:
try:
with ops.Graph().as_default():
with session.Session(self.master(), config=config_proto) as s:
devices = s.list_devices()
device_details = _get_device_dict_and_cores(devices)
break
device_details = _get_device_dict_and_cores(
get_accelerator_devices(self.master(), config_proto=config_proto))
break
except errors.DeadlineExceededError:
error_message = ('Failed to connect to master. The TPU might not be '
'ready (e.g. still scheduling) or the master '
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册