提交 79066ac6 编写于 作者: G guosheng

Fix sequence_length when None for RNN

上级 0a326f39
......@@ -818,7 +818,7 @@ class RNN(fluid.dygraph.Layer):
lambda x: fluid.layers.transpose(x, [1, 0] + list(
range(2, len(x.shape)))), inputs)
if sequence_length:
if sequence_length is not None:
mask = fluid.layers.sequence_mask(
sequence_length,
maxlen=time_steps,
......@@ -829,7 +829,7 @@ class RNN(fluid.dygraph.Layer):
inputs = map_structure(
lambda x: fluid.layers.reverse(x, axis=[0]), inputs)
mask = fluid.layers.reverse(
mask, axis=[0]) if sequence_length else None
mask, axis=[0]) if sequence_length is not None else None
states = initial_states
outputs = []
......@@ -837,7 +837,7 @@ class RNN(fluid.dygraph.Layer):
step_inputs = map_structure(lambda x: x[i], inputs)
step_outputs, new_states = self.cell(step_inputs, states,
**kwargs)
if sequence_length:
if sequence_length is not None:
new_states = map_structure(
partial(
_maybe_copy, step_mask=mask[i]),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册