提交 55870ffb 编写于 作者: H Hui Zhang

fix bugs

上级 03e9ea9e
......@@ -18,7 +18,7 @@ encoder_conf:
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos
pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
......@@ -30,7 +30,7 @@ decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
r_num_blocks: 3 # only for bitransformer
r_num_blocks: 0 # only for bitransformer
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
......@@ -39,7 +39,7 @@ decoder_conf:
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.3 # only for bitransformer
reverse_weight: 0.0 # only for bitransformer
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
......
......@@ -18,7 +18,7 @@ encoder_conf:
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos
pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
......
......@@ -145,7 +145,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
text_lengths)
ctc_time = time.time() - start
#logger.debug(f"ctc time: {ctc_time}")
if loss_ctc is None:
loss = loss_att
elif loss_att is None:
......@@ -916,6 +915,8 @@ class U2Model(U2DecodeModel):
decoder_type = configs.get('decoder', 'transformer')
logger.debug(f"U2 Decoder type: {decoder_type}")
if decoder_type == 'transformer':
configs['model_conf'].pop('reverse_weight', None)
configs['decoder_conf'].pop('r_num_blocks', None)
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
......
......@@ -16,6 +16,7 @@
"""Multi-Head Attention layer definition."""
import math
from typing import Tuple
from typing import List
import paddle
from paddle import nn
......@@ -418,25 +419,27 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
def apply_rotary_position_embeddings(self, sinusoidal, *tensors):
"""应用RoPE到tensors中
其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而
tensor.shape=[B, T, ..., D], or (B,T,H,D/H)
tensor.shape=[B, T, ..., D], or (B,H,T,D/H)
"""
assert len(tensors) > 0, 'at least one input tensor'
assert all(
[tensor.shape == tensors[0].shape
for tensor in tensors[1:]]), 'all tensors must have the same shape'
# (B,H,T,D)
ndim = tensors[0].dim()
_,H,T,D = tensors[0].shape
# sinusoidal shape same with tensors[0]
# [B,T,D] -> [B,T,1,D]
sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
# [B,T,D] -> [B,T,H,D/H] -> (B,H,T,D/H)
# sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
sinusoidal = sinusoidal.reshape((1, T, H, D)).transpose([0, 2, 1, 3])
# http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,T, ..., d/2] -> [b,T, ..., d]
cos_pos = paddle.repeat_interleave(sinusoidal[..., 1::2], 2, axis=-1)
sin_pos = paddle.repeat_interleave(sinusoidal[..., 0::2], 2, axis=-1)
outputs = []
for tensor in tensors:
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
......@@ -501,7 +504,7 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
new_cache = paddle.concat((k, v), axis=-1)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
q, k = self.apply_rotary_position_embeddings(pos_emb, [q, k])
q, k = self.apply_rotary_position_embeddings(pos_emb, q, k)
# dot(q, k)
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
......@@ -477,9 +477,10 @@ class ConformerEncoder(BaseEncoder):
activation = get_activation(activation_type)
# self-attention module definition
encoder_dim = output_size
if pos_enc_layer_type == "abs_pos":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, output_size,
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate)
elif pos_enc_layer_type == "rel_pos":
encoder_selfattn_layer = RelPositionMultiHeadedAttention
......@@ -495,16 +496,16 @@ class ConformerEncoder(BaseEncoder):
# feed-forward module definition
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (output_size, linear_units, dropout_rate,
positionwise_layer_args = (encoder_dim, linear_units, dropout_rate,
activation)
# convolution module definition
convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation,
convolution_layer_args = (encoder_dim, cnn_module_kernel, activation,
cnn_module_norm, causal)
self.encoders = nn.LayerList([
ConformerEncoderLayer(
size=output_size,
size=encoder_dim,
self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),
feed_forward=positionwise_layer(*positionwise_layer_args),
feed_forward_macaron=positionwise_layer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册