diff --git a/keras/utils/test_utils.py b/keras/utils/test_utils.py index 7b06df50dc31acdd80acae8aedd38653123573a5..880b1605eccdb97b95e0cb0d5741c3b154ab0706 100644 --- a/keras/utils/test_utils.py +++ b/keras/utils/test_utils.py @@ -81,6 +81,32 @@ def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None, kwargs['weights'] = weights layer = layer_cls(**kwargs) + expected_output_shape = layer.compute_output_shape(input_shape) + + def _layer_in_model_test(model): + actual_output = model.predict(input_data) + actual_output_shape = actual_output.shape + for expected_dim, actual_dim in zip(expected_output_shape, + actual_output_shape): + if expected_dim is not None: + assert expected_dim == actual_dim + if expected_output is not None: + assert_allclose(actual_output, expected_output, rtol=1e-3) + + # test serialization, weight setting at model level + model_config = model.get_config() + recovered_model = model.__class__.from_config(model_config) + if model.weights: + weights = model.get_weights() + recovered_model.set_weights(weights) + _output = recovered_model.predict(input_data) + assert_allclose(_output, actual_output, rtol=1e-3) + + # test training mode (e.g. useful for dropout tests) + model.compile('rmsprop', 'mse') + model.train_on_batch(input_data, actual_output) + return actual_output + # test in functional API if fixed_batch_size: x = Input(batch_shape=input_shape, dtype=input_dtype) @@ -89,59 +115,19 @@ def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None, y = layer(x) assert K.dtype(y) == expected_output_dtype - # check shape inference + # check with the functional API model = Model(x, y) - expected_output_shape = layer.compute_output_shape(input_shape) - actual_output = model.predict(input_data) - actual_output_shape = actual_output.shape - for expected_dim, actual_dim in zip(expected_output_shape, - actual_output_shape): - if expected_dim is not None: - assert expected_dim == actual_dim - if expected_output is not None: - assert_allclose(actual_output, expected_output, rtol=1e-3) - - # test serialization, weight setting at model level - model_config = model.get_config() - recovered_model = Model.from_config(model_config) - if model.weights: - weights = model.get_weights() - recovered_model.set_weights(weights) - _output = recovered_model.predict(input_data) - assert_allclose(_output, actual_output, rtol=1e-3) - - # test training mode (e.g. useful for dropout tests) - model.compile('rmsprop', 'mse') - model.train_on_batch(input_data, actual_output) + _layer_in_model_test(model) # test as first layer in Sequential API layer_config = layer.get_config() layer_config['batch_input_shape'] = input_shape layer = layer.__class__.from_config(layer_config) + # check with the sequential API model = Sequential() model.add(layer) - actual_output = model.predict(input_data) - actual_output_shape = actual_output.shape - for expected_dim, actual_dim in zip(expected_output_shape, - actual_output_shape): - if expected_dim is not None: - assert expected_dim == actual_dim - if expected_output is not None: - assert_allclose(actual_output, expected_output, rtol=1e-3) - - # test serialization, weight setting at model level - model_config = model.get_config() - recovered_model = Sequential.from_config(model_config) - if model.weights: - weights = model.get_weights() - recovered_model.set_weights(weights) - _output = recovered_model.predict(input_data) - assert_allclose(_output, actual_output, rtol=1e-3) - - # test training mode (e.g. useful for dropout tests) - model.compile('rmsprop', 'mse') - model.train_on_batch(input_data, actual_output) + actual_output = _layer_in_model_test(model) # for further checks in the caller function return actual_output