提交 a3d63443 编写于 作者: F Francois Chollet 提交者: TensorFlower Gardener

Run BatchNorm layer tests in all relevant execution modes.

PiperOrigin-RevId: 225287527
上级 1ec3b398
......@@ -466,6 +466,7 @@ py_test(
name = "normalization_test",
size = "medium",
srcs = ["layers/normalization_test.py"],
shard_count = 3,
srcs_version = "PY2AND3",
tags = ["notsan"],
deps = [
......
......@@ -22,16 +22,16 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.layers import normalization
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
@tf_test_util.run_all_in_graph_and_eager_modes
@tf_test_util.run_v1_only('b/120545219')
class NormalizationLayersTest(test.TestCase):
class BatchNormalizationTest(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes
def test_basic_batchnorm(self):
testing_utils.layer_test(
keras.layers.BatchNormalization,
......@@ -56,15 +56,8 @@ class NormalizationLayersTest(test.TestCase):
kwargs={'scale': False,
'center': False},
input_shape=(3, 3))
testing_utils.layer_test(
normalization.BatchNormalizationV2,
kwargs={'fused': True},
input_shape=(3, 3, 3, 3))
testing_utils.layer_test(
normalization.BatchNormalizationV2,
kwargs={'fused': None},
input_shape=(3, 3, 3))
@tf_test_util.run_in_graph_and_eager_modes
def test_batchnorm_weights(self):
layer = keras.layers.BatchNormalization(scale=False, center=False)
layer.build((None, 3, 4))
......@@ -76,6 +69,7 @@ class NormalizationLayersTest(test.TestCase):
self.assertEqual(len(layer.trainable_weights), 2)
self.assertEqual(len(layer.weights), 4)
@tf_test_util.run_in_graph_and_eager_modes
def test_batchnorm_regularization(self):
layer = keras.layers.BatchNormalization(
gamma_regularizer='l1', beta_regularizer='l1')
......@@ -88,36 +82,7 @@ class NormalizationLayersTest(test.TestCase):
self.assertEqual(layer.gamma.constraint, max_norm)
self.assertEqual(layer.beta.constraint, max_norm)
def _test_batchnorm_correctness(self, dtype, use_v2=True, fused=False):
model = keras.models.Sequential()
layer_ctor = (normalization.BatchNormalizationV2 if use_v2
else normalization.BatchNormalizationV1)
norm = layer_ctor(input_shape=(2, 2, 2), momentum=0.8, fused=fused)
model.add(norm)
model.compile(loss='mse',
optimizer=gradient_descent.GradientDescentOptimizer(0.01))
# centered on 5.0, variance 10.0
x = (np.random.normal(loc=5.0, scale=10.0, size=(1000, 2, 2, 2))
.astype(dtype))
model.fit(x, x, epochs=4, verbose=0)
out = model.predict(x)
out -= keras.backend.eval(norm.beta)
out /= keras.backend.eval(norm.gamma)
np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
def test_batchnorm_correctness(self):
self._test_batchnorm_correctness(np.float32)
self._test_batchnorm_correctness(np.float32, fused=True)
self._test_batchnorm_correctness(np.float32, use_v2=False)
def test_batchnorm_mixed_precision(self):
self._test_batchnorm_correctness(np.float16)
self._test_batchnorm_correctness(np.float16, fused=True)
self._test_batchnorm_correctness(np.float16, use_v2=False)
@keras_parameterized.run_all_keras_modes
def test_batchnorm_convnet(self):
if test.is_gpu_available(cuda_only=True):
with self.session(use_gpu=True):
......@@ -126,7 +91,8 @@ class NormalizationLayersTest(test.TestCase):
axis=1, input_shape=(3, 4, 4), momentum=0.8)
model.add(norm)
model.compile(loss='mse',
optimizer=gradient_descent.GradientDescentOptimizer(0.01))
optimizer=gradient_descent.GradientDescentOptimizer(0.01),
run_eagerly=testing_utils.should_run_eagerly())
# centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
......@@ -138,13 +104,15 @@ class NormalizationLayersTest(test.TestCase):
np.testing.assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)
@keras_parameterized.run_all_keras_modes
def test_batchnorm_convnet_channel_last(self):
model = keras.models.Sequential()
norm = keras.layers.BatchNormalization(
axis=-1, input_shape=(4, 4, 3), momentum=0.8)
model.add(norm)
model.compile(loss='mse',
optimizer=gradient_descent.GradientDescentOptimizer(0.01))
optimizer=gradient_descent.GradientDescentOptimizer(0.01),
run_eagerly=testing_utils.should_run_eagerly())
# centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 4, 4, 3))
......@@ -156,6 +124,28 @@ class NormalizationLayersTest(test.TestCase):
np.testing.assert_allclose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1)
np.testing.assert_allclose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1)
@keras_parameterized.run_all_keras_modes
def test_batchnorm_correctness(self):
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float32')
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float32', fused=True)
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float32', fused=False)
@keras_parameterized.run_all_keras_modes
def test_batchnorm_mixed_precision(self):
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float16')
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float16', fused=True)
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float16', fused=False)
class BatchNormalizationV1Test(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_v1_fused_attribute(self):
norm = normalization.BatchNormalizationV1()
inp = keras.layers.Input((4, 4, 4))
......@@ -174,6 +164,21 @@ class NormalizationLayersTest(test.TestCase):
norm(inp)
self.assertEqual(norm.fused, False)
class BatchNormalizationV2Test(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes
def test_basic_batchnorm_v2(self):
testing_utils.layer_test(
normalization.BatchNormalizationV2,
kwargs={'fused': True},
input_shape=(3, 3, 3, 3))
testing_utils.layer_test(
normalization.BatchNormalizationV2,
kwargs={'fused': None},
input_shape=(3, 3, 3))
@tf_test_util.run_in_graph_and_eager_modes
def test_v2_fused_attribute(self):
norm = normalization.BatchNormalizationV2()
self.assertEqual(norm.fused, None)
......@@ -228,6 +233,26 @@ class NormalizationLayersTest(test.TestCase):
norm(inp)
def _run_batchnorm_correctness_test(layer, dtype='float32', fused=False):
model = keras.models.Sequential()
norm = layer(input_shape=(2, 2, 2), momentum=0.8, fused=fused)
model.add(norm)
model.compile(loss='mse',
optimizer=gradient_descent.GradientDescentOptimizer(0.01),
run_eagerly=testing_utils.should_run_eagerly())
# centered on 5.0, variance 10.0
x = (np.random.normal(loc=5.0, scale=10.0, size=(1000, 2, 2, 2))
.astype(dtype))
model.fit(x, x, epochs=4, verbose=0)
out = model.predict(x)
out -= keras.backend.eval(norm.beta)
out /= keras.backend.eval(norm.gamma)
np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
@tf_test_util.run_v1_only('b/120545219')
class NormalizationLayersGraphModeOnlyTest(test.TestCase):
......@@ -309,6 +334,8 @@ class NormalizationLayersGraphModeOnlyTest(test.TestCase):
Computes mean and std for current inputs then
applies batch normalization using them.
"""
# TODO(fchollet): enable in all execution modes when issue with
# learning phase setting is resolved.
with self.cached_session():
bn_mean = 0.5
bn_std = 10.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册