diff --git a/keras/layers/containers.py b/keras/layers/containers.py index bde77d98d634334367c39ba5a46f3641d403a206..7e1416585ea00c82abab969b86c2a725775a0753 100644 --- a/keras/layers/containers.py +++ b/keras/layers/containers.py @@ -313,7 +313,7 @@ class Graph(Layer): if dtype == 'float': layer.input = K.placeholder(shape=layer.input_shape, name=name) else: - if len(input_shape) == 1: + if (input_shape and len(input_shape) == 1) or (batch_input_shape and len(batch_input_shape) == 2): layer.input = K.placeholder(shape=layer.input_shape, dtype='int32', name=name)