提交 913640d4 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 285765110
上级 722d9e57
......@@ -323,13 +323,16 @@ class SequenceBeamSearch(object):
new state dictionary.
"""
# Grow alive sequences by one token.
new_seq, new_log_probs, new_cache = self._grow_alive_seq(state)
new_seq, new_log_probs, topk_ids, new_cache = self._grow_alive_seq(state)
new_finished_flags = tf.equal(topk_ids, self.eos_id)
# Collect top beam_size alive sequences
alive_state = self._get_new_alive_state(new_seq, new_log_probs, new_cache)
alive_state = self._get_new_alive_state(new_seq, new_log_probs,
new_finished_flags, new_cache)
# Combine newly finished sequences with existing finished sequences, and
# collect the top k scoring sequences.
finished_state = self._get_new_finished_state(state, new_seq, new_log_probs)
finished_state = self._get_new_finished_state(state, new_seq, new_log_probs,
new_finished_flags)
# Increment loop index and create new state dictionary
new_state = {_StateKeys.CUR_INDEX: state[_StateKeys.CUR_INDEX] + 1}
......@@ -407,18 +410,20 @@ class SequenceBeamSearch(object):
tf.expand_dims(topk_ids, axis=0))
topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])
else:
topk_ids = tf.expand_dims(topk_ids, axis=2)
topk_seq = tf.concat([topk_seq, topk_ids], axis=2)
return topk_seq, topk_log_probs, new_cache
topk_seq = tf.concat([topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)
return topk_seq, topk_log_probs, topk_ids, new_cache
def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
def _get_new_alive_state(self, new_seq, new_log_probs, new_finished_flags,
new_cache):
"""Gather the top k sequences that are still alive.
Args:
new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape [batch_size, 2 * beam_size, cur_index + 1]
new_log_probs: Log probabilities of new sequences
float32 tensor with shape [batch_size, beam_size]
new_log_probs: Log probabilities of new sequences float32 tensor with
shape [batch_size, beam_size]
new_finished_flags: A boolean Tensor indicates which sequences are live
inside the beam.
new_cache: Dict of cached values for each sequence.
Returns:
......@@ -428,7 +433,6 @@ class SequenceBeamSearch(object):
Dict cache storing decoder states for top alive sequences}
"""
# To prevent finished sequences from being considered, set log probs to -inf
new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
new_log_probs += tf.cast(new_finished_flags, self.dtype) * -inf(self.dtype)
top_alive_seq, top_alive_log_probs, top_alive_cache = _gather_topk_beams(
......@@ -441,15 +445,18 @@ class SequenceBeamSearch(object):
_StateKeys.ALIVE_CACHE: top_alive_cache
}
def _get_new_finished_state(self, state, new_seq, new_log_probs):
def _get_new_finished_state(self, state, new_seq, new_log_probs,
new_finished_flags):
"""Combine new and old finished sequences, and gather the top k sequences.
Args:
state: A dictionary with the current loop state.
new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape [batch_size, beam_size, i + 1]
new_log_probs: Log probabilities of new sequences
float32 tensor with shape [batch_size, beam_size]
new_log_probs: Log probabilities of new sequences float32 tensor with
shape [batch_size, beam_size]
new_finished_flags: A boolean Tensor indicates which sequences are live
inside the beam.
Returns:
Dictionary with finished keys from _StateKeys:
......@@ -476,7 +483,6 @@ class SequenceBeamSearch(object):
new_scores = new_log_probs / length_norm
# Set the scores of the still-alive seq in new_seq to large negative values.
new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
new_scores += ((1. - tf.cast(new_finished_flags, self.dtype)) *
-inf(self.dtype))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册