提交 43aad7a0 编写于 作者: J Junkun

beam search with optimality guarantees

上级 26524031
...@@ -285,7 +285,7 @@ class U2STTrainer(Trainer): ...@@ -285,7 +285,7 @@ class U2STTrainer(Trainer):
subsampling_factor=1, subsampling_factor=1,
load_aux_output=load_transcript, load_aux_output=load_transcript,
num_encs=1, num_encs=1,
dist_sampler=True) dist_sampler=False)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
# test dataset, return raw text # test dataset, return raw text
...@@ -408,6 +408,7 @@ class U2STTester(U2STTrainer): ...@@ -408,6 +408,7 @@ class U2STTester(U2STTrainer):
decoding_method=decode_cfg.decoding_method, decoding_method=decode_cfg.decoding_method,
beam_size=decode_cfg.beam_size, beam_size=decode_cfg.beam_size,
word_reward=decode_cfg.word_reward, word_reward=decode_cfg.word_reward,
maxlen_ratio=decode_cfg.maxlen_ratio,
decoding_chunk_size=decode_cfg.decoding_chunk_size, decoding_chunk_size=decode_cfg.decoding_chunk_size,
num_decoding_left_chunks=decode_cfg.num_decoding_left_chunks, num_decoding_left_chunks=decode_cfg.num_decoding_left_chunks,
simulate_streaming=decode_cfg.simulate_streaming) simulate_streaming=decode_cfg.simulate_streaming)
...@@ -435,6 +436,7 @@ class U2STTester(U2STTrainer): ...@@ -435,6 +436,7 @@ class U2STTester(U2STTrainer):
decoding_method=decode_cfg.decoding_method, decoding_method=decode_cfg.decoding_method,
beam_size=decode_cfg.beam_size, beam_size=decode_cfg.beam_size,
word_reward=decode_cfg.word_reward, word_reward=decode_cfg.word_reward,
maxlen_ratio=decode_cfg.maxlen_ratio,
decoding_chunk_size=decode_cfg.decoding_chunk_size, decoding_chunk_size=decode_cfg.decoding_chunk_size,
num_decoding_left_chunks=decode_cfg.num_decoding_left_chunks, num_decoding_left_chunks=decode_cfg.num_decoding_left_chunks,
simulate_streaming=decode_cfg.simulate_streaming) simulate_streaming=decode_cfg.simulate_streaming)
......
...@@ -264,14 +264,17 @@ class U2STBaseModel(nn.Layer): ...@@ -264,14 +264,17 @@ class U2STBaseModel(nn.Layer):
speech_lengths: paddle.Tensor, speech_lengths: paddle.Tensor,
beam_size: int=10, beam_size: int=10,
word_reward: float=0.0, word_reward: float=0.0,
maxlen_ratio: float=0.5,
decoding_chunk_size: int=-1, decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1, num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False, ) -> paddle.Tensor: simulate_streaming: bool=False, ) -> paddle.Tensor:
""" Apply beam search on attention decoder """ Apply beam search on attention decoder with length penalty
Args: Args:
speech (paddle.Tensor): (batch, max_len, feat_dim) speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, ) speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search beam_size (int): beam size for beam search
word_reward (float): word reward used in beam search
maxlen_ratio (float): max length ratio to bound the length of translated text
decoding_chunk_size (int): decoding chunk for dynamic chunk decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model. trained model.
<0: for decoding, use full chunk. <0: for decoding, use full chunk.
...@@ -284,90 +287,84 @@ class U2STBaseModel(nn.Layer): ...@@ -284,90 +287,84 @@ class U2STBaseModel(nn.Layer):
""" """
assert speech.shape[0] == speech_lengths.shape[0] assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0 assert decoding_chunk_size != 0
assert speech.shape[0] == 1
device = speech.place device = speech.place
batch_size = speech.shape[0]
# Let's assume B = batch_size and N = beam_size # Let's assume B = batch_size and N = beam_size
# 1. Encoder # 1. Encoder and init hypothesis
encoder_out, encoder_mask = self._forward_encoder( encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim) simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.shape[1]
encoder_dim = encoder_out.shape[2] maxlen = max(int(encoder_out.shape[1] * maxlen_ratio), 5)
running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( hyp = {"score": 0.0, "yseq": [self.sos], "cache": None}
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) hyps = [hyp]
encoder_mask = encoder_mask.unsqueeze(1).repeat( ended_hyps = []
1, beam_size, 1, 1).view(running_size, 1, cur_best_score = -float("inf")
maxlen) # (B*N, 1, max_len) cache = None
hyps = paddle.ones(
[running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1)
# log scale score
scores = paddle.to_tensor(
[0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float)
scores = scores.to(device).repeat(batch_size).unsqueeze(1).to(
device) # (B*N, 1)
end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1)
cache: Optional[List[paddle.Tensor]] = None
# 2. Decoder forward step by step # 2. Decoder forward step by step
for i in range(1, maxlen + 1): for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos ys = paddle.ones((len(hyps), i), dtype=paddle.long)
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size: if hyps[0]["cache"] is not None:
break cache = [paddle.ones((len(hyps), i-1, hyps[0]["cache"][0].shape[-1]), dtype=paddle.float32) for _ in range(len(hyps[0]["cache"]))]
for j, hyp in enumerate(hyps):
# 2.1 Forward decoder step ys[j, :] = paddle.to_tensor(hyp["yseq"])
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( if hyps[0]["cache"] is not None:
running_size, 1, 1).to(device) # (B*N, i, i) for k in range(len(cache)):
# logp: (B*N, vocab) cache[k][j] = hyps[j]["cache"][k]
ys_mask = subsequent_mask(i).unsqueeze(0).to(device)
logp, cache = self.st_decoder.forward_one_step( logp, cache = self.st_decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache) encoder_out.repeat(len(hyps), 1, 1), encoder_mask.repeat(len(hyps), 1, 1), ys, ys_mask, cache)
# 2.2 First beam prune: select topk best prob at current time hyps_best_kept = []
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) for j, hyp in enumerate(hyps):
top_k_logp += word_reward top_k_logp, top_k_index = logp[j : j + 1].topk(beam_size)
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) for b in range(beam_size):
new_hyp = {}
# 2.3 Seconde beam prune: select topk score with history new_hyp["score"] = hyp["score"] + float(top_k_logp[0, b])
scores = scores + top_k_logp # (B*N, N), broadcast add new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"]
scores, offset_k_index = scores.topk(k=beam_size) # (B, N) new_hyp["yseq"][len(hyp["yseq"])] = int(top_k_index[0, b])
scores = scores.view(-1, 1) # (B*N, 1) new_hyp["cache"] = [cache_[j] for cache_ in cache]
# will be (2 x beam) hyps at most
# 2.4. Compute base index in top_k_index, hyps_best_kept.append(new_hyp)
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
# then find offset_k_index in top_k_index hyps_best_kept = sorted(
base_k_index = paddle.arange(batch_size).view(-1, 1).repeat( hyps_best_kept, key=lambda x: -x["score"])[:beam_size]
1, beam_size) # (B, N)
base_k_index = base_k_index * beam_size * beam_size # sort and get nbest
best_k_index = base_k_index.view(-1) + offset_k_index.view( hyps = hyps_best_kept
-1) # (B*N) if i == maxlen:
for hyp in hyps:
# 2.5 Update best hyps hyp["yseq"].append(self.eos)
best_k_pred = paddle.index_select(
top_k_index.view(-1), index=best_k_index, axis=0) # (B*N) # finalize the ended hypotheses with word reward (by length)
best_hyps_index = best_k_index // beam_size remained_hyps = []
last_best_k_hyps = paddle.index_select( for hyp in hyps:
hyps, index=best_hyps_index, axis=0) # (B*N, i) if hyp["yseq"][-1] == self.eos:
hyps = paddle.cat( hyp["score"] += (i - 1) * word_reward
(last_best_k_hyps, best_k_pred.view(-1, 1)), cur_best_score = max(cur_best_score, hyp["score"])
dim=1) # (B*N, i+1) ended_hyps.append(hyp)
else:
# 2.6 Update end flag # stop while guarantee the optimality
end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1) if hyp["score"] + maxlen * word_reward > cur_best_score:
remained_hyps.append(hyp)
# stop predition when there is no unended hypothesis
if not remained_hyps:
break
hyps = remained_hyps
# 3. Select best of best # 3. Select best of best
scores = scores.view(batch_size, beam_size) best_hyp = max(ended_hyps, key=lambda x: x["score"])
# TODO: length normalization
best_index = paddle.argmax(scores, axis=-1).long() # (B) return paddle.to_tensor([best_hyp["yseq"][1:]])
best_hyps_index = best_index + paddle.arange(
batch_size, dtype=paddle.long) * beam_size
best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0)
best_hyps = best_hyps[:, 1:]
return best_hyps
# @jit.to_static # @jit.to_static
def subsampling_rate(self) -> int: def subsampling_rate(self) -> int:
...@@ -472,6 +469,7 @@ class U2STBaseModel(nn.Layer): ...@@ -472,6 +469,7 @@ class U2STBaseModel(nn.Layer):
decoding_method: str, decoding_method: str,
beam_size: int, beam_size: int,
word_reward: float=0.0, word_reward: float=0.0,
maxlen_ratio: float=0.5,
decoding_chunk_size: int=-1, decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1, num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False): simulate_streaming: bool=False):
...@@ -507,6 +505,7 @@ class U2STBaseModel(nn.Layer): ...@@ -507,6 +505,7 @@ class U2STBaseModel(nn.Layer):
feats_lengths, feats_lengths,
beam_size=beam_size, beam_size=beam_size,
word_reward=word_reward, word_reward=word_reward,
maxlen_ratio=maxlen_ratio,
decoding_chunk_size=decoding_chunk_size, decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks, num_decoding_left_chunks=num_decoding_left_chunks,
simulate_streaming=simulate_streaming) simulate_streaming=simulate_streaming)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册