提交 79066ac6 编写于 作者: G guosheng

Fix sequence_length when None for RNN

上级 0a326f39
...@@ -818,7 +818,7 @@ class RNN(fluid.dygraph.Layer): ...@@ -818,7 +818,7 @@ class RNN(fluid.dygraph.Layer):
lambda x: fluid.layers.transpose(x, [1, 0] + list( lambda x: fluid.layers.transpose(x, [1, 0] + list(
range(2, len(x.shape)))), inputs) range(2, len(x.shape)))), inputs)
if sequence_length: if sequence_length is not None:
mask = fluid.layers.sequence_mask( mask = fluid.layers.sequence_mask(
sequence_length, sequence_length,
maxlen=time_steps, maxlen=time_steps,
...@@ -829,7 +829,7 @@ class RNN(fluid.dygraph.Layer): ...@@ -829,7 +829,7 @@ class RNN(fluid.dygraph.Layer):
inputs = map_structure( inputs = map_structure(
lambda x: fluid.layers.reverse(x, axis=[0]), inputs) lambda x: fluid.layers.reverse(x, axis=[0]), inputs)
mask = fluid.layers.reverse( mask = fluid.layers.reverse(
mask, axis=[0]) if sequence_length else None mask, axis=[0]) if sequence_length is not None else None
states = initial_states states = initial_states
outputs = [] outputs = []
...@@ -837,7 +837,7 @@ class RNN(fluid.dygraph.Layer): ...@@ -837,7 +837,7 @@ class RNN(fluid.dygraph.Layer):
step_inputs = map_structure(lambda x: x[i], inputs) step_inputs = map_structure(lambda x: x[i], inputs)
step_outputs, new_states = self.cell(step_inputs, states, step_outputs, new_states = self.cell(step_inputs, states,
**kwargs) **kwargs)
if sequence_length: if sequence_length is not None:
new_states = map_structure( new_states = map_structure(
partial( partial(
_maybe_copy, step_mask=mask[i]), _maybe_copy, step_mask=mask[i]),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册