提交 43aad7a0 编写于 作者: J Junkun

beam search with optimality guarantees

上级 26524031
......@@ -285,7 +285,7 @@ class U2STTrainer(Trainer):
subsampling_factor=1,
load_aux_output=load_transcript,
num_encs=1,
dist_sampler=True)
dist_sampler=False)
logger.info("Setup train/valid Dataloader!")
else:
# test dataset, return raw text
......@@ -408,6 +408,7 @@ class U2STTester(U2STTrainer):
decoding_method=decode_cfg.decoding_method,
beam_size=decode_cfg.beam_size,
word_reward=decode_cfg.word_reward,
maxlen_ratio=decode_cfg.maxlen_ratio,
decoding_chunk_size=decode_cfg.decoding_chunk_size,
num_decoding_left_chunks=decode_cfg.num_decoding_left_chunks,
simulate_streaming=decode_cfg.simulate_streaming)
......@@ -435,6 +436,7 @@ class U2STTester(U2STTrainer):
decoding_method=decode_cfg.decoding_method,
beam_size=decode_cfg.beam_size,
word_reward=decode_cfg.word_reward,
maxlen_ratio=decode_cfg.maxlen_ratio,
decoding_chunk_size=decode_cfg.decoding_chunk_size,
num_decoding_left_chunks=decode_cfg.num_decoding_left_chunks,
simulate_streaming=decode_cfg.simulate_streaming)
......
......@@ -264,14 +264,17 @@ class U2STBaseModel(nn.Layer):
speech_lengths: paddle.Tensor,
beam_size: int=10,
word_reward: float=0.0,
maxlen_ratio: float=0.5,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False, ) -> paddle.Tensor:
""" Apply beam search on attention decoder
""" Apply beam search on attention decoder with length penalty
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
word_reward (float): word reward used in beam search
maxlen_ratio (float): max length ratio to bound the length of translated text
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
......@@ -284,90 +287,84 @@ class U2STBaseModel(nn.Layer):
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
assert speech.shape[0] == 1
device = speech.place
batch_size = speech.shape[0]
# Let's assume B = batch_size and N = beam_size
# 1. Encoder
# 1. Encoder and init hypothesis
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.shape[1]
encoder_dim = encoder_out.shape[2]
running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
encoder_mask = encoder_mask.unsqueeze(1).repeat(
1, beam_size, 1, 1).view(running_size, 1,
maxlen) # (B*N, 1, max_len)
hyps = paddle.ones(
[running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1)
# log scale score
scores = paddle.to_tensor(
[0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float)
scores = scores.to(device).repeat(batch_size).unsqueeze(1).to(
device) # (B*N, 1)
end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1)
cache: Optional[List[paddle.Tensor]] = None
maxlen = max(int(encoder_out.shape[1] * maxlen_ratio), 5)
hyp = {"score": 0.0, "yseq": [self.sos], "cache": None}
hyps = [hyp]
ended_hyps = []
cur_best_score = -float("inf")
cache = None
# 2. Decoder forward step by step
for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
# logp: (B*N, vocab)
ys = paddle.ones((len(hyps), i), dtype=paddle.long)
if hyps[0]["cache"] is not None:
cache = [paddle.ones((len(hyps), i-1, hyps[0]["cache"][0].shape[-1]), dtype=paddle.float32) for _ in range(len(hyps[0]["cache"]))]
for j, hyp in enumerate(hyps):
ys[j, :] = paddle.to_tensor(hyp["yseq"])
if hyps[0]["cache"] is not None:
for k in range(len(cache)):
cache[k][j] = hyps[j]["cache"][k]
ys_mask = subsequent_mask(i).unsqueeze(0).to(device)
logp, cache = self.st_decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp += word_reward
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
# 2.3 Seconde beam prune: select topk score with history
scores = scores + top_k_logp # (B*N, N), broadcast add
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
scores = scores.view(-1, 1) # (B*N, 1)
# 2.4. Compute base index in top_k_index,
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
# then find offset_k_index in top_k_index
base_k_index = paddle.arange(batch_size).view(-1, 1).repeat(
1, beam_size) # (B, N)
base_k_index = base_k_index * beam_size * beam_size
best_k_index = base_k_index.view(-1) + offset_k_index.view(
-1) # (B*N)
# 2.5 Update best hyps
best_k_pred = paddle.index_select(
top_k_index.view(-1), index=best_k_index, axis=0) # (B*N)
best_hyps_index = best_k_index // beam_size
last_best_k_hyps = paddle.index_select(
hyps, index=best_hyps_index, axis=0) # (B*N, i)
hyps = paddle.cat(
(last_best_k_hyps, best_k_pred.view(-1, 1)),
dim=1) # (B*N, i+1)
# 2.6 Update end flag
end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1)
encoder_out.repeat(len(hyps), 1, 1), encoder_mask.repeat(len(hyps), 1, 1), ys, ys_mask, cache)
hyps_best_kept = []
for j, hyp in enumerate(hyps):
top_k_logp, top_k_index = logp[j : j + 1].topk(beam_size)
for b in range(beam_size):
new_hyp = {}
new_hyp["score"] = hyp["score"] + float(top_k_logp[0, b])
new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"]
new_hyp["yseq"][len(hyp["yseq"])] = int(top_k_index[0, b])
new_hyp["cache"] = [cache_[j] for cache_ in cache]
# will be (2 x beam) hyps at most
hyps_best_kept.append(new_hyp)
hyps_best_kept = sorted(
hyps_best_kept, key=lambda x: -x["score"])[:beam_size]
# sort and get nbest
hyps = hyps_best_kept
if i == maxlen:
for hyp in hyps:
hyp["yseq"].append(self.eos)
# finalize the ended hypotheses with word reward (by length)
remained_hyps = []
for hyp in hyps:
if hyp["yseq"][-1] == self.eos:
hyp["score"] += (i - 1) * word_reward
cur_best_score = max(cur_best_score, hyp["score"])
ended_hyps.append(hyp)
else:
# stop while guarantee the optimality
if hyp["score"] + maxlen * word_reward > cur_best_score:
remained_hyps.append(hyp)
# stop predition when there is no unended hypothesis
if not remained_hyps:
break
hyps = remained_hyps
# 3. Select best of best
scores = scores.view(batch_size, beam_size)
# TODO: length normalization
best_index = paddle.argmax(scores, axis=-1).long() # (B)
best_hyps_index = best_index + paddle.arange(
batch_size, dtype=paddle.long) * beam_size
best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0)
best_hyps = best_hyps[:, 1:]
return best_hyps
best_hyp = max(ended_hyps, key=lambda x: x["score"])
return paddle.to_tensor([best_hyp["yseq"][1:]])
# @jit.to_static
def subsampling_rate(self) -> int:
......@@ -472,6 +469,7 @@ class U2STBaseModel(nn.Layer):
decoding_method: str,
beam_size: int,
word_reward: float=0.0,
maxlen_ratio: float=0.5,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False):
......@@ -507,6 +505,7 @@ class U2STBaseModel(nn.Layer):
feats_lengths,
beam_size=beam_size,
word_reward=word_reward,
maxlen_ratio=maxlen_ratio,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks,
simulate_streaming=simulate_streaming)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册