提交 1fb23815 编写于 作者: H Hui Zhang

add export function for decoder

上级 660efceb
......@@ -610,10 +610,11 @@ class U2Tester(U2Trainer):
input_spec=[
paddle.static.InputSpec(
shape=[1, decoder_max_time], dtype='int32'), # tgt
paddle.static.InputSpec(shape=[1], dtype='int32'), # tgt_len
paddle.static.InputSpec(
shape=[1, decoder_max_time], dtype='bool'), # tgt_mask
paddle.static.InputSpec(
shape=[1, encoder_max_time, encoder_model_size],
dtype='int32'), # encoder_out
dtype='float32'), # encoder_out
])
logger.info(f"Export code: {static_model}")
......
......@@ -927,3 +927,32 @@ class U2InferModel(U2Model):
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks,
simulate_streaming=simulate_streaming)
def forward_attention_decoder(
self,
hyps: paddle.Tensor,
hyps_masks: paddle.Tensor,
encoder_out: paddle.Tensor, ) -> paddle.Tensor:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
hyps (paddle.Tensor): hyps from ctc prefix beam search, already
pad sos at the begining, (B, U)
hyps_masks (paddle.Tensor): length of each hyp in hyps, (B, U)
encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D)
Returns:
paddle.Tensor: decoder output, (B, V)
"""
assert encoder_out.shape[0] == 1
num_hyps = hyps.shape[0]
assert hyps_masks.shape[0] == num_hyps
# encoder_out = encoder_out.repeat(num_hyps, 1, 1)
encoder_out = encoder_out.tile([num_hyps, 1, 1])
# (B, 1, T)
encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
# (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder.export(encoder_out, encoder_mask, hyps,
hyps_masks)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1)
return decoder_out
......@@ -185,3 +185,45 @@ class TransformerDecoder(nn.Module):
y = paddle.log_softmax(self.output_layer(y), axis=-1)
return y, new_cache
def export(
self,
memory: paddle.Tensor,
memory_mask: paddle.Tensor,
ys_in_pad: paddle.Tensor,
ys_in_mask: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_mask: input mask of this batch (batch, maxlen_out)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, vocab_size)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
tgt_mask = ys_in_mask.unsqueeze(1)
# m: (1, L, L)
m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0)
# tgt_mask: (B, L, L)
# TODO(Hui Zhang): not support & for tensor
# tgt_mask = tgt_mask & m
tgt_mask = tgt_mask.logical_and(m)
x, _ = self.embed(tgt)
for layer in self.decoders:
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.use_output_layer:
x = self.output_layer(x)
# TODO(Hui Zhang): reduce_sum not support bool type
# olens = tgt_mask.sum(1)
olens = tgt_mask.astype(paddle.int).sum(1)
return x, olens
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册