From 79066ac6ca9cab6668bbe27ae43dc1d90bc0a472 Mon Sep 17 00:00:00 2001 From: guosheng Date: Tue, 21 Apr 2020 23:02:47 +0800 Subject: [PATCH] Fix sequence_length when None for RNN --- hapi/text/text.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hapi/text/text.py b/hapi/text/text.py index 319800d..ee74c51 100644 --- a/hapi/text/text.py +++ b/hapi/text/text.py @@ -818,7 +818,7 @@ class RNN(fluid.dygraph.Layer): lambda x: fluid.layers.transpose(x, [1, 0] + list( range(2, len(x.shape)))), inputs) - if sequence_length: + if sequence_length is not None: mask = fluid.layers.sequence_mask( sequence_length, maxlen=time_steps, @@ -829,7 +829,7 @@ class RNN(fluid.dygraph.Layer): inputs = map_structure( lambda x: fluid.layers.reverse(x, axis=[0]), inputs) mask = fluid.layers.reverse( - mask, axis=[0]) if sequence_length else None + mask, axis=[0]) if sequence_length is not None else None states = initial_states outputs = [] @@ -837,7 +837,7 @@ class RNN(fluid.dygraph.Layer): step_inputs = map_structure(lambda x: x[i], inputs) step_outputs, new_states = self.cell(step_inputs, states, **kwargs) - if sequence_length: + if sequence_length is not None: new_states = map_structure( partial( _maybe_copy, step_mask=mask[i]), -- GitLab