提交 adf76904 编写于 作者: A Anjali Sridhar 提交者: TensorFlower Gardener

Add support for aggregating batch statistics across devices by using the newly...

Add support for aggregating batch statistics across devices by using the newly added tf.keras.layers.experimental.SyncBatchNormalization layer.

PiperOrigin-RevId: 292723222
Change-Id: I1c0458ec24c7e712ffa5e12dcf1f5efd6b4ce8ac
上级 319b73c6
......@@ -35,7 +35,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
_NUM_SAMPLES = 64
_NUM_SAMPLES = 66
_BATCH_SIZE = 32
_RANDOM_SEED = 1337
_NUM_EPOCHS = 2
......@@ -60,12 +60,16 @@ class MaybeStrategyScope(object):
self._scope = None
def get_model():
def get_model(sync_batchnorm=False):
model = keras.Sequential()
model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,)))
model.add(keras.layers.Dense(
10, activation='relu',
kernel_regularizer=keras.regularizers.l2(1e-4)))
if sync_batchnorm:
model.add(keras.layers.SyncBatchNormalization())
else:
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Dense(10, activation='relu'))
model.add(keras.layers.Dense(1))
return model
......@@ -90,10 +94,13 @@ def compute_loss(labels, logits, reg_losses):
def iteration_inside_func(initial_weights, dataset, optimizer_fn,
iteration_type, strategy=None):
iteration_type, strategy=None, sync_batchnorm=None):
"""Helper function to test iterating over data inside a tf.function."""
with MaybeStrategyScope(strategy):
model = get_model()
if strategy and sync_batchnorm:
model = get_model(sync_batchnorm)
else:
model = get_model()
model.set_weights(initial_weights)
optimizer = optimizer_fn()
......@@ -153,10 +160,10 @@ def iteration_inside_func(initial_weights, dataset, optimizer_fn,
def iteration_outside_func(initial_weights, dataset, optimizer_fn,
iteration_type, strategy=None):
iteration_type, strategy=None, sync_batchnorm=None):
"""Helper function to test iterating over data outside a tf.function."""
with MaybeStrategyScope(strategy):
model = get_model()
model = get_model(sync_batchnorm=sync_batchnorm)
model.set_weights(initial_weights)
optimizer = optimizer_fn()
......@@ -223,16 +230,21 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
optimizer_fn=strategy_combinations.optimizers_v1_and_v2,
mode=['eager'],
iteration_type=['iterator', 'dataset'],
inside_func=[False, True]
inside_func=[False, True],
sync_batchnorm=[True, False]
))
def test_dnn_correctness_minus_tpus(self, distribution, optimizer_fn,
iteration_type, inside_func):
iteration_type, inside_func,
sync_batchnorm):
# TODO(anjs): Identify why this particular V1 optimizer needs a higher tol.
if 'FtrlV1' in optimizer_fn._name and 'TPU' in type(distribution).__name__:
self.skipTest('Reduced tolerance of the order of 1e-1 required.')
self.dnn_correctness(distribution, optimizer_fn, iteration_type,
inside_func)
inside_func, sync_batchnorm)
def dnn_correctness(self, distribution, optimizer_fn, iteration_type,
inside_func):
model = get_model()
inside_func, sync_batchnorm=None):
model = get_model(sync_batchnorm)
initial_weights = model.get_weights()
dataset = get_data()
if inside_func:
......@@ -241,13 +253,15 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
iteration_func = iteration_outside_func
wts_with_ds, loss_with_ds, acc_with_ds = iteration_func(
initial_weights, dataset, optimizer_fn, iteration_type,
strategy=distribution)
strategy=distribution, sync_batchnorm=sync_batchnorm)
wts, loss, acc = iteration_func(initial_weights, dataset, optimizer_fn,
iteration_type)
iteration_type,
sync_batchnorm=sync_batchnorm)
self.assertAllClose(wts, wts_with_ds, atol=1e-3, rtol=1e-3)
self.assertAllClose(loss, loss_with_ds, atol=1e-3, rtol=1e-3)
self.assertAllClose(acc, acc_with_ds, atol=1e-3, rtol=1e-3)
if __name__ == '__main__':
test.main()
......@@ -386,7 +386,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
def set_up_test_config(self,
use_numpy=False,
use_validation_data=False,
with_batch_norm=False):
with_batch_norm=None):
self.use_numpy = use_numpy
self.use_validation_data = use_validation_data
self.with_batch_norm = with_batch_norm
......@@ -435,7 +435,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
use_numpy,
use_validation_data,
experimental_run_tf_function=None,
with_batch_norm=False,
with_batch_norm=None,
is_stateful_model=False,
partial_last_batch=None,
training_epochs=2):
......@@ -503,7 +503,8 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
# First, special case, for multi-replica distributed training, batch
# norm is not aggregated globally. So it is expected to have different
# weights.
if (self.with_batch_norm and distribution.num_replicas_in_sync > 1):
if (self.with_batch_norm == 'regular' and
distribution.num_replicas_in_sync > 1):
with self.assertRaises(AssertionError):
compare_results(
results_with_ds,
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import keras
from tensorflow.python.distribute import combinations
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.keras.distribute import keras_correctness_test_base
from tensorflow.python.keras.optimizer_v2 import gradient_descent
......@@ -43,8 +44,10 @@ class DistributionStrategyCnnCorrectnessTest(
strides=(4, 4),
kernel_regularizer=keras.regularizers.l2(1e-4))(
image)
if self.with_batch_norm:
if self.with_batch_norm == 'regular':
c1 = keras.layers.BatchNormalization(name='bn1')(c1)
elif self.with_batch_norm == 'sync':
c1 = keras.layers.SyncBatchNormalization(name='bn1')(c1)
c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1)
logits = keras.layers.Dense(
10, activation='softmax', name='pred')(
......@@ -107,7 +110,22 @@ class DistributionStrategyCnnCorrectnessTest(
distribution,
use_numpy,
use_validation_data,
with_batch_norm=True,
with_batch_norm='regular',
experimental_run_tf_function=experimental_run_tf_function)
@combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations())
def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy,
use_validation_data,
experimental_run_tf_function):
if not context.executing_eagerly() or not experimental_run_tf_function:
self.skipTest('SyncBatchNorm is not enabled in graph mode.')
self.run_correctness_test(
distribution,
use_numpy,
use_validation_data,
with_batch_norm='sync',
experimental_run_tf_function=experimental_run_tf_function)
@combinations.generate(
......@@ -134,7 +152,7 @@ class DistributionStrategyCnnCorrectnessTest(
distribution,
use_numpy,
use_validation_data,
with_batch_norm=True,
with_batch_norm='regular',
partial_last_batch=True)
......
......@@ -135,6 +135,8 @@ from tensorflow.python.keras.layers.noise import GaussianDropout
# Normalization layers.
from tensorflow.python.keras.layers.normalization import LayerNormalization
from tensorflow.python.keras.layers.normalization_v2 import SyncBatchNormalization
if tf2.enabled():
from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization
from tensorflow.python.keras.layers.normalization import BatchNormalization as BatchNormalizationV1
......
......@@ -652,8 +652,12 @@ class BatchNormalizationBase(Layer):
return (r, d, out_mean, out_variance)
def _calculate_mean_and_var(self, inputs, reduction_axes, keep_dims):
return nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
def _moments(self, inputs, reduction_axes, keep_dims):
mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
mean, variance = self._calculate_mean_and_var(inputs, reduction_axes,
keep_dims)
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
# code as well.
if self._support_zero_size_input():
......
......@@ -18,10 +18,192 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import distribution_strategy_context as ds
from tensorflow.python.distribute import reduce_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras.layers import normalization
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.layers.experimental.SyncBatchNormalization', v1=[]) # pylint: disable=g-classes-have-attributes
class SyncBatchNormalization(normalization.BatchNormalizationBase):
r"""Normalize and scale inputs or activations synchronously across replicas.
Applies batch normalization to activations of the previous layer at each batch
by synchronizing the global batch statistics across all devices that are
training the model. For specific details about batch normalization please
refer to the `tf.keras.layers.BatchNormalization` layer docs.
If this layer is used when using tf.distribute strategy to train models
across devices/workers, there will be an allreduce call to aggregate batch
statistics across all replicas at every training step. Without tf.distribute
strategy, this layer behaves as a regular `tf.keras.layers.BatchNormalization`
layer.
Example usage:
```
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(16))
model.add(tf.keras.layers.experimental.SyncBatchNormalization())
```
Arguments:
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `BatchNormalization`.
momentum: Momentum for the moving average.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
moving_mean_initializer: Initializer for the moving mean.
moving_variance_initializer: Initializer for the moving variance.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
renorm: Whether to use Batch Renormalization
(https://arxiv.org/abs/1702.03275). This adds extra variables during
training. The inference is the same for either value of this parameter.
renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
scalar `Tensors` used to clip the renorm correction. The correction
`(r, d)` is used as `corrected_value = normalized_value * r + d`, with
`r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
dmax are set to inf, 0, inf, respectively.
renorm_momentum: Momentum used to update the moving means and standard
deviations with renorm. Unlike `momentum`, this affects training
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference.
trainable: Boolean, if `True` the variables will be marked as trainable.
Call arguments:
inputs: Input tensor (of any rank).
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode.
- `training=True`: The layer will normalize its inputs using the
mean and variance of the current batch of inputs.
- `training=False`: The layer will normalize its inputs using the
mean and variance of its moving statistics, learned during training.
Input shape:
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
Output shape:
Same shape as input.
"""
def __init__(self,
axis=-1,
momentum=0.99,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
trainable=True,
adjustment=None,
name=None,
**kwargs):
# Currently we only support aggregating over the global batch size.
super(SyncBatchNormalization, self).__init__(
axis=axis,
momentum=momentum,
epsilon=epsilon,
center=center,
scale=scale,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
moving_mean_initializer=moving_mean_initializer,
moving_variance_initializer=moving_variance_initializer,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
beta_constraint=beta_constraint,
gamma_constraint=gamma_constraint,
renorm=renorm,
renorm_clipping=renorm_clipping,
renorm_momentum=renorm_momentum,
fused=False,
trainable=trainable,
virtual_batch_size=None,
name=name,
**kwargs)
def _calculate_mean_and_var(self, x, axes, keep_dims):
with ops.name_scope('moments', values=[x, axes]):
# The dynamic range of fp16 is too limited to support the collection of
# sufficient statistics. As a workaround we simply perform the operations
# on 32-bit floats before converting the mean and variance back to fp16
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
replica_ctx = ds.get_replica_context()
if replica_ctx:
local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True)
local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes,
keepdims=True)
y_sum, y_squared_sum, global_batch_size = (
replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, [
local_sum, local_squared_sum, array_ops.shape_v2(y)[0]]))
axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))]
multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals),
dtypes.float32)
multiplier = multiplier * math_ops.cast(global_batch_size,
dtypes.float32)
mean = y_sum / multiplier
y_squared_mean = y_squared_sum / multiplier
# var = E(x^2) - E(x)^2
variance = y_squared_mean - math_ops.square(mean)
else:
# Compute true mean while keeping the dims for proper broadcasting.
mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean')
# sample variance, not unbiased variance
# Note: stop_gradient does not change the gradient that gets
# backpropagated to the mean from the variance calculation,
# because that gradient is zero
variance = math_ops.reduce_mean(
math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
axes,
keepdims=True,
name='variance')
if not keep_dims:
mean = array_ops.squeeze(mean, axes)
variance = array_ops.squeeze(variance, axes)
if x.dtype == dtypes.float16:
return (math_ops.cast(mean, dtypes.float16),
math_ops.cast(variance, dtypes.float16))
else:
return (mean, variance)
@keras_export('keras.layers.BatchNormalization', v1=[]) # pylint: disable=missing-docstring
class BatchNormalization(normalization.BatchNormalizationBase):
......
path: "tensorflow.keras.layers.experimental.SyncBatchNormalization"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.normalization_v2.SyncBatchNormalization\'>"
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "dynamic"
mtype: "<type \'property\'>"
}
member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "input"
mtype: "<type \'property\'>"
}
member {
name: "input_mask"
mtype: "<type \'property\'>"
}
member {
name: "input_shape"
mtype: "<type \'property\'>"
}
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member {
name: "losses"
mtype: "<type \'property\'>"
}
member {
name: "metrics"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "outbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "output"
mtype: "<type \'property\'>"
}
member {
name: "output_mask"
mtype: "<type \'property\'>"
}
member {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
name: "variables"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'trainable\', \'adjustment\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'zeros\', \'ones\', \'zeros\', \'ones\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'True\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_metric"
argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "add_update"
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_variable"
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "build"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_output_signature"
argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_losses_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "with_name_scope"
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
}
}
path: "tensorflow.keras.layers.experimental"
tf_module {
member {
name: "SyncBatchNormalization"
mtype: "<type \'type\'>"
}
member {
name: "preprocessing"
mtype: "<type \'module\'>"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册