提交 55870ffb 编写于 作者: H Hui Zhang

fix bugs

上级 03e9ea9e
...@@ -18,7 +18,7 @@ encoder_conf: ...@@ -18,7 +18,7 @@ encoder_conf:
cnn_module_kernel: 15 cnn_module_kernel: 15
use_cnn_module: True use_cnn_module: True
activation_type: 'swish' activation_type: 'swish'
pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused selfattention_layer_type: 'rel_selfattn' # unused
causal: true causal: true
use_dynamic_chunk: true use_dynamic_chunk: true
...@@ -30,7 +30,7 @@ decoder_conf: ...@@ -30,7 +30,7 @@ decoder_conf:
attention_heads: 4 attention_heads: 4
linear_units: 2048 linear_units: 2048
num_blocks: 6 num_blocks: 6
r_num_blocks: 3 # only for bitransformer r_num_blocks: 0 # only for bitransformer
dropout_rate: 0.1 # sublayer output dropout dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1 positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0 self_attention_dropout_rate: 0.0
...@@ -39,7 +39,7 @@ decoder_conf: ...@@ -39,7 +39,7 @@ decoder_conf:
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.3 # only for bitransformer reverse_weight: 0.0 # only for bitransformer
length_normalized_loss: false length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence init_type: 'kaiming_uniform' # !Warning: need to convergence
......
...@@ -18,7 +18,7 @@ encoder_conf: ...@@ -18,7 +18,7 @@ encoder_conf:
cnn_module_kernel: 15 cnn_module_kernel: 15
use_cnn_module: True use_cnn_module: True
activation_type: 'swish' activation_type: 'swish'
pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused selfattention_layer_type: 'rel_selfattn' # unused
causal: true causal: true
use_dynamic_chunk: true use_dynamic_chunk: true
......
...@@ -145,7 +145,6 @@ class U2BaseModel(ASRInterface, nn.Layer): ...@@ -145,7 +145,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
text_lengths) text_lengths)
ctc_time = time.time() - start ctc_time = time.time() - start
#logger.debug(f"ctc time: {ctc_time}") #logger.debug(f"ctc time: {ctc_time}")
if loss_ctc is None: if loss_ctc is None:
loss = loss_att loss = loss_att
elif loss_att is None: elif loss_att is None:
...@@ -916,6 +915,8 @@ class U2Model(U2DecodeModel): ...@@ -916,6 +915,8 @@ class U2Model(U2DecodeModel):
decoder_type = configs.get('decoder', 'transformer') decoder_type = configs.get('decoder', 'transformer')
logger.debug(f"U2 Decoder type: {decoder_type}") logger.debug(f"U2 Decoder type: {decoder_type}")
if decoder_type == 'transformer': if decoder_type == 'transformer':
configs['model_conf'].pop('reverse_weight', None)
configs['decoder_conf'].pop('r_num_blocks', None)
decoder = TransformerDecoder(vocab_size, decoder = TransformerDecoder(vocab_size,
encoder.output_size(), encoder.output_size(),
**configs['decoder_conf']) **configs['decoder_conf'])
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Multi-Head Attention layer definition.""" """Multi-Head Attention layer definition."""
import math import math
from typing import Tuple from typing import Tuple
from typing import List
import paddle import paddle
from paddle import nn from paddle import nn
...@@ -418,25 +419,27 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -418,25 +419,27 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
def apply_rotary_position_embeddings(self, sinusoidal, *tensors): def apply_rotary_position_embeddings(self, sinusoidal, *tensors):
"""应用RoPE到tensors中 """应用RoPE到tensors中
其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而 其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而
tensor.shape=[B, T, ..., D], or (B,T,H,D/H) tensor.shape=[B, T, ..., D], or (B,H,T,D/H)
""" """
assert len(tensors) > 0, 'at least one input tensor' assert len(tensors) > 0, 'at least one input tensor'
assert all( assert all(
[tensor.shape == tensors[0].shape [tensor.shape == tensors[0].shape
for tensor in tensors[1:]]), 'all tensors must have the same shape' for tensor in tensors[1:]]), 'all tensors must have the same shape'
# (B,H,T,D)
ndim = tensors[0].dim() ndim = tensors[0].dim()
_,H,T,D = tensors[0].shape
# sinusoidal shape same with tensors[0] # sinusoidal shape same with tensors[0]
# [B,T,D] -> [B,T,1,D] # [B,T,D] -> [B,T,H,D/H] -> (B,H,T,D/H)
sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim) # sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
sinusoidal = sinusoidal.reshape((1, T, H, D)).transpose([0, 2, 1, 3])
# http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html # http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3) # like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,T, ..., d/2] -> [b,T, ..., d] # [b,T, ..., d/2] -> [b,T, ..., d]
cos_pos = paddle.repeat_interleave(sinusoidal[..., 1::2], 2, axis=-1) cos_pos = paddle.repeat_interleave(sinusoidal[..., 1::2], 2, axis=-1)
sin_pos = paddle.repeat_interleave(sinusoidal[..., 0::2], 2, axis=-1) sin_pos = paddle.repeat_interleave(sinusoidal[..., 0::2], 2, axis=-1)
outputs = [] outputs = []
for tensor in tensors: for tensor in tensors:
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}] # x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
...@@ -501,7 +504,7 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -501,7 +504,7 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
new_cache = paddle.concat((k, v), axis=-1) new_cache = paddle.concat((k, v), axis=-1)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index # f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
q, k = self.apply_rotary_position_embeddings(pos_emb, [q, k]) q, k = self.apply_rotary_position_embeddings(pos_emb, q, k)
# dot(q, k) # dot(q, k)
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k) scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache return self.forward_attention(v, scores, mask), new_cache
...@@ -477,9 +477,10 @@ class ConformerEncoder(BaseEncoder): ...@@ -477,9 +477,10 @@ class ConformerEncoder(BaseEncoder):
activation = get_activation(activation_type) activation = get_activation(activation_type)
# self-attention module definition # self-attention module definition
encoder_dim = output_size
if pos_enc_layer_type == "abs_pos": if pos_enc_layer_type == "abs_pos":
encoder_selfattn_layer = MultiHeadedAttention encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, output_size, encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate) attention_dropout_rate)
elif pos_enc_layer_type == "rel_pos": elif pos_enc_layer_type == "rel_pos":
encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer = RelPositionMultiHeadedAttention
...@@ -495,16 +496,16 @@ class ConformerEncoder(BaseEncoder): ...@@ -495,16 +496,16 @@ class ConformerEncoder(BaseEncoder):
# feed-forward module definition # feed-forward module definition
positionwise_layer = PositionwiseFeedForward positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (output_size, linear_units, dropout_rate, positionwise_layer_args = (encoder_dim, linear_units, dropout_rate,
activation) activation)
# convolution module definition # convolution module definition
convolution_layer = ConvolutionModule convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation, convolution_layer_args = (encoder_dim, cnn_module_kernel, activation,
cnn_module_norm, causal) cnn_module_norm, causal)
self.encoders = nn.LayerList([ self.encoders = nn.LayerList([
ConformerEncoderLayer( ConformerEncoderLayer(
size=output_size, size=encoder_dim,
self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args), self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),
feed_forward=positionwise_layer(*positionwise_layer_args), feed_forward=positionwise_layer(*positionwise_layer_args),
feed_forward_macaron=positionwise_layer( feed_forward_macaron=positionwise_layer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册