提交 8b8adf85 编写于 作者: S Scott Zhu 提交者: TensorFlower Gardener

Update simplernn_test to use v2 mode.

Test case that can only run in v1 has bug attached.

PiperOrigin-RevId: 225271476
上级 e9f8aff8
...@@ -22,14 +22,15 @@ import numpy as np ...@@ -22,14 +22,15 @@ import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent from tensorflow.python.training import gradient_descent
from tensorflow.python.training.rmsprop import RMSPropOptimizer from tensorflow.python.training.rmsprop import RMSPropOptimizer
@tf_test_util.run_all_in_graph_and_eager_modes @keras_parameterized.run_all_keras_modes
class SimpleRNNLayerTest(test.TestCase): class SimpleRNNLayerTest(keras_parameterized.TestCase):
def test_return_sequences_SimpleRNN(self): def test_return_sequences_SimpleRNN(self):
num_samples = 2 num_samples = 2
...@@ -118,93 +119,91 @@ class SimpleRNNLayerTest(test.TestCase): ...@@ -118,93 +119,91 @@ class SimpleRNNLayerTest(test.TestCase):
l2 = layer_class.from_config(l1.get_config()) l2 = layer_class.from_config(l1.get_config())
assert l1.get_config() == l2.get_config() assert l1.get_config() == l2.get_config()
class SimpleRNNLayerGraphOnlyTest(test.TestCase):
@tf_test_util.run_v1_only('b/120545219')
def test_statefulness_SimpleRNN(self): def test_statefulness_SimpleRNN(self):
num_samples = 2 num_samples = 2
timesteps = 3 timesteps = 3
embedding_dim = 4 embedding_dim = 4
units = 2 units = 2
layer_class = keras.layers.SimpleRNN layer_class = keras.layers.SimpleRNN
with self.cached_session(): model = keras.models.Sequential()
model = keras.models.Sequential() model.add(
model.add( keras.layers.Embedding(
keras.layers.Embedding( 4,
4, embedding_dim,
embedding_dim, mask_zero=True,
mask_zero=True, input_length=timesteps,
input_length=timesteps, batch_input_shape=(num_samples, timesteps)))
batch_input_shape=(num_samples, timesteps))) layer = layer_class(
layer = layer_class( units, return_sequences=False, stateful=True, weights=None)
units, return_sequences=False, stateful=True, weights=None) model.add(layer)
model.add(layer) model.compile(optimizer=gradient_descent.GradientDescentOptimizer(0.01),
model.compile(optimizer=gradient_descent.GradientDescentOptimizer(0.01), loss='mse')
loss='mse') out1 = model.predict(np.ones((num_samples, timesteps)))
out1 = model.predict(np.ones((num_samples, timesteps))) self.assertEqual(out1.shape, (num_samples, units))
self.assertEqual(out1.shape, (num_samples, units))
# train once so that the states change
# train once so that the states change model.train_on_batch(
model.train_on_batch( np.ones((num_samples, timesteps)), np.ones((num_samples, units)))
np.ones((num_samples, timesteps)), np.ones((num_samples, units))) out2 = model.predict(np.ones((num_samples, timesteps)))
out2 = model.predict(np.ones((num_samples, timesteps)))
# if the state is not reset, output should be different
# if the state is not reset, output should be different self.assertNotEqual(out1.max(), out2.max())
self.assertNotEqual(out1.max(), out2.max())
# check that output changes after states are reset
# check that output changes after states are reset # (even though the model itself didn't change)
# (even though the model itself didn't change) layer.reset_states()
layer.reset_states() out3 = model.predict(np.ones((num_samples, timesteps)))
out3 = model.predict(np.ones((num_samples, timesteps))) self.assertNotEqual(out2.max(), out3.max())
self.assertNotEqual(out2.max(), out3.max())
# check that container-level reset_states() works
# check that container-level reset_states() works model.reset_states()
model.reset_states() out4 = model.predict(np.ones((num_samples, timesteps)))
out4 = model.predict(np.ones((num_samples, timesteps))) np.testing.assert_allclose(out3, out4, atol=1e-5)
np.testing.assert_allclose(out3, out4, atol=1e-5)
# check that the call to `predict` updated the states
# check that the call to `predict` updated the states out5 = model.predict(np.ones((num_samples, timesteps)))
out5 = model.predict(np.ones((num_samples, timesteps))) self.assertNotEqual(out4.max(), out5.max())
self.assertNotEqual(out4.max(), out5.max())
# Check masking
# Check masking layer.reset_states()
layer.reset_states()
left_padded_input = np.ones((num_samples, timesteps))
left_padded_input[0, :1] = 0
left_padded_input[1, :2] = 0
out6 = model.predict(left_padded_input)
layer.reset_states()
right_padded_input = np.ones((num_samples, timesteps))
right_padded_input[0, -1:] = 0
right_padded_input[1, -2:] = 0
out7 = model.predict(right_padded_input)
np.testing.assert_allclose(out7, out6, atol=1e-5)
left_padded_input = np.ones((num_samples, timesteps))
left_padded_input[0, :1] = 0
left_padded_input[1, :2] = 0
out6 = model.predict(left_padded_input)
layer.reset_states()
right_padded_input = np.ones((num_samples, timesteps))
right_padded_input[0, -1:] = 0
right_padded_input[1, -2:] = 0
out7 = model.predict(right_padded_input)
np.testing.assert_allclose(out7, out6, atol=1e-5)
class SimpleRNNLayerGraphOnlyTest(test.TestCase):
# b/120919032
@tf_test_util.run_deprecated_v1 @tf_test_util.run_deprecated_v1
def test_regularizers_SimpleRNN(self): def test_regularizers_SimpleRNN(self):
embedding_dim = 4 embedding_dim = 4
layer_class = keras.layers.SimpleRNN layer_class = keras.layers.SimpleRNN
with self.cached_session(): layer = layer_class(
layer = layer_class( 5,
5, return_sequences=False,
return_sequences=False, weights=None,
weights=None, input_shape=(None, embedding_dim),
input_shape=(None, embedding_dim), kernel_regularizer=keras.regularizers.l1(0.01),
kernel_regularizer=keras.regularizers.l1(0.01), recurrent_regularizer=keras.regularizers.l1(0.01),
recurrent_regularizer=keras.regularizers.l1(0.01), bias_regularizer='l2',
bias_regularizer='l2', activity_regularizer='l1')
activity_regularizer='l1') layer.build((None, None, 2))
layer.build((None, None, 2)) self.assertEqual(len(layer.losses), 3)
self.assertEqual(len(layer.losses), 3)
x = keras.backend.variable(np.ones((2, 3, 2)))
x = keras.backend.variable(np.ones((2, 3, 2))) layer(x)
layer(x) self.assertEqual(len(layer.get_losses_for(x)), 1)
self.assertEqual(len(layer.get_losses_for(x)), 1)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册