...
 
Commits (12)
    https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/03e9ea9e52d61ffd4420bf9bfdc2f422752ad29c add roformer 2023-07-12T08:59:00+00:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/55870ffbb3581af4a0b7aed61a168f80f0f045fb fix bugs 2023-07-12T09:36:13+00:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/3b6b680771c454151f5ac99013bbc934e967f703 add roformer result 2023-07-12T11:24:48+00:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/b91b1c9b083002fb716c60c29adf3a20e51262e1 support position interpolation for langer attention context windown length. 2023-07-13T03:58:31+00:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/b56fb85ca08e7a24b9fc9f8859d9c5a472b553fa RoPE with position interpolation 2023-07-14T04:45:24+00:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/0a5cc5556e602a202304860d5221f9b573582196 rope for streaming decoding 2023-07-14T07:37:55+00:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/596f71407cd0daa8f1d3e1edd60ce44300d37413 update result 2023-07-17T02:49:07+00:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/d94db47f784d30b4c8b07c5f2a44c82cc4c7f24f fix rotary embeding 2023-07-17T03:13:24+00:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/897dcc37e65fa3260d72045afd58f79741240191 Merge pull request #3407 from zh794390558/roformer 2023-07-20T10:44:00+08:00 Hui Zhang zhtclz@foxmail.com Roformer https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/5d10d6e884d0d7eccf4f1724c3ee64bf70a25aaa Update README.md 2023-07-21T16:51:44+08:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/2faa49a39fcd810a2c896f61006c6e9958a5e85c fix weight decay 2023-07-25T02:38:36+00:00 Hui Zhang zhtclz@foxmail.com https://gitcode.net/paddlepaddle/DeepSpeech/-/commit/a1745944657feaaae2fe22201aedd6d42b0a536e Merge pull request #3424 from zh794390558/fix_opt 2023-07-26T09:55:52+08:00 Hui Zhang zhtclz@foxmail.com fix weight decay
......@@ -893,10 +893,6 @@ The Text-to-Speech module is originally called [Parakeet](https://github.com/Pad
- **[VTuberTalk](https://github.com/jerryuhoo/VTuberTalk): Use PaddleSpeech TTS and ASR to clone voice from videos.**
<div align="center">
<img src="https://raw.githubusercontent.com/jerryuhoo/VTuberTalk/main/gui/gui.png" width = "500px" />
</div>
## Citation
......
# Aishell
## Conformer
paddle version: 2.2.2
paddlespeech version: 1.0.1
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0522 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0481 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug| test | ctc_prefix_beam_search | - | 0.0480 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0460 |
## RoFormer Streaming
paddle version: 2.5.0
paddlespeech version: 1.5.0
Tesla V100-SXM2-32GB: 1 node, 4 card
Global BachSize: 32 * 4
Training Done: 1 day, 12:56:39.639646
### `decoding.decoding_chunk_size=16`
> chunk_size=16, ((16 - 1) * 4 + 7) * 10ms = (16 * 4 + 3) * 10ms = 670ms
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention | 16, -1 | - | 5.63 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_greedy_search | 16, -1 | - | 6.13 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_prefix_beam_search | 16, -1 | - | 6.13 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 5.44 |
### `decoding.decoding_chunk_size=-1`
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention | -1, -1 | - | 5.39 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_greedy_search | -1, -1 | - | 5.51 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | ctc_prefix_beam_search | -1, -1 | - | 5.51 |
| roformer | 44.80M | conf/chunk_roformer.yaml | spec_aug | test | attention_rescoring | -1, -1 | - | 4.99 |
## Conformer Streaming
......@@ -24,6 +41,17 @@ Need set `decoding.decoding_chunk_size=16` when decoding.
| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug | test | attention_rescoring | 16, -1 | - | 0.051968 |
## Conformer
paddle version: 2.2.2
paddlespeech version: 1.0.1
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention | - | 0.0522 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_greedy_search | - | 0.0481 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | ctc_prefix_beam_search | - | 0.0480 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug | test | attention_rescoring | - | 0.0460 |
## Transformer
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | CER |
......
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: transformer # transformer, bitransformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
r_num_blocks: 0 # only for bitransformer
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.0 # only for bitransformer
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 1
global_grad_clip: 5.0
dist_sampler: True
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: bitransformer # transformer, bitransformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 3
r_num_blocks: 3 # only for bitransformer
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.3 # only for bitransformer
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 1
global_grad_clip: 5.0
dist_sampler: True
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
......@@ -20,30 +20,6 @@ import numpy as np
import paddle
def define_argparse():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument(
'--ckpt_dir', required=True, help='ckpt model dir for average')
parser.add_argument(
'--val_best', action="store_true", help='averaged model')
parser.add_argument(
'--num', default=5, type=int, help='nums for averaged model')
parser.add_argument(
'--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument(
'--max_epoch',
default=65536, # Big enough
type=int,
help='max epoch used for averaging model')
args = parser.parse_args()
return args
def average_checkpoints(dst_model="",
ckpt_dir="",
val_best=True,
......@@ -85,7 +61,7 @@ def average_checkpoints(dst_model="",
print(path_list)
avg = None
num = args.num
num = num
assert num == len(path_list)
for path in path_list:
print(f'Processing {path}')
......@@ -100,14 +76,14 @@ def average_checkpoints(dst_model="",
if avg[k] is not None:
avg[k] /= num
paddle.save(avg, args.dst_model)
print(f'Saving to {args.dst_model}')
paddle.save(avg, dst_model)
print(f'Saving to {dst_model}')
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json'
meta_path = os.path.splitext(dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f:
data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model,
"mode": 'val_best' if val_best else 'latest',
"avg_ckpt": dst_model,
"val_loss_mean": avg_val_score,
"ckpts": path_list,
"epochs": selected_epochs.tolist(),
......@@ -116,9 +92,40 @@ def average_checkpoints(dst_model="",
f.write(data + "\n")
def define_argparse():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument(
'--ckpt_dir', required=True, help='ckpt model dir for average')
parser.add_argument(
'--val_best', action="store_true", help='averaged model')
parser.add_argument(
'--num', default=5, type=int, help='nums for averaged model')
parser.add_argument(
'--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument(
'--max_epoch',
default=65536, # Big enough
type=int,
help='max epoch used for averaging model')
args = parser.parse_args()
print(args)
return args
def main():
args = define_argparse()
average_checkpoints(args)
average_checkpoints(
dst_model=args.dst_model,
ckpt_dir=args.ckpt_dir,
val_best=args.val_best,
num=args.num,
min_epoch=args.min_epoch,
max_epoch=args.max_epoch)
if __name__ == '__main__':
......
......@@ -145,7 +145,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
text_lengths)
ctc_time = time.time() - start
#logger.debug(f"ctc time: {ctc_time}")
if loss_ctc is None:
loss = loss_att
elif loss_att is None:
......@@ -916,6 +915,8 @@ class U2Model(U2DecodeModel):
decoder_type = configs.get('decoder', 'transformer')
logger.debug(f"U2 Decoder type: {decoder_type}")
if decoder_type == 'transformer':
configs['model_conf'].pop('reverse_weight', None)
configs['decoder_conf'].pop('r_num_blocks', None)
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
......
......@@ -15,6 +15,7 @@
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Multi-Head Attention layer definition."""
import math
from typing import List
from typing import Tuple
import paddle
......@@ -26,7 +27,10 @@ from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"]
__all__ = [
"MultiHeadedAttention", "RelPositionMultiHeadedAttention",
"RoPERelPositionMultiHeadedAttention"
]
# Relative Positional Encodings
# https://www.jianshu.com/p/c0608efcc26f
......@@ -165,6 +169,7 @@ class MultiHeadedAttention(nn.Layer):
and `head * d_k == size`
"""
# (B,T,D) -> (B,T,H,D/H)
q, k, v = self.forward_qkv(query, key, value)
# when export onnx model, for 1st chunk, we feed
......@@ -373,3 +378,139 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask), new_cache
class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with RoPE relative position encoding."""
def __init__(self,
n_head,
n_feat,
dropout_rate,
adaptive_scale=False,
init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
super().__init__(n_head, n_feat, dropout_rate)
def align(self, tensor: paddle.Tensor, axes: List[int], ndim=None):
"""重新对齐tensor(批量版expand_dims)
axes:原来的第i维对齐新tensor的第axes[i]维;
ndim:新tensor的维度。
"""
assert len(axes) == tensor.dim()
assert ndim or min(axes) >= 0
ndim = ndim or max(axes) + 1
# a[0, None, 1] = a[0, np.newaxis, 1]
indices = [None] * ndim
for i in axes:
# slice nothing, a[0, slice(None), 1] = a[0, :, 1]
indices[i] = slice(None)
return tensor[indices]
def apply_rotary_position_embeddings(self, sinusoidal, *tensors):
"""应用RoPE到tensors中
其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而
tensor.shape=[B, T, ..., D], or (B,H,T,D/H)
"""
assert len(tensors) > 0, 'at least one input tensor'
assert all(
[tensor.shape == tensors[0].shape
for tensor in tensors[1:]]), 'all tensors must have the same shape'
# (B,H,T,D)
ndim = tensors[0].dim()
_, H, T, D = tensors[0].shape
# sinusoidal shape same with tensors[0]
# [B,T,D] -> [B,T,H,D/H] -> (B,H,T,D/H)
# sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
sinusoidal = sinusoidal.reshape((1, T, H, D)).transpose([0, 2, 1, 3])
# http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,T, ..., d/2] -> [b,T, ..., d]
cos_pos = paddle.repeat_interleave(sinusoidal[..., 1::2], 2, axis=-1)
sin_pos = paddle.repeat_interleave(sinusoidal[..., 0::2], 2, axis=-1)
outputs = []
for tensor in tensors:
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
tensor2 = paddle.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim)
tensor2 = paddle.reshape(tensor2, paddle.shape(tensor))
# 公式 34, out = x * cos_pos + x2 * sin_pos
outputs.append(tensor * cos_pos + tensor2 * sin_pos)
return outputs[0] if len(outputs) == 1 else outputs
def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Ref: https://github.com/facebookresearch/llama/blob/main/llama/model.py
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
# q_t always is chunk_size
q_t = q.shape[2]
q = self.apply_rotary_position_embeddings(pos_emb[:, -q_t:, :], q)
# k will increase when in streaming decoding.
k = self.apply_rotary_position_embeddings(pos_emb[:, -q_t:, :], k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.shape[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
# dot(q, k)
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
......@@ -85,18 +85,21 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
reverse (bool, optional): Not used. Defaults to False.
"""
nn.Layer.__init__(self)
self.d_model = d_model
self.d_model = paddle.to_tensor(d_model)
self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
self.dropout = nn.Dropout(p=dropout_rate)
self.base = paddle.to_tensor(10000.0)
self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D]
position = paddle.arange(
0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1]
# base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model))
-paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
(paddle.log(self.base) / self.d_model))
# [B,T,D]
self.pe[:, :, 0::2] = paddle.sin(position * div_term)
self.pe[:, :, 1::2] = paddle.cos(position * div_term)
......@@ -161,6 +164,98 @@ class RelPositionalEncoding(PositionalEncoding):
assert offset + x.shape[
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len)
x = x * self.xscale
pos_emb = self.pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb)
# RotaryRelPositionalEncoding is same to RelPositionalEncoding
class ScaledRotaryRelPositionalEncoding(RelPositionalEncoding):
"""Scaled Rotary Relative positional encoding module.
POSITION INTERPOLATION: : https://arxiv.org/pdf/2306.15595v2.pdf
"""
def __init__(self,
d_model: int,
dropout_rate: float,
max_len: int=5000,
scale=1):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
scale (int): Interpolation max input length to `scale * max_len` positions.
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
self.pscale = paddle.to_tensor(scale)
self.max_len = max_len * scale
def sinusoidal_embeddings(self,
pos: paddle.Tensor,
dim: paddle.Tensor,
base=10000) -> paddle.Tensor:
"""计算pos位置的dim维sinusoidal编码"""
assert dim % 2 == 0
# (d/2,)
indices = paddle.arange(0, dim // 2, dtype=pos.dtype)
indices = paddle.pow(paddle.cast(base, pos.dtype), -2 * indices / dim)
# pos (1, T), indices (d/2,) -> (1, T, d/2)
embeddings = paddle.einsum('...,d->...d', pos, indices)
# (1, T, d/2, 2)
embeddings = paddle.stack(
[paddle.sin(embeddings), paddle.cos(embeddings)], axis=-1)
# (1, T, d)
embeddings = paddle.flatten(embeddings, start_axis=-2, stop_axis=-1)
return embeddings
def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
x = x * self.xscale
B, T, D = x.shape
assert D == self.d_model
# postion interploation
start = 0
end = T * self.pscale
assert end <= self.max_len
position = paddle.arange(start, end, dtype=x.dtype).unsqueeze(0)
position *= 1.0 / self.pscale
pe = self.sinusoidal_embeddings(position, self.d_model, base=self.base)
pos_emb = pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb)
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int): start offset
size (int): requried size of position encoding
Returns:
paddle.Tensor: Corresponding position encoding, #[1, T, D].
"""
# postion interploation
start = offset
end = (offset + size) * self.pscale
assert end <= self.max_len
position = paddle.arange(
start, end, dtype=paddle.get_default_dtype()).unsqueeze(0)
position *= 1.0 / self.pscale
pe = self.sinusoidal_embeddings(position, self.d_model, base=self.base)
return self.dropout(pe)
......@@ -28,6 +28,7 @@ from paddlespeech.s2t.modules.align import LayerNorm
from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.modules.attention import MultiHeadedAttention
from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention
from paddlespeech.s2t.modules.attention import RoPERelPositionMultiHeadedAttention
from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule
from paddlespeech.s2t.modules.embedding import NoPositionalEncoding
from paddlespeech.s2t.modules.embedding import PositionalEncoding
......@@ -115,6 +116,8 @@ class BaseEncoder(nn.Layer):
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "rope_pos":
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "no_pos":
pos_enc_class = NoPositionalEncoding
else:
......@@ -230,14 +233,14 @@ class BaseEncoder(nn.Layer):
xs = self.global_cmvn(xs)
# before embed, xs=(B, T, D1), pos_emb=(B=1, T, D)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
xs, _, _ = self.embed(xs, tmp_masks, offset=offset)
# after embed, xs=(B=1, chunk_size, hidden-dim)
elayers, _, cache_t1, _ = att_cache.shape
chunk_size = xs.shape[1]
attention_key_size = cache_t1 + chunk_size
# only used when using `RelPositionMultiHeadedAttention`
# only used when using `RelPositionMultiHeadedAttention` and `RoPERelPositionMultiHeadedAttention`
pos_emb = self.embed.position_encoding(
offset=offset - cache_t1, size=attention_key_size)
......@@ -474,21 +477,35 @@ class ConformerEncoder(BaseEncoder):
activation = get_activation(activation_type)
# self-attention module definition
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, output_size,
attention_dropout_rate)
encoder_dim = output_size
if pos_enc_layer_type == "abs_pos":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate)
elif pos_enc_layer_type == "rel_pos":
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate)
elif pos_enc_layer_type == "rope_pos":
encoder_selfattn_layer = RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate)
else:
raise ValueError(
f"pos_enc_layer_type {pos_enc_layer_type} not supported.")
# feed-forward module definition
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (output_size, linear_units, dropout_rate,
positionwise_layer_args = (encoder_dim, linear_units, dropout_rate,
activation)
# convolution module definition
convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation,
convolution_layer_args = (encoder_dim, cnn_module_kernel, activation,
cnn_module_norm, causal)
self.encoders = nn.LayerList([
ConformerEncoderLayer(
size=output_size,
size=encoder_dim,
self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),
feed_forward=positionwise_layer(*positionwise_layer_args),
feed_forward_macaron=positionwise_layer(
......@@ -580,15 +597,23 @@ class SqueezeformerEncoder(nn.Layer):
activation = get_activation(activation_type)
# self-attention module definition
if pos_enc_layer_type != "rel_pos":
if pos_enc_layer_type == "abs_pos":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, output_size,
attention_dropout_rate)
else:
elif pos_enc_layer_type == "rel_pos":
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate,
adaptive_scale, init_weights)
elif pos_enc_layer_type == "rope_pos":
encoder_selfattn_layer = RoPERelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate,
adaptive_scale, init_weights)
else:
raise ValueError(
f"pos_enc_layer_type {pos_enc_layer_type} not supported.")
# feed-forward module definition
positionwise_layer = PositionwiseFeedForward
......
......@@ -48,7 +48,7 @@ class TransformerEncoderLayer(nn.Layer):
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
`MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
......@@ -147,7 +147,7 @@ class ConformerEncoderLayer(nn.Layer):
Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
`MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
......@@ -298,7 +298,7 @@ class SqueezeformerEncoderLayer(nn.Layer):
Args:
size (int): Input dimension.
self_attn (paddle.nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
`MultiHeadedAttention`, `RelPositionMultiHeadedAttention` or `RoPERelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward1 (paddle.nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
......
......@@ -102,8 +102,7 @@ class OptimizerFactory():
grad_clip = paddle.nn.ClipGradByGlobalNorm(
args['grad_clip']) if "grad_clip" in args else None
weight_decay = L2Decay(
args['weight_decay']) if "weight_decay" in args else None
weight_decay = args.get("weight_decay", None)
if weight_decay:
logger.info(f'<WeightDecay - {weight_decay}>')
if grad_clip:
......