未验证 提交 4af3ec0f 编写于 作者: X Xing Wu 提交者: GitHub

fix rnn check_type list error (#24346)

* fix rnn check_type list error

* tigger ci, test=develop

* update modify, test=develop
上级 63da846d
......@@ -494,8 +494,14 @@ def rnn(cell,
if isinstance(initial_states, (list, tuple)):
states = map_structure(lambda x: x, initial_states)[0]
for i, state in enumerate(states):
check_variable_and_dtype(state, 'states[' + str(i) + ']',
['float32', 'float64'], 'rnn')
if isinstance(state, (list, tuple)):
for j, state_j in enumerate(state):
check_variable_and_dtype(state_j, 'state_j[' + str(j) + ']',
['float32', 'float64'], 'rnn')
else:
check_variable_and_dtype(state, 'states[' + str(i) + ']',
['float32', 'float64'], 'rnn')
check_type(sequence_length, 'sequence_length', (Variable, type(None)),
'rnn')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册