提交 45663511 编写于 作者: H Hui Zhang

add transformer lm and encoder score api

上级 c5f66921
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import List
from typing import Tuple
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from deepspeech.modules.mask import subsequent_mask
from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
#LMInterface
class TransformerLM(nn.Layer, BatchScorerInterface):
def __init__(
self,
n_vocab: int,
pos_enc: str=None,
embed_unit: int=128,
att_unit: int=256,
head: int=2,
unit: int=1024,
layer: int=4,
dropout_rate: float=0.5,
emb_dropout_rate: float = 0.0,
att_dropout_rate: float = 0.0,
tie_weights: bool = False,):
nn.Layer.__init__(self)
if pos_enc == "sinusoidal":
pos_enc_layer_type = "abs_pos"
elif pos_enc is None:
#TODO
pos_enc_layer_type = "None"
else:
raise ValueError(f"unknown pos-enc option: {pos_enc}")
self.embed = nn.Embedding(n_vocab, embed_unit)
if emb_dropout_rate == 0.0:
self.embed_drop = None
else:
self.embed_drop = nn.Dropout(emb_dropout_rate)
self.encoder = TransformerEncoder(
input_size=embed_unit,
output_size=att_unit,
attention_heads=head,
linear_units=unit,
num_blocks=layer,
dropout_rate=dropout_rate,
attention_dropout_rate=att_dropout_rate,
input_layer="linear",
pos_enc_layer_type=pos_enc_layer_type,
concat_after=False,
static_chunk_size=1,
use_dynamic_chunk=False,
use_dynamic_left_chunk=False)
self.decoder = nn.Linear(att_unit, n_vocab)
logging.info("Tie weights set to {}".format(tie_weights))
logging.info("Dropout set to {}".format(dropout_rate))
logging.info("Emb Dropout set to {}".format(emb_dropout_rate))
logging.info("Att Dropout set to {}".format(att_dropout_rate))
if tie_weights:
assert (
att_unit == embed_unit
), "Tie Weights: True need embedding and final dimensions to match"
self.decoder.weight = self.embed.weight
def _target_mask(self, ys_in_pad):
ys_mask = ys_in_pad != 0
m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0)
return ys_mask.unsqueeze(-2) & m
def forward(
self, x: paddle.Tensor, xlens, t: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute LM loss value from buffer sequences.
Args:
x (paddle.Tensor): Input ids. (batch, len)
t (paddle.Tensor): Target ids. (batch, len)
Returns:
tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
xm = x != 0
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(x))
else:
emb = self.embed(x)
xlen = xm.sum(axis=1)
h, _ = self.encoder(emb, xlen)
y = self.decoder(h)
loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
mask = xm.to(dtype=loss.dtype)
logp = loss * mask.view(-1)
logp = logp.sum()
count = mask.sum()
return logp / count, logp, count
# beam search API (see ScorerInterface)
def score(self, y: paddle.Tensor, state: Any,
x: paddle.Tensor) -> Tuple[paddle.Tensor, Any]:
"""Score new token.
Args:
y (paddle.Tensor): 1D paddle.int64 prefix tokens.
state: Scorer state for prefix tokens
x (paddle.Tensor): encoder feature that generates ys.
Returns:
tuple[paddle.Tensor, Any]: Tuple of
paddle.float32 scores for next token (n_vocab)
and next state for ys
"""
y = y.unsqueeze(0)
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(y))
else:
emb = self.embed(y)
h, _, cache = self.encoder.forward_one_step(
emb, self._target_mask(y), cache=state
)
h = self.decoder(h[:, -1])
logp = h.log_softmax(axis=-1).squeeze(0)
return logp, cache
# batch beam search API (see BatchScorerInterface)
def batch_score(
self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor
) -> Tuple[paddle.Tensor, List[Any]]:
"""Score new token batch (required).
Args:
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (paddle.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[paddle.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
paddle.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(ys))
else:
emb = self.embed(ys)
# batch decoding
h, _, states = self.encoder.forward_one_step(
emb, self._target_mask(ys), cache=batch_state
)
h = self.decoder(h[:, -1])
logp = h.log_softmax(axi=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
if __name__ == "__main__":
tlm = TransformerLM(
n_vocab=5002,
pos_enc=None,
embed_unit=128,
att_unit=512,
head=8,
unit=2048,
layer=16,
dropout_rate=0.5, )
# n_vocab: int,
# pos_enc: str=None,
# embed_unit: int=128,
# att_unit: int=256,
# head: int=2,
# unit: int=1024,
# layer: int=4,
# dropout_rate: float=0.5,
# emb_dropout_rate: float = 0.0,
# att_dropout_rate: float = 0.0,
# tie_weights: bool = False,):
paddle.set_device("cpu")
model_dict = paddle.load("transformerLM.pdparams")
tlm.set_state_dict(model_dict)
tlm.eval()
#Test the score
input2 = np.array([5])
input2 = paddle.to_tensor(input2)
state = (None, None, 0)
output, state = tlm.score(input2, state, None)
input3 = np.array([10])
input3 = paddle.to_tensor(input3)
output, state = tlm.score(input3, state, None)
input4 = np.array([0])
input4 = paddle.to_tensor(input4)
output, state = tlm.score(input4, state, None)
print("output", output)
"""
#Test the batch score
batch_size = 2
inp2 = np.array([[5], [10]])
inp2 = paddle.to_tensor(inp2)
output, states = tlm.batch_score(
inp2, [(None,None,0)] * batch_size)
inp3 = np.array([[100], [30]])
inp3 = paddle.to_tensor(inp3)
output, states = tlm.batch_score(
inp3, states)
print("output", output)
#print("cache", cache)
#np.save("output_pd.npy", output)
"""
\ No newline at end of file
......@@ -31,6 +31,7 @@ from deepspeech.modules.encoder_layer import TransformerEncoderLayer
from deepspeech.modules.mask import add_optional_chunk_mask
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward
from deepspeech.modules.subsampling import Conv2dSubsampling
from deepspeech.modules.subsampling import Conv2dSubsampling4
from deepspeech.modules.subsampling import Conv2dSubsampling6
from deepspeech.modules.subsampling import Conv2dSubsampling8
......@@ -370,6 +371,46 @@ class TransformerEncoder(BaseEncoder):
concat_after=concat_after) for _ in range(num_blocks)
])
def forward_one_step(
self,
xs: paddle.Tensor,
masks: paddle.Tensor,
cache=None,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Encode input frame.
Args:
xs (paddle.Tensor): Input tensor. (B, T, D)
masks (paddle.Tensor): Mask tensor. (B, 1, T)
cache (List[paddle.Tensor]): List of cache tensors.
Returns:
paddle.Tensor: Output tensor.
paddle.Tensor: Mask tensor.
List[paddle.Tensor]: List of new cache tensors.
"""
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
if isinstance(self.embed, Conv2dSubsampling):
# xs, masks = self.embed(xs, masks)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
else:
xs = self.embed(xs)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
if cache is None:
cache = [None for _ in range(len(self.encoders))]
new_cache = []
for c, e in zip(cache, self.encoders):
xs, masks, _ = e(xs, masks, output_cache=c)
new_cache.append(xs)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks, new_cache
class ConformerEncoder(BaseEncoder):
"""Conformer encoder module."""
......
......@@ -71,7 +71,7 @@ class TransformerEncoderLayer(nn.Layer):
self,
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
pos_emb: Optional[paddle.Tensor]=None,
mask_pad: Optional[paddle.Tensor]=None,
output_cache: Optional[paddle.Tensor]=None,
cnn_cache: Optional[paddle.Tensor]=None,
......@@ -82,8 +82,8 @@ class TransformerEncoderLayer(nn.Layer):
mask (paddle.Tensor): Mask tensor for the input (#batch, time).
pos_emb (paddle.Tensor): just for interface compatibility
to ConformerEncoderLayer
mask_pad (paddle.Tensor): does not used in transformer layer,
just for unified api with conformer.
mask_pad (paddle.Tensor): not used here, it's for interface
compatibility to ConformerEncoderLayer
output_cache (paddle.Tensor): Cache tensor of the output
(#batch, time2, size), time2 < time in x.
cnn_cache (paddle.Tensor): not used here, it's for interface
......
......@@ -82,8 +82,11 @@ class LinearNoSubsampling(BaseSubsampling):
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask
class Conv2dSubsampling(BaseSubsampling):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class Conv2dSubsampling4(BaseSubsampling):
class Conv2dSubsampling4(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/4 length)."""
def __init__(self,
......@@ -134,7 +137,7 @@ class Conv2dSubsampling4(BaseSubsampling):
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
class Conv2dSubsampling6(BaseSubsampling):
class Conv2dSubsampling6(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/6 length)."""
def __init__(self,
......@@ -187,7 +190,7 @@ class Conv2dSubsampling6(BaseSubsampling):
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3]
class Conv2dSubsampling8(BaseSubsampling):
class Conv2dSubsampling8(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/8 length)."""
def __init__(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册