提交 4c838379 编写于 作者: H Hui Zhang

size to shape; repeat to tile

上级 94918305
...@@ -532,7 +532,7 @@ class U2Tester(U2Trainer): ...@@ -532,7 +532,7 @@ class U2Tester(U2Trainer):
# 1. Encoder # 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder( encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim) feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1) maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax( ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size) encoder_out) # (1, maxlen, vocab_size)
...@@ -598,10 +598,20 @@ class U2Tester(U2Trainer): ...@@ -598,10 +598,20 @@ class U2Tester(U2Trainer):
def export(self): def export(self):
infer_model, input_spec = self.load_inferspec() infer_model, input_spec = self.load_inferspec()
assert isinstance(input_spec, list), type(input_spec) # assert isinstance(input_spec, list), type(input_spec)
infer_model.eval() infer_model.eval()
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec) #static_model = paddle.jit.to_static(infer_model., input_spec=input_spec)
logger.info(f"Export code: {static_model.forward.code}")
static_model = paddle.jit.to_static(
infer_model.forward_attention_decoder,
input_spec=[
paddle.static.InputSpec(shape=[1, None],dtype='int32'),
paddle.static.InputSpec(shape=[1],dtype='int32'),
paddle.static.InputSpec(shape=[1, None, 256],dtype='int32'),
]
)
logger.info(f"Export code: {static_model}")
paddle.jit.save(static_model, self.args.export_path) paddle.jit.save(static_model, self.args.export_path)
def run_export(self): def run_export(self):
......
...@@ -299,8 +299,8 @@ class U2BaseModel(nn.Module): ...@@ -299,8 +299,8 @@ class U2BaseModel(nn.Module):
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim) simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1) maxlen = encoder_out.shape[1]
encoder_dim = encoder_out.size(2) encoder_dim = encoder_out.shape[2]
running_size = batch_size * beam_size running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
...@@ -405,7 +405,7 @@ class U2BaseModel(nn.Module): ...@@ -405,7 +405,7 @@ class U2BaseModel(nn.Module):
encoder_out, encoder_mask = self._forward_encoder( encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming) num_decoding_left_chunks, simulate_streaming)
maxlen = encoder_out.size(1) maxlen = encoder_out.shape[1]
# (TODO Hui Zhang): bool no support reduce_sum # (TODO Hui Zhang): bool no support reduce_sum
# encoder_out_lens = encoder_mask.squeeze(1).sum(1) # encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1) encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1)
...@@ -455,7 +455,7 @@ class U2BaseModel(nn.Module): ...@@ -455,7 +455,7 @@ class U2BaseModel(nn.Module):
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim) simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1) maxlen = encoder_out.shape[1]
ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
...@@ -578,7 +578,7 @@ class U2BaseModel(nn.Module): ...@@ -578,7 +578,7 @@ class U2BaseModel(nn.Module):
hyps_lens = hyps_lens + 1 # Add <sos> at begining hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones( encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.size(1)), dtype=paddle.bool) (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.decoder( decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size) hyps_lens) # (beam_size, max_hyps_len, vocab_size)
...@@ -624,7 +624,7 @@ class U2BaseModel(nn.Module): ...@@ -624,7 +624,7 @@ class U2BaseModel(nn.Module):
""" """
return self.eos return self.eos
@jit.export # @jit.export
def forward_encoder_chunk( def forward_encoder_chunk(
self, self,
xs: paddle.Tensor, xs: paddle.Tensor,
...@@ -654,9 +654,7 @@ class U2BaseModel(nn.Module): ...@@ -654,9 +654,7 @@ class U2BaseModel(nn.Module):
xs, offset, required_cache_size, subsampling_cache, xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache) elayers_output_cache, conformer_cnn_cache)
# @jit.export([ # @jit.export
# paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D]
# ])
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log """ Export interface for c++ call, apply linear transform and log
softmax before ctc softmax before ctc
...@@ -667,7 +665,7 @@ class U2BaseModel(nn.Module): ...@@ -667,7 +665,7 @@ class U2BaseModel(nn.Module):
""" """
return self.ctc.log_softmax(xs) return self.ctc.log_softmax(xs)
@jit.export # @jit.export
def forward_attention_decoder( def forward_attention_decoder(
self, self,
hyps: paddle.Tensor, hyps: paddle.Tensor,
...@@ -683,13 +681,14 @@ class U2BaseModel(nn.Module): ...@@ -683,13 +681,14 @@ class U2BaseModel(nn.Module):
Returns: Returns:
paddle.Tensor: decoder output, (B, L) paddle.Tensor: decoder output, (B, L)
""" """
assert encoder_out.size(0) == 1 assert encoder_out.shape[0] == 1
num_hyps = hyps.size(0) num_hyps = hyps.shape[0]
assert hyps_lens.size(0) == num_hyps assert hyps_lens.shape[0] == num_hyps
encoder_out = encoder_out.repeat(num_hyps, 1, 1) # encoder_out = encoder_out.repeat(num_hyps, 1, 1)
encoder_out = encoder_out.tile([num_hyps, 1, 1])
# (B, 1, T) # (B, 1, T)
encoder_mask = paddle.ones( encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool) [num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
# (num_hyps, max_hyps_len, vocab_size) # (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens) hyps_lens)
...@@ -744,7 +743,7 @@ class U2BaseModel(nn.Module): ...@@ -744,7 +743,7 @@ class U2BaseModel(nn.Module):
Returns: Returns:
List[List[int]]: transcripts. List[List[int]]: transcripts.
""" """
batch_size = feats.size(0) batch_size = feats.shape[0]
if decoding_method in ['ctc_prefix_beam_search', if decoding_method in ['ctc_prefix_beam_search',
'attention_rescoring'] and batch_size > 1: 'attention_rescoring'] and batch_size > 1:
logger.fatal( logger.fatal(
...@@ -772,7 +771,7 @@ class U2BaseModel(nn.Module): ...@@ -772,7 +771,7 @@ class U2BaseModel(nn.Module):
# result in List[int], change it to List[List[int]] for compatible # result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode # with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search': elif decoding_method == 'ctc_prefix_beam_search':
assert feats.size(0) == 1 assert feats.shape[0] == 1
hyp = self.ctc_prefix_beam_search( hyp = self.ctc_prefix_beam_search(
feats, feats,
feats_lengths, feats_lengths,
...@@ -782,7 +781,7 @@ class U2BaseModel(nn.Module): ...@@ -782,7 +781,7 @@ class U2BaseModel(nn.Module):
simulate_streaming=simulate_streaming) simulate_streaming=simulate_streaming)
hyps = [hyp] hyps = [hyp]
elif decoding_method == 'attention_rescoring': elif decoding_method == 'attention_rescoring':
assert feats.size(0) == 1 assert feats.shape[0] == 1
hyp = self.attention_rescoring( hyp = self.attention_rescoring(
feats, feats,
feats_lengths, feats_lengths,
...@@ -922,7 +921,7 @@ class U2InferModel(U2Model): ...@@ -922,7 +921,7 @@ class U2InferModel(U2Model):
Returns: Returns:
List[List[int]]: best path result List[List[int]]: best path result
""" """
return self.ctc_greedy_search( return self.attention_rescoring(
feats, feats,
feats_lengths, feats_lengths,
decoding_chunk_size=decoding_chunk_size, decoding_chunk_size=decoding_chunk_size,
......
...@@ -70,7 +70,7 @@ class MultiHeadedAttention(nn.Layer): ...@@ -70,7 +70,7 @@ class MultiHeadedAttention(nn.Layer):
paddle.Tensor: Transformed value tensor, size paddle.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k). (#batch, n_head, time2, d_k).
""" """
n_batch = query.size(0) n_batch = query.shape[0]
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
...@@ -96,7 +96,7 @@ class MultiHeadedAttention(nn.Layer): ...@@ -96,7 +96,7 @@ class MultiHeadedAttention(nn.Layer):
paddle.Tensor: Transformed value weighted paddle.Tensor: Transformed value weighted
by the attention score, (#batch, time1, d_model). by the attention score, (#batch, time1, d_model).
""" """
n_batch = value.size(0) n_batch = value.shape[0]
if mask is not None: if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf')) scores = scores.masked_fill(mask, -float('inf'))
...@@ -172,15 +172,15 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -172,15 +172,15 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
paddle.Tensor: Output tensor. (batch, head, time1, time1) paddle.Tensor: Output tensor. (batch, head, time1, time1)
""" """
zero_pad = paddle.zeros( zero_pad = paddle.zeros(
(x.size(0), x.size(1), x.size(2), 1), dtype=x.dtype) (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1) x_padded = paddle.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2)) x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]
if zero_triu: if zero_triu:
ones = paddle.ones((x.size(2), x.size(3))) ones = paddle.ones((x.shape[2], x.shape[3]))
x = x * paddle.tril(ones, x.size(3) - x.size(2))[None, None, :, :] x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :]
return x return x
...@@ -205,7 +205,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -205,7 +205,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
q, k, v = self.forward_qkv(query, key, value) q, k, v = self.forward_qkv(query, key, value)
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0) n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
......
...@@ -122,7 +122,7 @@ class TransformerDecoder(nn.Module): ...@@ -122,7 +122,7 @@ class TransformerDecoder(nn.Module):
# tgt_mask: (B, 1, L) # tgt_mask: (B, 1, L)
tgt_mask = (make_non_pad_mask(ys_in_lens).unsqueeze(1)) tgt_mask = (make_non_pad_mask(ys_in_lens).unsqueeze(1))
# m: (1, L, L) # m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1)).unsqueeze(0) m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0)
# tgt_mask: (B, L, L) # tgt_mask: (B, L, L)
# TODO(Hui Zhang): not support & for tensor # TODO(Hui Zhang): not support & for tensor
# tgt_mask = tgt_mask & m # tgt_mask = tgt_mask & m
......
...@@ -68,7 +68,7 @@ class PositionalEncoding(nn.Layer): ...@@ -68,7 +68,7 @@ class PositionalEncoding(nn.Layer):
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...) paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
""" """
T = x.shape[1] T = x.shape[1]
assert offset + x.size(1) < self.max_len assert offset + x.shape[1] < self.max_len
#TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor #TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + T] pos_emb = self.pe[:, offset:offset + T]
x = x * self.xscale + pos_emb x = x * self.xscale + pos_emb
...@@ -114,7 +114,7 @@ class RelPositionalEncoding(PositionalEncoding): ...@@ -114,7 +114,7 @@ class RelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`). paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`). paddle.Tensor: Positional embedding tensor (1, time, `*`).
""" """
assert offset + x.size(1) < self.max_len assert offset + x.shape[1] < self.max_len
x = x * self.xscale x = x * self.xscale
#TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor #TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + x.shape[1]] pos_emb = self.pe[:, offset:offset + x.shape[1]]
......
...@@ -65,11 +65,11 @@ def pad_sequence(sequences: List[paddle.Tensor], ...@@ -65,11 +65,11 @@ def pad_sequence(sequences: List[paddle.Tensor],
# assuming trailing dimensions and type of all the Tensors # assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0] # in sequences are same and fetching those from sequences[0]
max_size = sequences[0].size() max_size = sequences[0].shape
# (TODO Hui Zhang): slice not supprot `end==start` # (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:] # trailing_dims = max_size[1:]
trailing_dims = max_size[1:] if max_size.ndim >= 2 else () trailing_dims = max_size[1:] if max_size.ndim >= 2 else ()
max_len = max([s.size(0) for s in sequences]) max_len = max([s.shape[0] for s in sequences])
if batch_first: if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims out_dims = (len(sequences), max_len) + trailing_dims
else: else:
...@@ -77,7 +77,7 @@ def pad_sequence(sequences: List[paddle.Tensor], ...@@ -77,7 +77,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
out_tensor = sequences[0].new_full(out_dims, padding_value) out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences): for i, tensor in enumerate(sequences):
length = tensor.size(0) length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor # use index notation to prevent duplicate references to the tensor
if batch_first: if batch_first:
out_tensor[i, :length, ...] = tensor out_tensor[i, :length, ...] = tensor
...@@ -125,7 +125,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, ...@@ -125,7 +125,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys] #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys] #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id) #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
B = ys_pad.size(0) B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
ys_in = paddle.cat([_sos, ys_pad], dim=1) ys_in = paddle.cat([_sos, ys_pad], dim=1)
...@@ -152,7 +152,7 @@ def th_accuracy(pad_outputs: paddle.Tensor, ...@@ -152,7 +152,7 @@ def th_accuracy(pad_outputs: paddle.Tensor,
float: Accuracy value (0.0 - 1.0). float: Accuracy value (0.0 - 1.0).
""" """
pad_pred = pad_outputs.view( pad_pred = pad_outputs.view(
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2) pad_targets.shape[0], pad_targets.size(1), pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label mask = pad_targets != ignore_label
#TODO(Hui Zhang): sum not support bool type #TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum( # numerator = paddle.sum(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册