提交 8644b6d4 编写于 作者: S Sourabh Bajaj 提交者: TensorFlower Gardener

Move reduce non distributed values and share the code with TPU Strategy and...

Move reduce non distributed values and share the code with TPU Strategy and also improve print output of TPUMirroredVariable.

PiperOrigin-RevId: 225259008
上级 9ed22473
......@@ -28,7 +28,6 @@ from tensorflow.python.distribute import values
# pylint: disable=protected-access,invalid-name
_call_for_each_replica = mirrored_strategy._call_for_each_replica
_reduce_non_distributed_value = mirrored_strategy._reduce_non_distributed_value
_create_mirrored_variable = mirrored_strategy._create_mirrored_variable
all_local_devices = mirrored_strategy.all_local_devices
CoreMirroredStrategy = mirrored_strategy.MirroredStrategy
......
......@@ -356,7 +356,7 @@ class ParameterServerExtended(distribute_lib.DistributionStrategyExtended):
self._verify_destinations_not_different_worker(destinations)
if not isinstance(value, values.DistributedValues):
# pylint: disable=protected-access
return mirrored_strategy._reduce_non_distributed_value(
return cross_device_ops_lib.reduce_non_distributed_value(
self, reduce_op, value, destinations)
return self._cross_device_ops.reduce(
reduce_op, value, destinations=destinations)
......
......@@ -465,6 +465,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
"Currently only support sum & mean in TPUStrategy.")
return tpu_ops.cross_replica_sum(value)
if not isinstance(value, values.DistributedValues):
# This function handles reducing values that are not PerReplica or
# Mirrored values. For example, the same value could be present on all
# replicas in which case `value` would be a single value or value could
# be 0.
return cross_device_ops_lib.reduce_non_distributed_value(
self, reduce_op, value, destinations)
# Validate that the destination is same as the host device
# Note we don't do this when in replicate context as the reduction is
# performed on the TPU device itself.
......
......@@ -62,6 +62,43 @@ def validate_destinations(destinations):
raise ValueError("destinations can not be empty")
def reduce_non_distributed_value(extended, reduce_op, value, destinations):
"""Reduce a non-DistributedValue `value` to `destinations`."""
if isinstance(value, value_lib.DistributedValues):
raise ValueError("You are passing a `DistributedValue` to "
"`reduce_non_distributed_value`, which is not allowed.")
# If the same value is present on all replicas then the PerReplica value will
# be a single value. We also handle the case when `value` is a single value
# and equal to 0.
if value == 0:
return 0
# If there is only a single value and the reduce op is MEAN,
# that value should be on all destinations.
if reduce_op == reduce_util.ReduceOp.MEAN:
return value
validate_destinations(destinations)
# We do not support a reduce op of SUM if the value is the same across
# all replicas. We call this as part of assign functions for MirroredVariables
# and summing up identical values across replicas is not clearly defined.
if (len(extended.worker_devices) != 1 or
not check_destinations(destinations)):
raise ValueError("A non-DistributedValues value %s cannot be reduced with "
"the given reduce op %s." % (value, reduce_op))
# TODO(anjalisridhar): Moves these methods to a device utility file?
devices = get_devices_from(destinations)
if len(devices) == 1:
with ops.device(devices[0]):
return array_ops.identity(value)
else:
value_updates = {}
for d in devices:
with ops.device(d):
value_updates[d] = array_ops.identity(value)
return value_lib.Mirrored(value_updates)
def _make_tensor_into_per_replica(input_tensor):
"""Converts a single tensor into a PerReplica object."""
if isinstance(input_tensor, (tuple, list)):
......
......@@ -74,10 +74,9 @@ class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name
pass
# _call_for_each_replica and _reduce_non_distributed_value are not members of
# MirroredStrategy so that they are generally not allowed to use anything
# specific to MirroredStrategy and thus can be shared with other distribution
# strategies.
# _call_for_each_replica is not a member of MirroredStrategy so that it is
# not allowed to use anything specific to MirroredStrategy and thus
# can be shared with other distribution strategies.
# TODO(yuefengz): maybe create a common class for those who need to call this
......@@ -192,43 +191,6 @@ def _call_for_each_replica(distribution, fn, args, kwargs):
return values.regroup({t.device: t.main_result for t in threads})
def _reduce_non_distributed_value(extended, reduce_op, value, destinations):
"""Reduce a non-DistributedValue `value` to `destinations`."""
if isinstance(value, values.DistributedValues):
raise ValueError("You are passing a `DistributedValue` to "
"`_reduce_non_distributed_value`, which is not allowed.")
# If the same value is present on all replicas then the PerReplica value will
# be a single value. We also handle the case when `value` is a single value
# and equal to 0.
if value == 0:
return 0
# If there is only a single value and the reduce op is MEAN,
# that value should be on all destinations.
if reduce_op == reduce_util.ReduceOp.MEAN:
return value
cross_device_ops_lib.validate_destinations(destinations)
# We do not support a reduce op of SUM if the value is the same across
# all replicas. We call this as part of assign functions for MirroredVariables
# and summing up identical values across replicas is not clearly defined.
if (len(extended.worker_devices) != 1 or
not cross_device_ops_lib.check_destinations(destinations)):
raise ValueError("A non-DistributedValues value %s cannot be reduced with "
"the given reduce op %s." % (value, reduce_op))
# TODO(anjalisridhar): Moves these methods to a device utility file?
devices = cross_device_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
with ops.device(devices[0]):
return array_ops.identity(value)
else:
value_updates = {}
for d in devices:
with ops.device(d):
value_updates[d] = array_ops.identity(value)
return values.Mirrored(value_updates)
def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring
# Figure out what collections this variable should be added to.
# We'll add the MirroredVariable to those collections instead.
......@@ -714,8 +676,8 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
# Mirrored values. For example, the same value could be present on all
# replicas in which case `value` would be a single value or value could
# be 0.
return _reduce_non_distributed_value(self, reduce_op, value,
destinations)
return cross_device_ops_lib.reduce_non_distributed_value(
self, reduce_op, value, destinations)
return self._get_cross_device_ops().reduce(
reduce_op, value, destinations=destinations)
......
......@@ -570,6 +570,12 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase):
# See https://docs.python.org/3/library/constants.html#NotImplemented
return NotImplemented
def __str__(self):
return "%s:%s" % (self.__class__.__name__, self._index)
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self._index)
@property
def handle(self):
# If we're in a tpu.rewrite(), return the replicated handle.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册