提交 bce62166 编写于 作者: V Vincent Vanhoucke 提交者: TensorFlower Gardener

Switch nn.moments() to using a one-pass stable algorithm.

Helps with: https://github.com/tensorflow/tensorflow/issues/917
Also fixes https://github.com/tensorflow/tensorflow/issues/1162

The main benefit is that the computation of the sufficient statistics is now decoupled of the aggregation of the moments, which means that if you want to perform the accumulation incrementally, you don't have to keep all the inputs around, and can instead keep the much more compact sum and sum-of-squares. Accumulation could also be performed locally if you aggregate across multiple devices.
Computing sum and sum-of-squares can also theoretically be performed in parallel now.

Tested running inception: same performance, same step time.
Batch normalization benchmark is a bit faster on CPU, a bit slower on GPU:

Before:
cpu shape:4/3 #layers:10 mode:py scale:True train:False - 1.139310 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.021970 secs
cpu shape:4/3 #layers:10 mode:py scale:True train:True - 2.767147 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:True - 0.074531 secs
cpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.742835 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.013473 secs
cpu shape:4/3 #layers:10 mode:py scale:True train:True - 1.738806 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:True - 0.052777 secs
cpu shape:2/1 #layers:10 mode:py scale:True train:False - 0.119180 secs
gpu shape:2/1 #layers:10 mode:py scale:True train:False - 0.011201 secs
cpu shape:2/1 #layers:10 mode:py scale:True train:True - 0.218297 secs
gpu shape:2/1 #layers:10 mode:py scale:True train:True - 0.048526 secs

After:
cpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.998944 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.025828 secs
cpu shape:4/3 #layers:10 mode:py scale:True train:True - 2.657428 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:True - 0.086614 secs
cpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.603137 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.017668 secs
cpu shape:4/3 #layers:10 mode:py scale:True train:True - 1.519533 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:True - 0.055214 secs
cpu shape:2/1 #layers:10 mode:py scale:True train:False - 0.071344 secs
gpu shape:2/1 #layers:10 mode:py scale:True train:False - 0.016440 secs
cpu shape:2/1 #layers:10 mode:py scale:True train:True - 0.222093 secs
gpu shape:2/1 #layers:10 mode:py scale:True train:True - 0.039967 secs
Change: 115507032
上级 2cc5ed87
......@@ -134,6 +134,8 @@ have varying scale, and to aid generalization.
@@l2_normalize
@@local_response_normalization
@@sufficient_statistics
@@aggregate_moments
@@moments
## Losses
......@@ -495,6 +497,101 @@ def separable_conv2d(input, depthwise_filter, pointwise_filter, strides,
padding="VALID", name=name)
def sufficient_statistics(x, axes, shift=True, keep_dims=False, name=None):
"""Calculate the sufficient statistics for the mean and variance of `x`.
These sufficient statistics are computed using the one pass algorithm on
an input that's optionally shifted using the value of the 1st element in `x`.
See:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
Args:
x: A `Tensor`.
axes: Array of ints. Axes along which to compute mean and variance.
shift: If true, shift the data to provide more numerically stable results.
keep_dims: produce statistics with the same dimensionality as the input.
name: Name used to scope the operations that compute the sufficient stats.
Returns:
Four `Tensor` objects of the same type as `x`:
* the count (number of elements to average over).
* the (possibly shifted) sum of the elements in the array.
* the (possibly shifted) sum of squares of the elements in the array.
* the shift by which the mean must be corrected or None if `shift` is False.
"""
with ops.op_scope([x, axes], name, "sufficient_statistics"):
x = ops.convert_to_tensor(x, name="x")
x_shape = x.get_shape()
if x_shape.is_fully_defined():
counts = 1
m_shape = []
for d in xrange(x_shape.ndims):
dim = x_shape[d].value
if d in set(axes):
counts *= dim
dim = 1
m_shape.append(dim)
counts = constant_op.constant(counts, dtype=x.dtype)
else: # shape needs to be inferred at runtime.
x_shape = array_ops.shape(x)
select_axes = sparse_ops.sparse_to_dense(axes, array_ops.shape(x_shape),
True, False)
m_shape = math_ops.select(select_axes, array_ops.ones_like(x_shape),
x_shape)
counts = math_ops.cast(
math_ops.reduce_prod(x_shape / m_shape),
x.dtype,
name="count")
if shift:
shift_value = array_ops.slice(x, array_ops.zeros_like(m_shape), m_shape)
m_ss = math_ops.sub(x, shift_value)
v_ss = math_ops.squared_difference(x, shift_value)
if keep_dims:
shift_value = array_ops.identity(shift_value, name="shift")
else:
shift_value = array_ops.squeeze(shift_value,
squeeze_dims=axes,
name="shift")
else: # not shift.
m_ss = x
v_ss = math_ops.square(x)
shift_value = None
m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss")
v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss")
return counts, m_ss, v_ss, shift_value
def aggregate_moments(counts, mean_ss, variance_ss, shift, name=None):
"""Calculate the mean and variance of based on the sufficient statistics.
Args:
counts: A `Tensor` containing a the total count of the data (one value).
mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly
shifted) sum of the elements to average over.
variance_ss: A `Tensor` containing the variance sufficient statistics: the
(possibly shifted) squared sum of the data to compute the variance over.
shift: A `Tensor` containing the value by which the data is shifted for
numerical stability, or `None` if no shift was performed.
name: Name used to scope the operations that compute the moments.
Returns:
Two `Tensor` objects: `mean` and `variance`.
"""
with ops.op_scope([counts, mean_ss, variance_ss, shift], name, "aggregate"):
divisor = math_ops.inv(counts, name="divisor")
if shift is not None:
shifted_mean = math_ops.mul(mean_ss, divisor, name="shifted_mean")
mean = math_ops.add(shifted_mean, shift, name="mean")
else: # no shift.
shifted_mean = math_ops.mul(mean_ss, divisor, name="mean")
mean = shifted_mean
variance = math_ops.sub(
math_ops.mul(variance_ss, divisor),
math_ops.square(shifted_mean),
name="variance")
return (mean, variance)
def moments(x, axes, name=None, keep_dims=False):
"""Calculate the mean and variance of `x`.
......@@ -519,40 +616,11 @@ def moments(x, axes, name=None, keep_dims=False):
Two `Tensor` objects: `mean` and `variance`.
"""
with ops.op_scope([x, axes], name, "moments"):
x = ops.convert_to_tensor(x, name="x")
x_shape = x.get_shape()
if all(x_shape[d].value is not None for d in axes):
# The shape is known in the relevant axes, so we can statically
# compute the divisor.
divisor = 1.0
for d in set(axes):
divisor *= x.get_shape()[d].value
divisor = constant_op.constant(1.0 / divisor, x.dtype, name="divisor")
else:
divisor = constant_op.constant(1.0, dtype=x.dtype)
x_dynamic_shape = array_ops.shape(x)
for d in set(axes):
divisor *= math_ops.cast(x_dynamic_shape[d], x.dtype)
divisor = math_ops.inv(divisor, name="divisor")
constant_axes = constant_op.constant(axes, name="axes")
# Note: We do not use Mean here because it is very slow on GPU.
mean = math_ops.mul(
math_ops.reduce_sum(x,
constant_axes,
keep_dims=True),
divisor,
name="mean")
var = math_ops.mul(
math_ops.reduce_sum(
math_ops.squared_difference(x, mean),
constant_axes,
keep_dims=keep_dims),
divisor,
name="variance")
if keep_dims:
return mean, var
else:
return array_ops.squeeze(mean, squeeze_dims=axes), var
counts, m_ss, v_ss, shift = sufficient_statistics(x,
axes,
keep_dims=keep_dims,
name=name)
return aggregate_moments(counts, m_ss, v_ss, shift, name=name)
def batch_normalization(x,
......
......@@ -476,7 +476,7 @@ class DropoutTest(tf.test.TestCase):
_ = tf.nn.dropout(t, keep_prob, noise_shape=[1, 1])
class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
class BatchNormalizationTest(tf.test.TestCase):
def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization, shift_after_normalization):
......@@ -670,8 +670,7 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
else:
all_grads = sess.run([dx, dm, dv, db, odx, odm, odv, odb])
to_check = ["dx", "dm", "dv", "db"]
for i, n in enumerate(to_check):
print(n)
for i, _ in enumerate(to_check):
self.assertAllClose(
all_grads[i + len(to_check)], all_grads[i], atol=0.000001)
......@@ -759,6 +758,117 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
atol=0.005)
class SufficientStatisticsTest(tf.test.TestCase):
def _npSuffStats(self, x, axes, shift, keep_dims):
axis = tuple(axes)
if shift:
shift_value = x[[slice(None) if i not in set(axis) else slice(0, 1)
for i in xrange(x.ndim)]]
m_ss = np.sum(x - shift_value, axis=axis, keepdims=keep_dims)
v_ss = np.sum(
(x - shift_value) * (x - shift_value),
axis=axis,
keepdims=keep_dims)
else:
shift_value = None
m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
count = 1.0
for d in xrange(x.ndim):
if d in set(axes):
count *= x.shape[d]
if not keep_dims:
shift_value = np.squeeze(shift_value, axis=axis)
return count, m_ss, v_ss, shift_value
def _opSuffStats(self, x, axes, shift, keep_dims):
return tf.nn.sufficient_statistics(x, axes, shift, keep_dims)
def _testSuffStats(self, x_shape, axes, shift, keep_dims, has_shape):
x_val = np.random.random_sample(x_shape).astype(np.float32)
np_c, np_m, np_v, np_s = self._npSuffStats(x_val, axes, shift, keep_dims)
for use_gpu in [True, False]:
with self.test_session(use_gpu=use_gpu) as sess:
if has_shape:
x = tf.constant(x_val, name="x")
x.set_shape(x_shape)
op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
if shift:
tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s])
else:
tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v])
else:
x = tf.placeholder(dtype=tf.float32,
shape=[None] * len(x_shape),
name="x")
op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
if shift:
tf_c, tf_m, tf_v, tf_s = sess.run(
[op_c, op_m, op_v, op_s],
feed_dict={x: x_val})
else:
tf_c, tf_m, tf_v = sess.run(
[op_c, op_m, op_v],
feed_dict={x: x_val})
self.assertAllClose(np_c, tf_c, atol=0.000001)
self.assertAllClose(np_m, tf_m, atol=0.000001)
self.assertAllClose(np_v, tf_v, atol=0.000001)
if shift:
self.assertAllClose(np_s, tf_s, atol=0.000001)
def testSuffStats(self):
for has_shape in [True, False]:
for keep_dims in [True, False]:
for shift in [True, False]:
self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
class AggregateMomentsTest(tf.test.TestCase):
def _npAggregateMoments(self, counts, mean_ss, variance_ss, shift):
mean = mean_ss / counts
variance = variance_ss / counts - mean * mean
if shift is not None:
mean += shift
return mean, variance
def _opAggregateMoments(self, counts, mean_ss, variance_ss, shift):
return tf.nn.aggregate_moments(counts, mean_ss, variance_ss, shift)
def _testAggregateMoments(self, shape, shift):
counts = np.ones([1]).astype(np.float32)
mean_ss = np.random.random_sample(shape).astype(np.float32)
variance_ss = np.random.random_sample(shape).astype(np.float32)
variance_ss *= variance_ss
if shift:
shift_v = np.random.random_sample(shape).astype(np.float32)
else:
shift_v = None
npm, npv = self._npAggregateMoments(counts, mean_ss, variance_ss, shift_v)
for use_gpu in [True, False]:
with self.test_session(use_gpu=use_gpu) as sess:
tf_counts = tf.constant(counts, name="counts")
tf_mean_ss = tf.constant(mean_ss, name="mean_ss")
tf_variance_ss = tf.constant(variance_ss, name="variance_ss")
if shift:
tf_shift_v = tf.constant(shift_v, name="shift")
else:
tf_shift_v = None
opm, opv = self._opAggregateMoments(tf_counts, tf_mean_ss,
tf_variance_ss, tf_shift_v)
tfm, tfv = sess.run([opm, opv])
self.assertAllClose(npm, tfm, atol=0.000001)
self.assertAllClose(npv, tfv, atol=0.000001)
def testAggregateMoments(self):
for shift in [True, False]:
self._testAggregateMoments([3], shift)
self._testAggregateMoments([2, 3], shift)
class MomentsTest(tf.test.TestCase):
def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims):
......@@ -857,6 +967,20 @@ class MomentsTest(tf.test.TestCase):
def testVarGlobalGradient(self):
self._testGlobalGradient(from_y="var")
def testOutputNamesNoKeep(self):
"""Make sure the output names are stable."""
with self.test_session():
mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=False)
self.assertEquals(mean.op.name, "moments/aggregate/mean")
self.assertEquals(var.op.name, "moments/aggregate/variance")
def testOutputNamesKeep(self):
"""Make sure the output names are stable."""
with self.test_session():
mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=True)
self.assertEquals(mean.op.name, "moments/aggregate/mean")
self.assertEquals(var.op.name, "moments/aggregate/variance")
class ComputeSampledLogitsTest(tf.test.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册