diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index c99578bb348ce12cd0f8fb390738dede4bc54c4f..cd1499d4e47e4a7c336200d04d4bf54135f350d7 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -494,8 +494,14 @@ def rnn(cell, if isinstance(initial_states, (list, tuple)): states = map_structure(lambda x: x, initial_states)[0] for i, state in enumerate(states): - check_variable_and_dtype(state, 'states[' + str(i) + ']', - ['float32', 'float64'], 'rnn') + if isinstance(state, (list, tuple)): + for j, state_j in enumerate(state): + check_variable_and_dtype(state_j, 'state_j[' + str(j) + ']', + ['float32', 'float64'], 'rnn') + else: + check_variable_and_dtype(state, 'states[' + str(i) + ']', + ['float32', 'float64'], 'rnn') + check_type(sequence_length, 'sequence_length', (Variable, type(None)), 'rnn')