提交 5659bd23 编写于 作者: H Hui Zhang

add u2 model

上级 498104b0
......@@ -78,7 +78,32 @@ if not hasattr(paddle, 'cat'):
"override cat of paddle if exists or register, remove this when fixed!")
paddle.cat = cat
########### hcak paddle.Tensor #############
def item(x: paddle.Tensor):
if x.dtype == paddle.fluid.core_avx.VarDesc.VarType.FP32:
return float(x)
else:
raise ValueError("not support")
if not hasattr(paddle.Tensor, 'item'):
logger.warn(
"override item of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.item = item
def func_long(x: paddle.Tensor):
return paddle.cast(x, paddle.long)
if not hasattr(paddle.Tensor, 'long'):
logger.warn(
"override long of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.long = func_long
if not hasattr(paddle.Tensor, 'numel'):
logger.warn(
"override numel of paddle.Tensor if exists or register, remove this when fixed!"
......@@ -247,6 +272,15 @@ if not hasattr(paddle.Tensor, 'to'):
logger.warn("register user to to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'to', to)
def func_float(x: paddle.Tensor) -> paddle.Tensor:
return x.astype(paddle.float)
if not hasattr(paddle.Tensor, 'float'):
logger.warn("register user float to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'float', func_float)
########### hcak paddle.nn.functional #############
......
......@@ -53,22 +53,21 @@ from deepspeech.utils.ctc_utils import remove_duplicates_and_blank
logger = logging.getLogger(__name__)
__all__ = ['U2Model']
__all__ = ['U2TransformerModel', "U2ConformerModel"]
class U2Model(nn.Module):
"""CTC-Attention hybrid Encoder-Decoder model"""
def __init__(
self,
vocab_size: int,
encoder: TransformerEncoder,
decoder: TransformerDecoder,
ctc: CTCDecoder,
ctc_weight: float=0.5,
ignore_id: int=IGNORE_ID,
lsm_weight: float=0.0,
length_normalized_loss: bool=False, ):
def __init__(self,
vocab_size: int,
encoder: TransformerEncoder,
decoder: TransformerDecoder,
ctc: CTCDecoder,
ctc_weight: float=0.5,
ignore_id: int=IGNORE_ID,
lsm_weight: float=0.0,
length_normalized_loss: bool=False):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
super().__init__()
......@@ -263,51 +262,54 @@ class U2Model(nn.Module):
# Stop if all batch and all beam produce eos
if end_flag.sum() == running_size:
break
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
# logp: (B*N, vocab)
logp, cache = self.decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
# 2.3 Seconde beam prune: select topk score with history
scores = scores + top_k_logp # (B*N, N), broadcast add
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
scores = scores.view(-1, 1) # (B*N, 1)
# 2.4. Compute base index in top_k_index,
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
# then find offset_k_index in top_k_index
base_k_index = torch.arange(
batch_size,
device=device).view(-1, 1).repeat([1, beam_size]) # (B, N)
base_k_index = paddle.arange(batch_size).view(-1, 1).repeat(
[1, beam_size]) # (B, N)
base_k_index = base_k_index * beam_size * beam_size
best_k_index = base_k_index.view(-1) + offset_k_index.view(
-1) # (B*N)
# 2.5 Update best hyps
best_k_pred = torch.index_select(
top_k_index.view(-1), dim=-1, index=best_k_index) # (B*N)
best_k_pred = paddle.index_select(
top_k_index.view(-1), index=best_k_index, axis=0) # (B*N)
best_hyps_index = best_k_index // beam_size
last_best_k_hyps = torch.index_select(
hyps, dim=0, index=best_hyps_index) # (B*N, i)
hyps = torch.cat(
last_best_k_hyps = paddle.index_select(
hyps, index=best_hyps_index, axis=0) # (B*N, i)
hyps = paddle.cat(
(last_best_k_hyps, best_k_pred.view(-1, 1)),
dim=1) # (B*N, i+1)
# 2.6 Update end flag
end_flag = torch.eq(hyps[:, -1], self.eos).view(-1, 1)
end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1)
# 3. Select best of best
scores = scores.view(batch_size, beam_size)
# TODO: length normalization
best_index = torch.argmax(scores, dim=-1).long()
best_hyps_index = best_index + torch.arange(
batch_size, dtype=torch.long, device=device) * beam_size
best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index)
best_index = paddle.argmax(scores, axis=-1).long() # (B)
best_hyps_index = best_index + paddle.arange(
batch_size, dtype=paddle.long) * beam_size
best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0)
best_hyps = best_hyps[:, 1:]
return best_hyps
......@@ -346,8 +348,8 @@ class U2Model(nn.Module):
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
mask = make_pad_mask(encoder_out_lens) # (B, maxlen)
topk_index = topk_index.masked_fill_(mask, self.eos) # (B, maxlen)
pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen)
topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index]
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
return hyps
......@@ -360,7 +362,7 @@ class U2Model(nn.Module):
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False,
) -> Tuple[List[List[int]], paddle.Tensor]:
blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]:
""" CTC prefix beam search inner implementation
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
......@@ -374,7 +376,7 @@ class U2Model(nn.Module):
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[List[int]]: nbest results
List[Tuple[int, float]]: nbest results, (N,1), (text, likelihood)
paddle.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode
"""
......@@ -406,7 +408,7 @@ class U2Model(nn.Module):
ps = logp[s].item()
for prefix, (pb, pnb) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
......@@ -491,7 +493,7 @@ class U2Model(nn.Module):
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
device = speech.device
device = speech.place
batch_size = speech.shape[0]
# For attention rescoring we only support batch_size=1
assert batch_size == 1
......@@ -502,22 +504,22 @@ class U2Model(nn.Module):
assert len(hyps) == beam_size
hyps_pad = pad_sequence([
paddle.tensor(hyp[0], device=device, dtype=torch.long)
paddle.to_tensor(hyp[0], place=device, dtype=paddle.long)
for hyp in hyps
], True, self.ignore_id) # (beam_size, max_hyps_len)
hyps_lens = paddle.tensor(
[len(hyp[0]) for hyp in hyps], device=device,
dtype=torch.long) # (beam_size,)
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = torch.ones(
beam_size, 1, encoder_out.size(1), dtype=torch.bool, device=device)
encoder_mask = paddle.ones(
beam_size, 1, encoder_out.size(1), dtype=paddle.bool)
decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
decoder_out = decoder_out.cpu().numpy()
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1)
decoder_out = decoder_out.numpy()
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
......@@ -609,56 +611,83 @@ class U2Model(nn.Module):
hypothesis from ctc prefix beam search and one encoder output
Args:
hyps (paddle.Tensor): hyps from ctc prefix beam search, already
pad sos at the begining
hyps_lens (paddle.Tensor): length of each hyp in hyps
encoder_out (paddle.Tensor): corresponding encoder output
pad sos at the begining, (B, T)
hyps_lens (paddle.Tensor): length of each hyp in hyps, (B)
encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D)
Returns:
paddle.Tensor: decoder output
paddle.Tensor: decoder output, (B, L)
"""
assert encoder_out.size(0) == 1
num_hyps = hyps.size(0)
assert hyps_lens.size(0) == num_hyps
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
encoder_mask = torch.ones(
num_hyps,
1,
encoder_out.size(1),
dtype=torch.bool,
device=encoder_out.device)
decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps,
hyps_lens) # (num_hyps, max_hyps_len, vocab_size)
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
# (B, 1, T)
encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool)
# (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1)
return decoder_out
def init_asr_model(configs):
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn'])
global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(), torch.from_numpy(istd).float())
else:
global_cmvn = None
class U2TransformerModel(U2Model):
def __init__(configs: dict):
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'],
configs['is_json_cmvn'])
global_cmvn = GlobalCMVN(
paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float())
else:
global_cmvn = None
input_dim = configs['input_dim']
vocab_size = configs['output_dim']
encoder_type = configs.get('encoder', 'transformer')
assert encoder_type == 'transformer'
encoder = TransformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
ctc = CTCDecoder(vocab_size, encoder.output_size())
self.__init__(
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
**configs['model_conf'])
class U2ConformerModel(U2Model):
def __init__(configs: dict):
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'],
configs['is_json_cmvn'])
global_cmvn = GlobalCMVN(
paddle.to_tensor(mean).float(), paddle.to_tensor(istd).float())
else:
global_cmvn = None
input_dim = configs['input_dim']
vocab_size = configs['output_dim']
input_dim = configs['input_dim']
vocab_size = configs['output_dim']
encoder_type = configs.get('encoder', 'conformer')
if encoder_type == 'conformer':
encoder_type = configs.get('encoder', 'conformer')
assert encoder_type == 'conformer'
encoder = ConformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
else:
encoder = TransformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
ctc = CTCDecoder(vocab_size, encoder.output_size())
model = U2Model(
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
**configs['model_conf'], )
return model
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
ctc = CTCDecoder(vocab_size, encoder.output_size())
self.__init__(
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
**configs['model_conf'])
......@@ -37,7 +37,7 @@ class BaseSubsampling(nn.Layer):
self.pos_enc = pos_enc_class
# window size = (1 + right_context) + (chunk_size -1) * subsampling_rate
self.right_context = 0
# stride = chunk_size * subsampling_rate
# stride = subsampling_rate * chunk_size
self.subsampling_rate = 1
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
......
......@@ -63,8 +63,13 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
def log_add(args: List[int]) -> float:
"""
Stable log add
"""Stable log add
Args:
args (List[int]): log scores
Returns:
float: sum of log scores
"""
if all(a == -float('inf') for a in args):
return -float('inf')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册