提交 7b4bfd90 编写于 作者: K Katherine Wu 提交者: TensorFlower Gardener

Add keras parameterization to training generator tests.

PiperOrigin-RevId: 225404979
上级 ba40882c
......@@ -29,6 +29,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import training_generator
......@@ -60,20 +61,17 @@ def custom_generator(mode=2):
yield x, y, w
@tf_test_util.run_all_in_graph_and_eager_modes
class TestGeneratorMethods(test.TestCase, parameterized.TestCase):
class TestGeneratorMethods(keras_parameterized.TestCase):
@unittest.skipIf(
os.name == 'nt',
'use_multiprocessing=True does not work on windows properly.')
@parameterized.parameters('sequential', 'functional')
def test_fit_generator_method(self, model_type):
if model_type == 'sequential':
model = testing_utils.get_small_sequential_mlp(
num_hidden=3, num_classes=4, input_dim=2)
else:
model = testing_utils.get_small_functional_mlp(
num_hidden=3, num_classes=4, input_dim=2)
# TODO(b/120940700): Bug with subclassed model inputs.
@keras_parameterized.run_with_all_model_types(exclude_models='subclass')
@keras_parameterized.run_all_keras_modes
def test_fit_generator_method(self):
model = testing_utils.get_small_mlp(
num_hidden=3, num_classes=4, input_dim=2)
model.compile(
loss='mse',
optimizer='sgd',
......@@ -109,19 +107,17 @@ class TestGeneratorMethods(test.TestCase, parameterized.TestCase):
@unittest.skipIf(
os.name == 'nt',
'use_multiprocessing=True does not work on windows properly.')
@parameterized.parameters('sequential', 'functional')
def test_evaluate_generator_method(self, model_type):
if model_type == 'sequential':
model = testing_utils.get_small_sequential_mlp(
num_hidden=3, num_classes=4, input_dim=2)
else:
model = testing_utils.get_small_functional_mlp(
num_hidden=3, num_classes=4, input_dim=2)
# TODO(b/120940700): Bug with subclassed model inputs.
@keras_parameterized.run_with_all_model_types(exclude_models='subclass')
@keras_parameterized.run_all_keras_modes
def test_evaluate_generator_method(self):
model = testing_utils.get_small_mlp(
num_hidden=3, num_classes=4, input_dim=2)
model.compile(
loss='mse',
optimizer='sgd',
metrics=['mae', metrics_module.CategoricalAccuracy()])
model.summary()
metrics=['mae', metrics_module.CategoricalAccuracy()],
run_eagerly=testing_utils.should_run_eagerly())
model.evaluate_generator(custom_generator(),
steps=5,
......@@ -142,18 +138,16 @@ class TestGeneratorMethods(test.TestCase, parameterized.TestCase):
@unittest.skipIf(
os.name == 'nt',
'use_multiprocessing=True does not work on windows properly.')
@parameterized.parameters('sequential', 'functional')
def test_predict_generator_method(self, model_type):
if model_type == 'sequential':
model = testing_utils.get_small_sequential_mlp(
num_hidden=3, num_classes=4, input_dim=2)
else:
model = testing_utils.get_small_functional_mlp(
num_hidden=3, num_classes=4, input_dim=2)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_predict_generator_method(self):
model = testing_utils.get_small_mlp(
num_hidden=3, num_classes=4, input_dim=2)
model.compile(
loss='mse',
optimizer='sgd',
metrics=['mae', metrics_module.CategoricalAccuracy()])
metrics=['mae', metrics_module.CategoricalAccuracy()],
run_eagerly=testing_utils.should_run_eagerly())
model.predict_generator(custom_generator(),
steps=5,
......@@ -183,13 +177,17 @@ class TestGeneratorMethods(test.TestCase, parameterized.TestCase):
max_queue_size=10,
workers=0)
# TODO(b/120940700): Bug with subclassed model inputs.
@keras_parameterized.run_with_all_model_types(exclude_models='subclass')
@keras_parameterized.run_all_keras_modes
def test_generator_methods_with_sample_weights(self):
model = keras.models.Sequential()
model.add(keras.layers.Dense(4, input_shape=(2,)))
model = testing_utils.get_small_mlp(
num_hidden=3, num_classes=4, input_dim=2)
model.compile(
loss='mse',
optimizer='sgd',
metrics=['mae', metrics_module.CategoricalAccuracy()])
metrics=['mae', metrics_module.CategoricalAccuracy()],
run_eagerly=testing_utils.should_run_eagerly())
model.fit_generator(custom_generator(mode=3),
steps_per_epoch=5,
......@@ -214,15 +212,19 @@ class TestGeneratorMethods(test.TestCase, parameterized.TestCase):
max_queue_size=10,
use_multiprocessing=False)
# TODO(b/120940700): Bug with subclassed model inputs.
@keras_parameterized.run_with_all_model_types(exclude_models='subclass')
@keras_parameterized.run_all_keras_modes
def test_generator_methods_invalid_use_case(self):
def invalid_generator():
while 1:
yield 0
model = keras.models.Sequential()
model.add(keras.layers.Dense(4, input_shape=(2,)))
model.compile(loss='mse', optimizer='sgd')
model = testing_utils.get_small_mlp(
num_hidden=3, num_classes=4, input_dim=2)
model.compile(loss='mse', optimizer='sgd',
run_eagerly=testing_utils.should_run_eagerly())
with self.assertRaises(ValueError):
model.fit_generator(invalid_generator(),
......@@ -251,6 +253,9 @@ class TestGeneratorMethods(test.TestCase, parameterized.TestCase):
max_queue_size=10,
use_multiprocessing=False)
# TODO(b/120940700): Bug with subclassed model inputs.
@keras_parameterized.run_with_all_model_types(exclude_models='subclass')
@keras_parameterized.run_all_keras_modes
def test_generator_input_to_fit_eval_predict(self):
val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
......@@ -258,12 +263,11 @@ class TestGeneratorMethods(test.TestCase, parameterized.TestCase):
while True:
yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
inputs = keras.layers.Input(shape=(10,))
x = keras.layers.Dense(10, activation='relu')(inputs)
outputs = keras.layers.Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs, outputs)
model = testing_utils.get_small_mlp(
num_hidden=10, num_classes=1, input_dim=10)
model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy',
run_eagerly=testing_utils.should_run_eagerly())
model.fit(
ones_generator(),
steps_per_epoch=2,
......@@ -273,9 +277,11 @@ class TestGeneratorMethods(test.TestCase, parameterized.TestCase):
model.predict(ones_generator(), steps=2)
@tf_test_util.run_all_in_graph_and_eager_modes
class TestGeneratorMethodsWithSequences(test.TestCase):
class TestGeneratorMethodsWithSequences(keras_parameterized.TestCase):
# TODO(b/120940700): Bug with subclassed model inputs.
@keras_parameterized.run_with_all_model_types(exclude_models='subclass')
@keras_parameterized.run_all_keras_modes
def test_training_with_sequences(self):
class DummySequence(keras.utils.Sequence):
......@@ -286,8 +292,8 @@ class TestGeneratorMethodsWithSequences(test.TestCase):
def __len__(self):
return 10
model = keras.models.Sequential()
model.add(keras.layers.Dense(4, input_shape=(2,)))
model = testing_utils.get_small_mlp(
num_hidden=3, num_classes=4, input_dim=2)
model.compile(loss='mse', optimizer='sgd')
model.fit_generator(DummySequence(),
......@@ -305,6 +311,9 @@ class TestGeneratorMethodsWithSequences(test.TestCase):
workers=0,
use_multiprocessing=False)
# TODO(b/120940700): Bug with subclassed model inputs.
@keras_parameterized.run_with_all_model_types(exclude_models='subclass')
@keras_parameterized.run_all_keras_modes
def test_sequence_input_to_fit_eval_predict(self):
val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
......@@ -316,10 +325,8 @@ class TestGeneratorMethodsWithSequences(test.TestCase):
def __len__(self):
return 2
inputs = keras.layers.Input(shape=(10,))
x = keras.layers.Dense(10, activation='relu')(inputs)
outputs = keras.layers.Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs, outputs)
model = testing_utils.get_small_mlp(
num_hidden=10, num_classes=1, input_dim=10)
model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
model.fit(CustomSequence(), validation_data=val_data, epochs=2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册