diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index 4c048355d2f73a8d627c48dedff149ab56cd029a..762e79e648b137e1589b9f398bbc062ced38b8da 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -601,6 +601,28 @@ class Decoder(object): """ raise NotImplementedError + @property + def tracks_own_finished(self): + """ + Describes whether the Decoder keeps track of finished states by itself. + + `decoder.step()` would emit a bool `finished` value at each decoding + step. The emited `finished` can be used to determine whether every + batch entries is finished directly, or it can be combined with the + finished tracker keeped in `dynamic_decode` by performing a logical OR + to take the already finished into account. + + If `False`, the latter would be took when performing `dynamic_decode`, + which is the default. Otherwise, the former would be took, which uses + the finished value emited by the decoder as all batch entry finished + status directly, and it is the case when batch entries might be + reordered such as beams in BeamSearchDecoder. + + Returns: + bool: A python bool `False`. + """ + return False + class BeamSearchDecoder(Decoder): """ @@ -1048,6 +1070,19 @@ class BeamSearchDecoder(Decoder): # TODO: use FinalBeamSearchDecoderOutput as output return predicted_ids, final_states + @property + def tracks_own_finished(self): + """ + BeamSearchDecoder reorders its beams and their finished state. Thus it + conflicts with `dynamic_decode` function's tracking of finished states. + Setting this property to true to avoid early stopping of decoding due + to mismanagement of the finished state. + + Returns: + bool: A python bool `True`. + """ + return True + def dynamic_decode(decoder, inits=None, @@ -1205,7 +1240,13 @@ def dynamic_decode(decoder, states_arrays) (outputs, next_states, next_inputs, next_finished) = decoder.step(step_idx, inputs, states, **kwargs) - next_finished = control_flow.logical_or(next_finished, global_finished) + if not decoder.tracks_own_finished: + # BeamSearchDecoder would track it own finished, since beams would + # be reordered and the finished status of each entry might change. + # Otherwise, perform logical OR which would not change the already + # finished. + next_finished = control_flow.logical_or(next_finished, + global_finished) next_sequence_lengths = nn.elementwise_add( sequence_lengths, tensor.cast( @@ -1226,6 +1267,10 @@ def dynamic_decode(decoder, lambda x, x_array: control_flow.array_write( x, i=step_idx, array=x_array), outputs, outputs_arrays) control_flow.increment(x=step_idx, value=1.0, in_place=True) + # update the global_finished first, since it might be also in states of + # decoder, which otherwise would write a stale finished status to array + tensor.assign(next_finished, global_finished) + tensor.assign(next_sequence_lengths, sequence_lengths) if is_test: map_structure(tensor.assign, next_inputs, global_inputs) map_structure(tensor.assign, next_states, global_states) @@ -1236,8 +1281,6 @@ def dynamic_decode(decoder, map_structure( lambda x, x_array: control_flow.array_write( x, i=step_idx, array=x_array), next_states, states_arrays) - tensor.assign(next_finished, global_finished) - tensor.assign(next_sequence_lengths, sequence_lengths) if max_step_num is not None: control_flow.logical_and( control_flow.logical_not(nn.reduce_all(global_finished)),