diff --git a/text.py b/text.py index 2702981d6e274ab20e250b7499b65632ffd7a3ba..3b6cac2fea9539b5836d0ada1184ab50e2424e1e 100644 --- a/text.py +++ b/text.py @@ -521,7 +521,15 @@ class DynamicDecode(Layer): (step_outputs, next_states, next_inputs, next_finished) = self.decoder.step(step_idx_tensor, inputs, states, **kwargs) - next_finished = layers.logical_or(next_finished, finished) + if not self.decoder.tracks_own_finished: + # BeamSearchDecoder would track it own finished, since + # beams would be reordered and the finished status of each + # entry might change. Otherwise, perform logical OR which + # would not change the already finished. + next_finished = layers.logical_or(next_finished, finished) + # To confirm states.finished/finished be consistent with + # next_finished. + layers.assign(next_finished, finished) next_sequence_lengths = layers.elementwise_add( sequence_lengths, layers.cast(