From 8644b6d4c77646407758a2ef93eb3567f9f03577 Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Wed, 12 Dec 2018 14:50:13 -0800 Subject: [PATCH] Move reduce non distributed values and share the code with TPU Strategy and also improve print output of TPUMirroredVariable. PiperOrigin-RevId: 225259008 --- .../distribute/python/mirrored_strategy.py | 1 - .../python/parameter_server_strategy.py | 2 +- .../contrib/distribute/python/tpu_strategy.py | 8 ++++ .../python/distribute/cross_device_ops.py | 37 ++++++++++++++ .../python/distribute/mirrored_strategy.py | 48 ++----------------- tensorflow/python/distribute/values.py | 6 +++ 6 files changed, 57 insertions(+), 45 deletions(-) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 20f1a08d426..24399db6522 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -28,7 +28,6 @@ from tensorflow.python.distribute import values # pylint: disable=protected-access,invalid-name _call_for_each_replica = mirrored_strategy._call_for_each_replica -_reduce_non_distributed_value = mirrored_strategy._reduce_non_distributed_value _create_mirrored_variable = mirrored_strategy._create_mirrored_variable all_local_devices = mirrored_strategy.all_local_devices CoreMirroredStrategy = mirrored_strategy.MirroredStrategy diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 2c7766f95fb..ca51b07be66 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -356,7 +356,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): self._verify_destinations_not_different_worker(destinations) if not isinstance(value, values.DistributedValues): # pylint: disable=protected-access - return mirrored_strategy._reduce_non_distributed_value( + return cross_device_ops_lib.reduce_non_distributed_value( self, reduce_op, value, destinations) return self._cross_device_ops.reduce( reduce_op, value, destinations=destinations) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index b6f5b492017..7ea245eb6eb 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -465,6 +465,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) + if not isinstance(value, values.DistributedValues): + # This function handles reducing values that are not PerReplica or + # Mirrored values. For example, the same value could be present on all + # replicas in which case `value` would be a single value or value could + # be 0. + return cross_device_ops_lib.reduce_non_distributed_value( + self, reduce_op, value, destinations) + # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index 57c552ca8f0..6bb3639bf01 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -62,6 +62,43 @@ def validate_destinations(destinations): raise ValueError("destinations can not be empty") +def reduce_non_distributed_value(extended, reduce_op, value, destinations): + """Reduce a non-DistributedValue `value` to `destinations`.""" + if isinstance(value, value_lib.DistributedValues): + raise ValueError("You are passing a `DistributedValue` to " + "`reduce_non_distributed_value`, which is not allowed.") + + # If the same value is present on all replicas then the PerReplica value will + # be a single value. We also handle the case when `value` is a single value + # and equal to 0. + if value == 0: + return 0 + # If there is only a single value and the reduce op is MEAN, + # that value should be on all destinations. + if reduce_op == reduce_util.ReduceOp.MEAN: + return value + + validate_destinations(destinations) + # We do not support a reduce op of SUM if the value is the same across + # all replicas. We call this as part of assign functions for MirroredVariables + # and summing up identical values across replicas is not clearly defined. + if (len(extended.worker_devices) != 1 or + not check_destinations(destinations)): + raise ValueError("A non-DistributedValues value %s cannot be reduced with " + "the given reduce op %s." % (value, reduce_op)) + # TODO(anjalisridhar): Moves these methods to a device utility file? + devices = get_devices_from(destinations) + if len(devices) == 1: + with ops.device(devices[0]): + return array_ops.identity(value) + else: + value_updates = {} + for d in devices: + with ops.device(d): + value_updates[d] = array_ops.identity(value) + return value_lib.Mirrored(value_updates) + + def _make_tensor_into_per_replica(input_tensor): """Converts a single tensor into a PerReplica object.""" if isinstance(input_tensor, (tuple, list)): diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 605e2cc8e78..fb3cf844492 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -74,10 +74,9 @@ class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name pass -# _call_for_each_replica and _reduce_non_distributed_value are not members of -# MirroredStrategy so that they are generally not allowed to use anything -# specific to MirroredStrategy and thus can be shared with other distribution -# strategies. +# _call_for_each_replica is not a member of MirroredStrategy so that it is +# not allowed to use anything specific to MirroredStrategy and thus +# can be shared with other distribution strategies. # TODO(yuefengz): maybe create a common class for those who need to call this @@ -192,43 +191,6 @@ def _call_for_each_replica(distribution, fn, args, kwargs): return values.regroup({t.device: t.main_result for t in threads}) -def _reduce_non_distributed_value(extended, reduce_op, value, destinations): - """Reduce a non-DistributedValue `value` to `destinations`.""" - if isinstance(value, values.DistributedValues): - raise ValueError("You are passing a `DistributedValue` to " - "`_reduce_non_distributed_value`, which is not allowed.") - - # If the same value is present on all replicas then the PerReplica value will - # be a single value. We also handle the case when `value` is a single value - # and equal to 0. - if value == 0: - return 0 - # If there is only a single value and the reduce op is MEAN, - # that value should be on all destinations. - if reduce_op == reduce_util.ReduceOp.MEAN: - return value - - cross_device_ops_lib.validate_destinations(destinations) - # We do not support a reduce op of SUM if the value is the same across - # all replicas. We call this as part of assign functions for MirroredVariables - # and summing up identical values across replicas is not clearly defined. - if (len(extended.worker_devices) != 1 or - not cross_device_ops_lib.check_destinations(destinations)): - raise ValueError("A non-DistributedValues value %s cannot be reduced with " - "the given reduce op %s." % (value, reduce_op)) - # TODO(anjalisridhar): Moves these methods to a device utility file? - devices = cross_device_ops_lib.get_devices_from(destinations) - if len(devices) == 1: - with ops.device(devices[0]): - return array_ops.identity(value) - else: - value_updates = {} - for d in devices: - with ops.device(d): - value_updates[d] = array_ops.identity(value) - return values.Mirrored(value_updates) - - def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. @@ -714,8 +676,8 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended): # Mirrored values. For example, the same value could be present on all # replicas in which case `value` would be a single value or value could # be 0. - return _reduce_non_distributed_value(self, reduce_op, value, - destinations) + return cross_device_ops_lib.reduce_non_distributed_value( + self, reduce_op, value, destinations) return self._get_cross_device_ops().reduce( reduce_op, value, destinations=destinations) diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 01a1680a246..a5918b7b731 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -570,6 +570,12 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): # See https://docs.python.org/3/library/constants.html#NotImplemented return NotImplemented + def __str__(self): + return "%s:%s" % (self.__class__.__name__, self._index) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._index) + @property def handle(self): # If we're in a tpu.rewrite(), return the replicated handle. -- GitLab