提交 69c30a15 编写于 作者: Y Yu-Yang Huang 提交者: François Chollet

Fix nested sequential deferred build (#10655)

* Fix nested sequential deferred build

* Add test

* Unify style
上级 3dcd9c76
......@@ -149,8 +149,6 @@ class Sequential(Model):
first_layer = layer.layers[0]
while isinstance(first_layer, (Model, Sequential)):
first_layer = first_layer.layers[0]
batch_shape = first_layer.batch_input_shape
dtype = first_layer.dtype
if hasattr(first_layer, 'batch_input_shape'):
batch_shape = first_layer.batch_input_shape
......
......@@ -422,5 +422,45 @@ def test_sequential_deferred_build():
assert len(new_model.weights) == 4
@keras_test
def test_nested_sequential_deferred_build():
inner_model = keras.models.Sequential()
inner_model.add(keras.layers.Dense(3))
inner_model.add(keras.layers.Dense(3))
model = keras.models.Sequential()
model.add(inner_model)
model.add(keras.layers.Dense(5))
model.compile('sgd', 'mse')
assert inner_model.built is False
assert len(inner_model.layers) == 2
assert len(inner_model.weights) == 0
assert model.built is False
assert len(model.layers) == 2
assert len(model.weights) == 0
model.train_on_batch(
np.random.random((2, 4)), np.random.random((2, 5)))
assert inner_model.built is True
assert len(inner_model.layers) == 2
assert len(inner_model.weights) == 4
assert model.built is True
assert len(model.layers) == 2
assert len(model.weights) == 6
config = model.get_config()
new_model = keras.models.Sequential.from_config(config)
assert new_model.built is True
assert len(new_model.layers) == 2
assert len(new_model.weights) == 6
new_inner_model = new_model.layers[0]
assert new_inner_model.built is True
assert len(new_inner_model.layers) == 2
assert len(new_inner_model.weights) == 4
if __name__ == '__main__':
pytest.main([__file__])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册