提交 498104b0 编写于 作者: H Hui Zhang

refactor data feed order

上级 f5477d31
......@@ -338,7 +338,7 @@
}
],
"source": [
"for idx, (audio, text, audio_len, text_len) in enumerate(batch_reader()):\n",
"for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test', text)\n",
" print(\"test raw\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
......@@ -386,4 +386,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
\ No newline at end of file
......@@ -249,7 +249,7 @@
}
],
"source": [
" for idx, (audio, text, audio_len, text_len) in enumerate(batch_reader()):\n",
" for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n",
" print('test', text)\n",
" print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[0]))\n",
" print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[-1]))\n",
......@@ -835,7 +835,7 @@
"\n",
" return logits, probs, audio_len\n",
"\n",
" def forward(self, audio, text, audio_len, text_len):\n",
" def forward(self, audio, audio_len, text, text_len):\n",
" \"\"\"\n",
" audio: shape [B, D, T]\n",
" text: shape [B, T]\n",
......@@ -877,10 +877,10 @@
"metadata": {},
"outputs": [],
"source": [
"audio, text, audio_len, text_len = None, None, None, None\n",
"audio, audio_len, text, text_len = None, None, None, None\n",
"\n",
"for idx, inputs in enumerate(batch_reader):\n",
" audio, text, audio_len, text_len = inputs\n",
" audio, audio_len, text, text_len = inputs\n",
"# print(idx)\n",
"# print('a', audio.shape, audio.place)\n",
"# print('t', text)\n",
......@@ -960,7 +960,7 @@
}
],
"source": [
"outputs = dp_model(audio, text, audio_len, text_len)\n",
"outputs = dp_model(audio, audio_len, text, text_len)\n",
"logits, _, logits_len = outputs\n",
"print('logits len', logits_len)\n",
"loss = loss_fn.forward(logits, text, logits_len, text_len)\n",
......
......@@ -222,6 +222,31 @@ if not hasattr(paddle.Tensor, 'relu'):
logger.warn("register user relu to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu)
def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor:
return x.astype(other.dtype)
if not hasattr(paddle.Tensor, 'type_as'):
logger.warn(
"register user type_as to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'type_as', type_as)
def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
assert len(args) == 1
if isinstace(args[0], str): # dtype
return x.astype(args[0])
elif isinstance(args[0], paddle.Tensor): #Tensor
return x.astype(args[0].dtype)
else: # Device
return x
if not hasattr(paddle.Tensor, 'to'):
logger.warn("register user to to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'to', to)
########### hcak paddle.nn.functional #############
......
......@@ -103,7 +103,7 @@ def tune(config, args):
trans.append(''.join([chr(i) for i in ids]))
return trans
audio, text, audio_len, text_len = infer_data
audio, audio_len, text, text_len = infer_data
target_transcripts = ordid2token(text, text_len)
num_ins += audio.shape[0]
......
......@@ -17,6 +17,7 @@ import numpy as np
from collections import namedtuple
from deepspeech.io.utility import pad_sequence
from deepspeech.utils.tensor_utils import IGNORE_ID
logger = logging.getLogger(__name__)
......@@ -29,10 +30,6 @@ class SpeechCollator():
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
If ``padding_to`` is -1, the maximun shape in the batch will be used
as the target shape for padding. Otherwise, `padding_to` will be the
target shape (only refers to the second axis).
if ``is_training`` is True, text is token ids else is raw string.
"""
self._is_training = is_training
......@@ -48,8 +45,8 @@ class SpeechCollator():
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
text : (B, Umax)
audio_lens: (B)
text : (B, Umax)
text_lens: (B)
"""
audios = []
......@@ -76,7 +73,9 @@ class SpeechCollator():
padded_audios = pad_sequence(
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
padded_texts = pad_sequence(texts, padding_value=-1).astype(np.int32)
audio_lens = np.array(audio_lens).astype(np.int64)
# (TODO:Hui Zhang) ctc loss does not support int64 labels
padded_texts = pad_sequence(
texts, padding_value=IGNORE_ID).astype(np.int32)
text_lens = np.array(text_lens).astype(np.int64)
return padded_audios, padded_texts, audio_lens, text_lens
return padded_audios, audio_lens, padded_texts, text_lens
......@@ -168,13 +168,13 @@ class DeepSpeech2Model(nn.Layer):
dropout_rate=0.0,
reduction=True)
def forward(self, audio, text, audio_len, text_len):
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
Args:
audio (Tenosr): [B, T, D]
text (Tensor): [B, U]
audio_len (Tensor): [B]
text (Tensor): [B, U]
text_len (Tensor): [B]
Returns:
......
......@@ -28,7 +28,12 @@ from paddle import jit
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from paddle.nn.utils.rnn import pad_sequence
from deepspeech.modules.mask import make_pad_mask
from deepspeech.modules.mask import mask_finished_preds
from deepspeech.modules.mask import mask_finished_scores
from deepspeech.modules.mask import subsequent_mask
from deepspeech.modules.cmvn import GlobalCMVN
from deepspeech.modules.encoder import ConformerEncoder
......@@ -36,10 +41,6 @@ from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.modules.decoder import TransformerDecoder
from deepspeech.modules.label_smoothing_loss import LabelSmoothingLoss
from deepspeech.modules.mask import make_pad_mask
from deepspeech.modules.mask import mask_finished_preds
from deepspeech.modules.mask import mask_finished_scores
from deepspeech.modules.mask import subsequent_mask
from deepspeech.utils import checkpoint
from deepspeech.utils import layer_tools
......@@ -101,6 +102,8 @@ class U2Model(nn.Module):
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
Returns:
total_loss, attention_loss, ctc_loss
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
......@@ -109,21 +112,19 @@ class U2Model(nn.Module):
text.shape, text_lengths.shape)
# 1. Encoder
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
# 2a. Attention-decoder branch
loss_att = None
if self.ctc_weight != 1.0:
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
text, text_lengths)
else:
loss_att = None
# 2b. CTC branch
loss_ctc = None
if self.ctc_weight != 0.0:
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text,
text_lengths)
else:
loss_ctc = None
if loss_ctc is None:
loss = loss_att
......@@ -139,6 +140,17 @@ class U2Model(nn.Module):
encoder_mask: paddle.Tensor,
ys_pad: paddle.Tensor,
ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]:
"""Calc attention loss.
Args:
encoder_out (paddle.Tensor): [B, Tmax, D]
encoder_mask (paddle.Tensor): [B, 1, Tmax]
ys_pad (paddle.Tensor): [B, Umax]
ys_pad_lens (paddle.Tensor): [B]
Returns:
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
"""
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
self.ignore_id)
ys_in_lens = ys_pad_lens + 1
......@@ -163,6 +175,20 @@ class U2Model(nn.Module):
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Encoder pass.
Args:
speech (paddle.Tensor): [B, Tmax, D]
speech_lengths (paddle.Tensor): [B]
decoding_chunk_size (int, optional): chuck size. Defaults to -1.
num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1.
simulate_streaming (bool, optional): streaming or not. Defaults to False.
Returns:
Tuple[paddle.Tensor, paddle.Tensor]:
encoder hiddens (B, Tmax, D),
encoder hiddens mask (B, 1, Tmax).
"""
# Let's assume B = batch_size
# 1. Encoder
if simulate_streaming and decoding_chunk_size > 0:
......@@ -205,7 +231,7 @@ class U2Model(nn.Module):
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
device = speech.device
device = speech.place
batch_size = speech.shape[0]
# Let's assume B = batch_size and N = beam_size
......@@ -223,14 +249,14 @@ class U2Model(nn.Module):
1, beam_size, 1, 1).view(running_size, 1,
maxlen) # (B*N, 1, max_len)
hyps = torch.ones(
[running_size, 1], dtype=torch.long,
device=device).fill_(self.sos) # (B*N, 1)
scores = paddle.tensor(
[0.0] + [-float('inf')] * (beam_size - 1), dtype=torch.float)
hyps = paddle.ones(
[running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1)
# log scale score
scores = paddle.to_tensor(
[0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float)
scores = scores.to(device).repeat([batch_size]).unsqueeze(1).to(
device) # (B*N, 1)
end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device)
end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1)
cache: Optional[List[paddle.Tensor]] = None
# 2. Decoder forward step by step
for i in range(1, maxlen + 1):
......
......@@ -152,12 +152,12 @@ class TransformerDecoder(nn.Module):
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out, maxlen_out)
dtype=paddle.bool
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
y.shape` is (batch, token)
"""
x, _ = self.embed(tgt)
new_cache = []
......
......@@ -88,7 +88,10 @@ class LabelSmoothingLoss(nn.Layer):
size (int): the number of class
padding_idx (int): padding class id which will be ignored for loss
smoothing (float): smoothing rate (0.0 means the conventional CE)
normalize_length (bool): True, normalize loss by sequence length; False, normalize loss by batch size. Defaults to False.
normalize_length (bool):
True, normalize loss by sequence length;
False, normalize loss by batch size.
Defaults to False.
"""
super().__init__()
self.size = size
......@@ -103,6 +106,7 @@ class LabelSmoothingLoss(nn.Layer):
The model outputs and data labels tensors are flatten to
(batch*seqlen, class) shape and a mask is applied to the
padding part which should not be calculated for loss.
Args:
x (paddle.Tensor): prediction (batch, seqlen, class)
target (paddle.Tensor):
......
......@@ -50,6 +50,52 @@ def sequence_mask(x_len, max_len=None, dtype='float32'):
return mask
def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
paddle.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = int(lengths.shape[0])
max_len = int(lengths.max())
seq_range = paddle.arange(0, max_len, dtype=paddle.int64)
seq_range_expand = seq_range.unsqueeze(0).expand([batch_size, max_len])
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
"""Make mask tensor containing indices of non-padded part.
The sequences in a batch may have different lengths. To enable
batch computing, padding is need to make all sequence in same
size. To avoid the padding part pass value to context dependent
block such as attention or convolution , this padding part is
masked.
This pad_mask is used in both encoder and decoder.
1 for non-padded part and 0 for padded part.
Args:
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
paddle.Tensor: mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
return ~make_pad_mask(lengths)
def subsequent_mask(size: int) -> paddle.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
......@@ -170,52 +216,6 @@ def add_optional_chunk_mask(xs: paddle.Tensor,
return chunk_masks
def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
paddle.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = int(lengths.shape[0])
max_len = int(lengths.max())
seq_range = paddle.arange(0, max_len, dtype=paddle.int64)
seq_range_expand = seq_range.unsqueeze(0).expand([batch_size, max_len])
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
"""Make mask tensor containing indices of non-padded part.
The sequences in a batch may have different lengths. To enable
batch computing, padding is need to make all sequence in same
size. To avoid the padding part pass value to context dependent
block such as attention or convolution , this padding part is
masked.
This pad_mask is used in both encoder and decoder.
1 for non-padded part and 0 for padded part.
Args:
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
paddle.Tensor: mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
return ~make_pad_mask(lengths)
def mask_finished_scores(score: paddle.Tensor,
flag: paddle.Tensor) -> paddle.Tensor:
"""
......
......@@ -46,7 +46,7 @@ if __name__ == '__main__':
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
logits, probs, logits_len = model(audio, text, audio_len, text_len)
logits, probs, logits_len = model(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")
......@@ -58,7 +58,7 @@ if __name__ == '__main__':
rnn_size=1024,
use_gru=True,
share_rnn_weights=False, )
logits, probs, logits_len = model2(audio, text, audio_len, text_len)
logits, probs, logits_len = model2(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")
......@@ -70,7 +70,7 @@ if __name__ == '__main__':
rnn_size=1024,
use_gru=False,
share_rnn_weights=True, )
logits, probs, logits_len = model3(audio, text, audio_len, text_len)
logits, probs, logits_len = model3(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")
......@@ -82,7 +82,7 @@ if __name__ == '__main__':
rnn_size=1024,
use_gru=True,
share_rnn_weights=True, )
logits, probs, logits_len = model4(audio, text, audio_len, text_len)
logits, probs, logits_len = model4(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")
......@@ -94,6 +94,6 @@ if __name__ == '__main__':
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
logits, probs, logits_len = model5(audio, text, audio_len, text_len)
logits, probs, logits_len = model5(audio, audio_len, text, text_len)
print('probs.shape', probs.shape)
print("-----------------")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册