未验证 提交 d368d57d 编写于 作者: 小湉湉's avatar 小湉湉 提交者: GitHub

fix low ips bug of speedyspeech and fastspeech2, test=tts (#1349)

上级 5aff0bde
...@@ -627,7 +627,7 @@ class FastSpeech2(nn.Layer): ...@@ -627,7 +627,7 @@ class FastSpeech2(nn.Layer):
hs = hs + e_embs + p_embs hs = hs + e_embs + p_embs
# (B, Lmax, adim) # (B, Lmax, adim)
hs = self.length_regulator(hs, d_outs, alpha) hs = self.length_regulator(hs, d_outs, alpha, is_inference=True)
else: else:
d_outs = self.duration_predictor(hs, d_masks) d_outs = self.duration_predictor(hs, d_masks)
# use groundtruth in training # use groundtruth in training
...@@ -638,7 +638,7 @@ class FastSpeech2(nn.Layer): ...@@ -638,7 +638,7 @@ class FastSpeech2(nn.Layer):
hs = hs + e_embs + p_embs hs = hs + e_embs + p_embs
# (B, Lmax, adim) # (B, Lmax, adim)
hs = self.length_regulator(hs, ds) hs = self.length_regulator(hs, ds, is_inference=False)
# forward decoder # forward decoder
if olens is not None and not is_inference: if olens is not None and not is_inference:
......
...@@ -14,28 +14,9 @@ ...@@ -14,28 +14,9 @@
import paddle import paddle
from paddle import nn from paddle import nn
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.positional_encoding import sinusoid_position_encoding from paddlespeech.t2s.modules.positional_encoding import sinusoid_position_encoding
from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator
def expand(encodings: paddle.Tensor, durations: paddle.Tensor) -> paddle.Tensor:
"""
encodings: (B, T, C)
durations: (B, T)
"""
batch_size, t_enc = paddle.shape(durations)
slens = paddle.sum(durations, -1)
t_dec = paddle.max(slens)
M = paddle.zeros([batch_size, t_dec, t_enc])
for i in range(batch_size):
k = 0
for j in range(t_enc):
d = durations[i, j]
# If the d == 0, slice action is meaningless and not supported
if d >= 1:
M[0, k:k + d, j] = 1
k += d
encodings = paddle.matmul(M, encodings)
return encodings
class ResidualBlock(nn.Layer): class ResidualBlock(nn.Layer):
...@@ -175,19 +156,25 @@ class SpeedySpeechDecoder(nn.Layer): ...@@ -175,19 +156,25 @@ class SpeedySpeechDecoder(nn.Layer):
class SpeedySpeech(nn.Layer): class SpeedySpeech(nn.Layer):
def __init__(self, def __init__(
vocab_size, self,
encoder_hidden_size, vocab_size,
encoder_kernel_size, encoder_hidden_size,
encoder_dilations, encoder_kernel_size,
duration_predictor_hidden_size, encoder_dilations,
decoder_hidden_size, duration_predictor_hidden_size,
decoder_output_size, decoder_hidden_size,
decoder_kernel_size, decoder_output_size,
decoder_dilations, decoder_kernel_size,
tone_size=None, decoder_dilations,
spk_num=None): tone_size=None,
spk_num=None,
init_type: str="xavier_uniform", ):
super().__init__() super().__init__()
# initialize parameters
initialize(self, init_type)
encoder = SpeedySpeechEncoder(vocab_size, tone_size, encoder = SpeedySpeechEncoder(vocab_size, tone_size,
encoder_hidden_size, encoder_kernel_size, encoder_hidden_size, encoder_kernel_size,
encoder_dilations, spk_num) encoder_dilations, spk_num)
...@@ -198,6 +185,10 @@ class SpeedySpeech(nn.Layer): ...@@ -198,6 +185,10 @@ class SpeedySpeech(nn.Layer):
self.encoder = encoder self.encoder = encoder
self.duration_predictor = duration_predictor self.duration_predictor = duration_predictor
self.decoder = decoder self.decoder = decoder
# define length regulator
self.length_regulator = LengthRegulator()
nn.initializer.set_global_initializer(None)
def forward(self, text, tones, durations, spk_id: paddle.Tensor=None): def forward(self, text, tones, durations, spk_id: paddle.Tensor=None):
# input of embedding must be int64 # input of embedding must be int64
...@@ -212,7 +203,7 @@ class SpeedySpeech(nn.Layer): ...@@ -212,7 +203,7 @@ class SpeedySpeech(nn.Layer):
# expand encodings # expand encodings
durations_to_expand = durations durations_to_expand = durations
encodings = expand(encodings, durations_to_expand) encodings = self.length_regulator(encodings, durations_to_expand)
# decode # decode
# remove positional encoding here # remove positional encoding here
...@@ -240,7 +231,8 @@ class SpeedySpeech(nn.Layer): ...@@ -240,7 +231,8 @@ class SpeedySpeech(nn.Layer):
durations_to_expand = durations_to_expand.astype(paddle.int64) durations_to_expand = durations_to_expand.astype(paddle.int64)
else: else:
durations_to_expand = durations durations_to_expand = durations
encodings = expand(encodings, durations_to_expand) encodings = self.length_regulator(
encodings, durations_to_expand, is_inference=True)
shape = paddle.shape(encodings) shape = paddle.shape(encodings)
t_dec, feature_size = shape[1], shape[2] t_dec, feature_size = shape[1], shape[2]
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet) # Modified from espnet(https://github.com/espnet/espnet)
"""Length regulator related modules.""" """Length regulator related modules."""
import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
...@@ -43,6 +44,28 @@ class LengthRegulator(nn.Layer): ...@@ -43,6 +44,28 @@ class LengthRegulator(nn.Layer):
super().__init__() super().__init__()
self.pad_value = pad_value self.pad_value = pad_value
# expand_numpy is faster than expand
def expand_numpy(self, encodings: paddle.Tensor,
durations: paddle.Tensor) -> paddle.Tensor:
"""
encodings: (B, T, C)
durations: (B, T)
"""
batch_size, t_enc = durations.shape
durations = durations.numpy()
slens = np.sum(durations, -1)
t_dec = np.max(slens)
M = np.zeros([batch_size, t_dec, t_enc])
for i in range(batch_size):
k = 0
for j in range(t_enc):
d = durations[i, j]
M[i, k:k + d, j] = 1
k += d
M = paddle.to_tensor(M, dtype=encodings.dtype)
encodings = paddle.matmul(M, encodings)
return encodings
def expand(self, encodings: paddle.Tensor, def expand(self, encodings: paddle.Tensor,
durations: paddle.Tensor) -> paddle.Tensor: durations: paddle.Tensor) -> paddle.Tensor:
""" """
...@@ -50,20 +73,21 @@ class LengthRegulator(nn.Layer): ...@@ -50,20 +73,21 @@ class LengthRegulator(nn.Layer):
durations: (B, T) durations: (B, T)
""" """
batch_size, t_enc = paddle.shape(durations) batch_size, t_enc = paddle.shape(durations)
slens = durations.sum(-1) slens = paddle.sum(durations, -1)
t_dec = slens.max() t_dec = paddle.max(slens)
M = paddle.zeros([batch_size, t_dec, t_enc]) M = paddle.zeros([batch_size, t_dec, t_enc])
for i in range(batch_size): for i in range(batch_size):
k = 0 k = 0
for j in range(t_enc): for j in range(t_enc):
d = durations[i, j] d = durations[i, j]
# If the d == 0, slice action is meaningless and not supported in paddle
if d >= 1: if d >= 1:
M[i, k:k + d, j] = 1 M[i, k:k + d, j] = 1
k += d k += d
encodings = paddle.matmul(M, encodings) encodings = paddle.matmul(M, encodings)
return encodings return encodings
def forward(self, xs, ds, alpha=1.0): def forward(self, xs, ds, alpha=1.0, is_inference=False):
"""Calculate forward propagation. """Calculate forward propagation.
Parameters Parameters
...@@ -85,4 +109,7 @@ class LengthRegulator(nn.Layer): ...@@ -85,4 +109,7 @@ class LengthRegulator(nn.Layer):
assert alpha > 0 assert alpha > 0
ds = paddle.round(ds.cast(dtype=paddle.float32) * alpha) ds = paddle.round(ds.cast(dtype=paddle.float32) * alpha)
ds = ds.cast(dtype=paddle.int64) ds = ds.cast(dtype=paddle.int64)
return self.expand(xs, ds) if is_inference:
return self.expand(xs, ds)
else:
return self.expand_numpy(xs, ds)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册