diff --git a/hapi/text/text.py b/hapi/text/text.py index 319800d46597f1ad7cf6806843184534e7626807..ee74c516437a366e1dd91cde236346ecf2e1b787 100644 --- a/hapi/text/text.py +++ b/hapi/text/text.py @@ -818,7 +818,7 @@ class RNN(fluid.dygraph.Layer): lambda x: fluid.layers.transpose(x, [1, 0] + list( range(2, len(x.shape)))), inputs) - if sequence_length: + if sequence_length is not None: mask = fluid.layers.sequence_mask( sequence_length, maxlen=time_steps, @@ -829,7 +829,7 @@ class RNN(fluid.dygraph.Layer): inputs = map_structure( lambda x: fluid.layers.reverse(x, axis=[0]), inputs) mask = fluid.layers.reverse( - mask, axis=[0]) if sequence_length else None + mask, axis=[0]) if sequence_length is not None else None states = initial_states outputs = [] @@ -837,7 +837,7 @@ class RNN(fluid.dygraph.Layer): step_inputs = map_structure(lambda x: x[i], inputs) step_outputs, new_states = self.cell(step_inputs, states, **kwargs) - if sequence_length: + if sequence_length is not None: new_states = map_structure( partial( _maybe_copy, step_mask=mask[i]),