提交 5103fd81 编写于 作者: G Gabriel de Marmiesse 提交者: François Chollet

Refactoring the `layer_test` function. (#10660)

上级 8fe86302
......@@ -81,6 +81,32 @@ def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
kwargs['weights'] = weights
layer = layer_cls(**kwargs)
expected_output_shape = layer.compute_output_shape(input_shape)
def _layer_in_model_test(model):
actual_output = model.predict(input_data)
actual_output_shape = actual_output.shape
for expected_dim, actual_dim in zip(expected_output_shape,
actual_output_shape):
if expected_dim is not None:
assert expected_dim == actual_dim
if expected_output is not None:
assert_allclose(actual_output, expected_output, rtol=1e-3)
# test serialization, weight setting at model level
model_config = model.get_config()
recovered_model = model.__class__.from_config(model_config)
if model.weights:
weights = model.get_weights()
recovered_model.set_weights(weights)
_output = recovered_model.predict(input_data)
assert_allclose(_output, actual_output, rtol=1e-3)
# test training mode (e.g. useful for dropout tests)
model.compile('rmsprop', 'mse')
model.train_on_batch(input_data, actual_output)
return actual_output
# test in functional API
if fixed_batch_size:
x = Input(batch_shape=input_shape, dtype=input_dtype)
......@@ -89,59 +115,19 @@ def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
y = layer(x)
assert K.dtype(y) == expected_output_dtype
# check shape inference
# check with the functional API
model = Model(x, y)
expected_output_shape = layer.compute_output_shape(input_shape)
actual_output = model.predict(input_data)
actual_output_shape = actual_output.shape
for expected_dim, actual_dim in zip(expected_output_shape,
actual_output_shape):
if expected_dim is not None:
assert expected_dim == actual_dim
if expected_output is not None:
assert_allclose(actual_output, expected_output, rtol=1e-3)
# test serialization, weight setting at model level
model_config = model.get_config()
recovered_model = Model.from_config(model_config)
if model.weights:
weights = model.get_weights()
recovered_model.set_weights(weights)
_output = recovered_model.predict(input_data)
assert_allclose(_output, actual_output, rtol=1e-3)
# test training mode (e.g. useful for dropout tests)
model.compile('rmsprop', 'mse')
model.train_on_batch(input_data, actual_output)
_layer_in_model_test(model)
# test as first layer in Sequential API
layer_config = layer.get_config()
layer_config['batch_input_shape'] = input_shape
layer = layer.__class__.from_config(layer_config)
# check with the sequential API
model = Sequential()
model.add(layer)
actual_output = model.predict(input_data)
actual_output_shape = actual_output.shape
for expected_dim, actual_dim in zip(expected_output_shape,
actual_output_shape):
if expected_dim is not None:
assert expected_dim == actual_dim
if expected_output is not None:
assert_allclose(actual_output, expected_output, rtol=1e-3)
# test serialization, weight setting at model level
model_config = model.get_config()
recovered_model = Sequential.from_config(model_config)
if model.weights:
weights = model.get_weights()
recovered_model.set_weights(weights)
_output = recovered_model.predict(input_data)
assert_allclose(_output, actual_output, rtol=1e-3)
# test training mode (e.g. useful for dropout tests)
model.compile('rmsprop', 'mse')
model.train_on_batch(input_data, actual_output)
actual_output = _layer_in_model_test(model)
# for further checks in the caller function
return actual_output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册