提交 5fea53a7 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 418724903
上级 77d9fd62
......@@ -201,3 +201,74 @@ def safe_mean(losses):
total = tf.reduce_sum(losses)
num_elements = tf.cast(tf.size(losses), dtype=losses.dtype)
return tf.math.divide_no_nan(total, num_elements)
def get_replica_id():
"""Gets replica id depending on the environment."""
context = tf.distribute.get_replica_context()
if context is not None:
return context.replica_id_in_sync_group
else:
raise RuntimeError("Unknown replica context. The `get_replica_id` method "
"relies on TF 2.x tf.distribute API.")
def cross_replica_concat(value, axis, name="cross_replica_concat"):
"""Concatenates the given `value` across (GPU/TPU) cores, along `axis`.
In general, each core ("replica") will pass a
replica-specific value as `value` (corresponding to some element of a
data-parallel computation taking place across replicas).
The resulting concatenated `Tensor` will have the same shape as `value` for
all dimensions except `axis`, where it will be larger by a factor of the
number of replicas. It will also have the same `dtype` as `value`.
The position of a given replica's `value` within the resulting concatenation
is determined by that replica's replica ID. For
example:
With `value` for replica 0 given as
0 0 0
0 0 0
and `value` for replica 1 given as
1 1 1
1 1 1
the resulting concatenation along axis 0 will be
0 0 0
0 0 0
1 1 1
1 1 1
and this result will be identical across all replicas.
Note that this API only works in TF2 with `tf.distribute`.
Args:
value: The `Tensor` to concatenate across replicas. Each replica will have a
different value for this `Tensor`, and these replica-specific values will
be concatenated.
axis: The axis along which to perform the concatenation as a Python integer
(not a `Tensor`). E.g., `axis=0` to concatenate along the batch dimension.
name: A name for the operation (used to create a name scope).
Returns:
The result of concatenating `value` along `axis` across replicas.
Raises:
RuntimeError: when the batch (0-th) dimension is None.
"""
with tf.name_scope(name):
context = tf.distribute.get_replica_context()
# Typically this could be hit only if the tensor is derived from a
# dataset with finite epochs and drop_remainder=False, where the last
# batch could of different batch size and then the dim-0 is of dynamic
# shape.
if value.shape.as_list()[0] is None:
raise RuntimeError(f"{value} has unknown batch.")
return context.all_gather(value, axis=axis)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf_utils."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.modeling import tf_utils
def all_strategy_combinations():
return combinations.combine(
strategy=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
mode='eager',
)
class TFUtilsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_cross_replica_concat(self, strategy):
num_cores = strategy.num_replicas_in_sync
shape = (2, 3, 4)
def concat(axis):
@tf.function
def function():
replica_value = tf.fill(shape, tf_utils.get_replica_id())
return tf_utils.cross_replica_concat(replica_value, axis=axis)
return function
def expected(axis):
values = [np.full(shape, i) for i in range(num_cores)]
return np.concatenate(values, axis=axis)
per_replica_results = strategy.run(concat(axis=0))
replica_0_result = per_replica_results.values[0].numpy()
for value in per_replica_results.values[1:]:
self.assertAllClose(value.numpy(), replica_0_result)
self.assertAllClose(replica_0_result, expected(axis=0))
replica_0_result = strategy.run(concat(axis=1)).values[0].numpy()
self.assertAllClose(replica_0_result, expected(axis=1))
replica_0_result = strategy.run(concat(axis=2)).values[0].numpy()
self.assertAllClose(replica_0_result, expected(axis=2))
@combinations.generate(all_strategy_combinations())
def test_cross_replica_concat_gradient(self, strategy):
num_cores = strategy.num_replicas_in_sync
shape = (10, 5)
@tf.function
def function():
replica_value = tf.random.normal(shape)
with tf.GradientTape() as tape:
tape.watch(replica_value)
concat_value = tf_utils.cross_replica_concat(replica_value, axis=0)
output = tf.reduce_sum(concat_value)
return tape.gradient(output, replica_value)
per_replica_gradients = strategy.run(function)
for gradient in per_replica_gradients.values:
self.assertAllClose(gradient, num_cores * tf.ones(shape))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册