提交 25978176 编写于 作者: H Hui Zhang

comment u2 model for easy understand

上级 96c64237
...@@ -399,6 +399,7 @@ class U2BaseModel(nn.Module): ...@@ -399,6 +399,7 @@ class U2BaseModel(nn.Module):
assert speech.shape[0] == speech_lengths.shape[0] assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0 assert decoding_chunk_size != 0
batch_size = speech.shape[0] batch_size = speech.shape[0]
# Let's assume B = batch_size # Let's assume B = batch_size
# encoder_out: (B, maxlen, encoder_dim) # encoder_out: (B, maxlen, encoder_dim)
# encoder_mask: (B, 1, Tmax) # encoder_mask: (B, 1, Tmax)
...@@ -410,10 +411,12 @@ class U2BaseModel(nn.Module): ...@@ -410,10 +411,12 @@ class U2BaseModel(nn.Module):
# encoder_out_lens = encoder_mask.squeeze(1).sum(1) # encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1) encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1)
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen) pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen)
topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen) topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index] hyps = [hyp.tolist() for hyp in topk_index]
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
return hyps return hyps
...@@ -449,6 +452,7 @@ class U2BaseModel(nn.Module): ...@@ -449,6 +452,7 @@ class U2BaseModel(nn.Module):
batch_size = speech.shape[0] batch_size = speech.shape[0]
# For CTC prefix beam search, we only support batch_size=1 # For CTC prefix beam search, we only support batch_size=1
assert batch_size == 1 assert batch_size == 1
# Let's assume B = batch_size and N = beam_size # Let's assume B = batch_size and N = beam_size
# 1. Encoder forward and get CTC score # 1. Encoder forward and get CTC score
encoder_out, encoder_mask = self._forward_encoder( encoder_out, encoder_mask = self._forward_encoder(
...@@ -458,7 +462,9 @@ class U2BaseModel(nn.Module): ...@@ -458,7 +462,9 @@ class U2BaseModel(nn.Module):
maxlen = encoder_out.size(1) maxlen = encoder_out.size(1)
ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
cur_hyps = [(tuple(), (0.0, -float('inf')))] cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step # 2. CTC beam search step by step
for t in range(0, maxlen): for t in range(0, maxlen):
...@@ -498,6 +504,7 @@ class U2BaseModel(nn.Module): ...@@ -498,6 +504,7 @@ class U2BaseModel(nn.Module):
key=lambda x: log_add(list(x[1])), key=lambda x: log_add(list(x[1])),
reverse=True) reverse=True)
cur_hyps = next_hyps[:beam_size] cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps] hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
return hyps, encoder_out return hyps, encoder_out
...@@ -561,12 +568,13 @@ class U2BaseModel(nn.Module): ...@@ -561,12 +568,13 @@ class U2BaseModel(nn.Module):
batch_size = speech.shape[0] batch_size = speech.shape[0]
# For attention rescoring we only support batch_size=1 # For attention rescoring we only support batch_size=1
assert batch_size == 1 assert batch_size == 1
# encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
# len(hyps) = beam_size, encoder_out: (1, maxlen, encoder_dim)
hyps, encoder_out = self._ctc_prefix_beam_search( hyps, encoder_out = self._ctc_prefix_beam_search(
speech, speech_lengths, beam_size, decoding_chunk_size, speech, speech_lengths, beam_size, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming) num_decoding_left_chunks, simulate_streaming)
assert len(hyps) == beam_size assert len(hyps) == beam_size
hyps_pad = pad_sequence([ hyps_pad = pad_sequence([
paddle.to_tensor(hyp[0], place=device, dtype=paddle.long) paddle.to_tensor(hyp[0], place=device, dtype=paddle.long)
for hyp in hyps for hyp in hyps
...@@ -576,23 +584,28 @@ class U2BaseModel(nn.Module): ...@@ -576,23 +584,28 @@ class U2BaseModel(nn.Module):
dtype=paddle.long) # (beam_size,) dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones( encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.size(1)), dtype=paddle.bool) (beam_size, 1, encoder_out.size(1)), dtype=paddle.bool)
decoder_out, _ = self.decoder( decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size) hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy() decoder_out = decoder_out.numpy()
# Only use decoder score for rescoring # Only use decoder score for rescoring
best_score = -float('inf') best_score = -float('inf')
best_index = 0 best_index = 0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for i, hyp in enumerate(hyps): for i, hyp in enumerate(hyps):
score = 0.0 score = 0.0
for j, w in enumerate(hyp[0]): for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w] score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.eos] score += decoder_out[i][len(hyp[0])][self.eos]
# add ctc score # add ctc score (which in ln domain)
score += hyp[1] * ctc_weight score += hyp[1] * ctc_weight
if score > best_score: if score > best_score:
best_score = score best_score = score
......
...@@ -219,11 +219,14 @@ class BaseEncoder(nn.Layer): ...@@ -219,11 +219,14 @@ class BaseEncoder(nn.Layer):
xs, pos_emb, _ = self.embed( xs, pos_emb, _ = self.embed(
xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D) xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D)
if subsampling_cache is not None: if subsampling_cache is not None:
cache_size = subsampling_cache.size(1) #T cache_size = subsampling_cache.size(1) #T
xs = paddle.cat((subsampling_cache, xs), dim=1) xs = paddle.cat((subsampling_cache, xs), dim=1)
else: else:
cache_size = 0 cache_size = 0
# only used when using `RelPositionMultiHeadedAttention`
pos_emb = self.embed.position_encoding( pos_emb = self.embed.position_encoding(
offset=offset - cache_size, size=xs.size(1)) offset=offset - cache_size, size=xs.size(1))
...@@ -237,7 +240,7 @@ class BaseEncoder(nn.Layer): ...@@ -237,7 +240,7 @@ class BaseEncoder(nn.Layer):
# Real mask for transformer/conformer layers # Real mask for transformer/conformer layers
masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool) masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool)
masks = masks.unsqueeze(1) #[B=1, C=1, T] masks = masks.unsqueeze(1) #[B=1, L'=1, T]
r_elayers_output_cache = [] r_elayers_output_cache = []
r_conformer_cnn_cache = [] r_conformer_cnn_cache = []
for i, layer in enumerate(self.encoders): for i, layer in enumerate(self.encoders):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册