提交 4d13d641 编写于 作者: R Rick Chao 提交者: TensorFlower Gardener

Make cluster_resolver standard property in tf.distribute strategies.

PiperOrigin-RevId: 317771299
Change-Id: I71b5c585cef7bd7ef80e66b75e30287fddcf89e2
上级 e74a115b
......@@ -204,6 +204,7 @@ py_test(
"//tensorflow/python:variables",
"//tensorflow/python/autograph/core:test_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//third_party/py/numpy",
],
)
......@@ -1847,10 +1848,11 @@ py_test(
],
)
cuda_py_test(
distribute_py_test(
name = "strategy_common_test",
srcs = ["strategy_common_test.py"],
python_version = "PY3",
shard_count = 12,
tags = [
"multi_and_single_gpu",
# TODO(b/155301154): Enable this test on multi-gpu guitar once multi process
......@@ -1859,6 +1861,7 @@ cuda_py_test(
],
xla_enable_strict_auto_jit = True,
deps = [
":collective_all_reduce_strategy",
":combinations",
":multi_worker_test_base",
":reduce_util",
......
......@@ -138,6 +138,18 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
"""
return super(CollectiveAllReduceStrategy, self).scope()
@property
def cluster_resolver(self):
"""Returns the cluster resolver associated with this strategy.
As a multi-worker strategy,
`tf.distribute.experimental.MultiWorkerMirroredStrategy` provides the
associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user
provides one in `__init__`, that instance is returned; if the user does
not, a default `TFConfigClusterResolver` is provided.
"""
return self.extended._cluster_resolver # pylint: disable=protected-access
@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring
class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
......
......@@ -505,8 +505,7 @@ class DistributedCollectiveAllReduceStrategyTest(
self.assertEqual(['CollectiveReduce'],
new_rewrite_options.scoped_allocator_opts.enable_op)
@combinations.generate(combinations.combine(mode=['eager']))
def testEnableCollectiveOps(self):
def _get_strategy_with_mocked_methods(self):
mock_called = [False]
# pylint: disable=dangerous-default-value
......@@ -525,9 +524,21 @@ class DistributedCollectiveAllReduceStrategyTest(
mock_configure_collective_ops):
strategy, _, _ = self._get_test_object(
task_type='worker', task_id=1, num_gpus=2)
return strategy, mock_called
@combinations.generate(combinations.combine(mode=['eager']))
def testEnableCollectiveOps(self):
strategy, mock_called = self._get_strategy_with_mocked_methods()
self.assertTrue(strategy.extended._std_server_started)
self.assertTrue(mock_called[0])
@combinations.generate(combinations.combine(mode=['eager']))
def testEnableCollectiveOpsAndClusterResolver(self):
strategy, _ = self._get_strategy_with_mocked_methods()
self.assertEqual(strategy.cluster_resolver.task_type, 'worker')
self.assertEqual(strategy.cluster_resolver.task_id, 1)
class DistributedCollectiveAllReduceStrategyTestWithChief(
CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
......
......@@ -1439,6 +1439,65 @@ class StrategyBase(object):
def __copy__(self):
raise RuntimeError("Must only deepcopy DistributionStrategy.")
@property
def cluster_resolver(self):
"""Returns the cluster resolver associated with this strategy.
In general, when using a multi-worker `tf.distribute` strategy such as
`tf.distribute.experimental.MultiWorkerMirroredStrategy` or
`tf.distribute.experimental.TPUStrategy()`, there is a
`tf.distribute.cluster_resolver.ClusterResolver` associated with the
strategy used, and such an instance is returned by this property.
Strategies that intend to have an associated
`tf.distribute.cluster_resolver.ClusterResolver` must set the
relevant attribute, or override this property; otherwise, `None` is returned
by default. Those strategies should also provide information regarding what
is returned by this property.
Single-worker strategies usually do not have a
`tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this
property will return `None`.
The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the
user needs to access information such as the cluster spec, task type or task
id. For example,
```python
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"],
'ps': ["localhost:34567"]
},
'task': {'type': 'worker', 'index': 0}
})
# This implicitly uses TF_CONFIG for the cluster and current task info.
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
...
if strategy.cluster_resolver.task_type == 'worker':
# Perform something that's only applicable on workers. Since we set this
# as a worker above, this block will run on this particular instance.
elif strategy.cluster_resolver.task_type == 'ps':
# Perform something that's only applicable on parameter servers. Since we
# set this as a worker above, this block will not run on this particular
# instance.
```
For more information, please see
`tf.distribute.cluster_resolver.ClusterResolver`'s API docstring.
Returns:
The cluster resolver associated with this strategy. Returns `None` if a
cluster resolver is not applicable or available in this strategy.
"""
if hasattr(self.extended, "_cluster_resolver"):
return self.extended._cluster_resolver # pylint: disable=protected-access
return None
@tf_export("distribute.Strategy", v1=[]) # pylint: disable=g-missing-docstring
class Strategy(StrategyBase):
......
......@@ -28,6 +28,7 @@ from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
......@@ -36,6 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
......@@ -422,6 +424,17 @@ class TestStrategyTest(test.TestCase):
test_fn()
def testClusterResolverDefaultNotImplemented(self):
dist = _TestStrategy()
self.assertIsNone(dist.cluster_resolver)
base_cluster_spec = server_lib.ClusterSpec({
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
})
cluster_resolver = SimpleClusterResolver(base_cluster_spec)
dist.extended._cluster_resolver = cluster_resolver
self.assertIs(dist.cluster_resolver, cluster_resolver)
# _TestStrategy2 is like _TestStrategy, except it doesn't change variable
# creation.
......
......@@ -27,6 +27,8 @@ from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy
from tensorflow.python.distribute.tpu_strategy import TPUStrategy
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
......@@ -184,5 +186,38 @@ class DistributedCollectiveAllReduceStrategyTest(
# worker strategy combinations can run on a fixed number of GPUs.
class StrategyClusterResolverTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
strategy=[strategy_combinations.multi_worker_mirrored_two_workers] +
strategy_combinations.all_strategies,
mode=['eager']))
def testClusterResolverProperty(self, strategy):
# CollectiveAllReduceStrategy and TPUStrategy must have a cluster resolver.
# `None` otherwise.
resolver = strategy.cluster_resolver
if not isinstance(strategy, CollectiveAllReduceStrategy) and not isinstance(
strategy, TPUStrategy):
self.assertIsNone(resolver)
return
with strategy.scope():
self.assertIs(strategy.cluster_resolver, resolver)
self.assertTrue(hasattr(resolver, 'cluster_spec'))
self.assertTrue(hasattr(resolver, 'environment'))
self.assertTrue(hasattr(resolver, 'master'))
self.assertTrue(hasattr(resolver, 'num_accelerators'))
self.assertIsNone(resolver.rpc_layer)
if isinstance(strategy, CollectiveAllReduceStrategy):
self.assertGreaterEqual(resolver.task_id, 0)
self.assertLessEqual(resolver.task_id, 1)
self.assertEqual(resolver.task_type, 'worker')
elif isinstance(strategy, TPUStrategy):
# TPUStrategy does not have task_id and task_type applicable.
self.assertIsNone(resolver.task_id)
self.assertIsNone(resolver.task_type)
if __name__ == '__main__':
combinations.main()
......@@ -345,6 +345,18 @@ class TPUStrategy(distribute_lib.Strategy):
options = options or distribute_lib.RunOptions()
return self.extended.tpu_run(fn, args, kwargs, options)
@property
def cluster_resolver(self):
"""Returns the cluster resolver associated with this strategy.
`tf.distribute.experimental.TPUStrategy` provides the
associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user
provides one in `__init__`, that instance is returned; if the user does
not, a default
`tf.distribute.cluster_resolver.TPUClusterResolver` is provided.
"""
return self.extended._tpu_cluster_resolver # pylint: disable=protected-access
@tf_export(v1=["distribute.experimental.TPUStrategy"])
class TPUStrategyV1(distribute_lib.StrategyV1):
......
......@@ -555,6 +555,13 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
update_variable.get_concrete_function()
self.assertLen(strategy.extended.worker_devices, trace_count[0])
def test_cluster_resolver_available(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
tpu_strategy_util.initialize_tpu_system(resolver)
strategy = tpu_lib.TPUStrategy(resolver)
self.assertIs(strategy.cluster_resolver, resolver)
class TPUStrategyDataPrefetchTest(test.TestCase):
......
......@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -3,6 +3,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -3,6 +3,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
......@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册