未验证 提交 54a47cd2 编写于 作者: G Guo Sheng 提交者: GitHub

Add tracks_own_finished to Decoder to avoid mismanagement of the finished...

Add tracks_own_finished to Decoder to avoid mismanagement of the finished state in dynamic_decode. (#23664)

test=develop
上级 614eb942
......@@ -601,6 +601,28 @@ class Decoder(object):
"""
raise NotImplementedError
@property
def tracks_own_finished(self):
"""
Describes whether the Decoder keeps track of finished states by itself.
`decoder.step()` would emit a bool `finished` value at each decoding
step. The emited `finished` can be used to determine whether every
batch entries is finished directly, or it can be combined with the
finished tracker keeped in `dynamic_decode` by performing a logical OR
to take the already finished into account.
If `False`, the latter would be took when performing `dynamic_decode`,
which is the default. Otherwise, the former would be took, which uses
the finished value emited by the decoder as all batch entry finished
status directly, and it is the case when batch entries might be
reordered such as beams in BeamSearchDecoder.
Returns:
bool: A python bool `False`.
"""
return False
class BeamSearchDecoder(Decoder):
"""
......@@ -1048,6 +1070,19 @@ class BeamSearchDecoder(Decoder):
# TODO: use FinalBeamSearchDecoderOutput as output
return predicted_ids, final_states
@property
def tracks_own_finished(self):
"""
BeamSearchDecoder reorders its beams and their finished state. Thus it
conflicts with `dynamic_decode` function's tracking of finished states.
Setting this property to true to avoid early stopping of decoding due
to mismanagement of the finished state.
Returns:
bool: A python bool `True`.
"""
return True
def dynamic_decode(decoder,
inits=None,
......@@ -1205,7 +1240,13 @@ def dynamic_decode(decoder,
states_arrays)
(outputs, next_states, next_inputs,
next_finished) = decoder.step(step_idx, inputs, states, **kwargs)
next_finished = control_flow.logical_or(next_finished, global_finished)
if not decoder.tracks_own_finished:
# BeamSearchDecoder would track it own finished, since beams would
# be reordered and the finished status of each entry might change.
# Otherwise, perform logical OR which would not change the already
# finished.
next_finished = control_flow.logical_or(next_finished,
global_finished)
next_sequence_lengths = nn.elementwise_add(
sequence_lengths,
tensor.cast(
......@@ -1226,6 +1267,10 @@ def dynamic_decode(decoder,
lambda x, x_array: control_flow.array_write(
x, i=step_idx, array=x_array), outputs, outputs_arrays)
control_flow.increment(x=step_idx, value=1.0, in_place=True)
# update the global_finished first, since it might be also in states of
# decoder, which otherwise would write a stale finished status to array
tensor.assign(next_finished, global_finished)
tensor.assign(next_sequence_lengths, sequence_lengths)
if is_test:
map_structure(tensor.assign, next_inputs, global_inputs)
map_structure(tensor.assign, next_states, global_states)
......@@ -1236,8 +1281,6 @@ def dynamic_decode(decoder,
map_structure(
lambda x, x_array: control_flow.array_write(
x, i=step_idx, array=x_array), next_states, states_arrays)
tensor.assign(next_finished, global_finished)
tensor.assign(next_sequence_lengths, sequence_lengths)
if max_step_num is not None:
control_flow.logical_and(
control_flow.logical_not(nn.reduce_all(global_finished)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册