提交 001afee6 编写于 作者: 小湉湉's avatar 小湉湉

fix wavernn dygraph to static , test=tts

上级 2071774d
......@@ -48,4 +48,15 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt
fi
# wavernn
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=fastspeech2_csmsc \
--voc=wavernn_csmsc \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt
fi
\ No newline at end of file
......@@ -108,5 +108,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference
fi
......@@ -54,7 +54,7 @@ def main():
default='pwgan_csmsc',
choices=[
'pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc', 'pwgan_aishell3',
'pwgan_vctk'
'pwgan_vctk', 'wavernn_csmsc'
],
help='Choose vocoder type of tts task.')
# other
......
......@@ -76,6 +76,7 @@ class MelResNet(nn.Layer):
Tensor
Output tensor (B, res_out_dims, T).
'''
x = self.conv_in(x)
x = self.batch_norm(x)
x = F.relu(x)
......@@ -230,6 +231,7 @@ class WaveRNN(nn.Layer):
self.rnn1 = nn.GRU(rnn_dims, rnn_dims)
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims)
self._to_flatten += [self.rnn1, self.rnn2]
self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
......@@ -326,17 +328,17 @@ class WaveRNN(nn.Layer):
output = []
start = time.time()
rnn1 = self.get_gru_cell(self.rnn1)
rnn2 = self.get_gru_cell(self.rnn2)
# pseudo batch
# (T, C_aux) -> (1, C_aux, T)
c = paddle.transpose(c, [1, 0]).unsqueeze(0)
wave_len = (paddle.shape(c)[-1] - 1) * self.hop_length
T = paddle.shape(c)[-1]
wave_len = (T - 1) * self.hop_length
# TODO remove two transpose op by modifying function pad_tensor
c = self.pad_tensor(
c.transpose([0, 2, 1]), pad=self.aux_context_window,
side='both').transpose([0, 2, 1])
c, aux = self.upsample(c)
if batched:
......@@ -344,7 +346,13 @@ class WaveRNN(nn.Layer):
c = self.fold_with_overlap(c, target, overlap)
aux = self.fold_with_overlap(aux, target, overlap)
b_size, seq_len, _ = paddle.shape(c)
# for dygraph to static graph, if use seq_len of `b_size, seq_len, _ = paddle.shape(c)` in for
# will not get TensorArray
# see https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/04_dygraph_to_static/case_analysis_cn.html#list-lodtensorarray
# b_size, seq_len, _ = paddle.shape(c)
b_size = paddle.shape(c)[0]
seq_len = paddle.shape(c)[1]
h1 = paddle.zeros([b_size, self.rnn_dims])
h2 = paddle.zeros([b_size, self.rnn_dims])
x = paddle.zeros([b_size, 1])
......@@ -354,14 +362,20 @@ class WaveRNN(nn.Layer):
for i in range(seq_len):
m_t = c[:, i, :]
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
# for dygraph to static graph
# a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
a1_t = aux_split[0][:, i, :]
a2_t = aux_split[1][:, i, :]
a3_t = aux_split[2][:, i, :]
a4_t = aux_split[3][:, i, :]
x = paddle.concat([x, m_t, a1_t], axis=1)
x = self.I(x)
h1, _ = rnn1(x, h1)
# use GRUCell here
h1, _ = self.rnn1[0].cell(x, h1)
x = x + h1
inp = paddle.concat([x, a2_t], axis=1)
h2, _ = rnn2(inp, h2)
# use GRUCell here
h2, _ = self.rnn2[0].cell(inp, h2)
x = x + h2
x = paddle.concat([x, a3_t], axis=1)
......@@ -413,15 +427,6 @@ class WaveRNN(nn.Layer):
# 增加 C_out 维度
return output.unsqueeze(-1)
def get_gru_cell(self, gru):
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
gru_cell.weight_hh = gru.weight_hh_l0
gru_cell.weight_ih = gru.weight_ih_l0
gru_cell.bias_hh = gru.bias_hh_l0
gru_cell.bias_ih = gru.bias_ih_l0
return gru_cell
def _flatten_parameters(self):
[m.flatten_parameters() for m in self._to_flatten]
......@@ -438,7 +443,9 @@ class WaveRNN(nn.Layer):
----------
Tensor
'''
b, t, c = paddle.shape(x)
b, t, _ = paddle.shape(x)
# for dygraph to static graph
c = x.shape[-1]
total = t + 2 * pad if side == 'both' else t + pad
padded = paddle.zeros([b, total, c])
if side == 'before' or side == 'both':
......@@ -516,7 +523,7 @@ class WaveRNN(nn.Layer):
y : Tensor
Batched sequences of audio samples
shape=(num_folds, target + 2 * overlap)
dtype=paddle.float64
dtype=paddle.float32
overlap : int
Timesteps for both xfade and rnn warmup
......@@ -525,7 +532,7 @@ class WaveRNN(nn.Layer):
Tensor
audio samples in a 1d array
shape=(total_len)
dtype=paddle.float64
dtype=paddle.float32
Details
----------
......@@ -545,19 +552,19 @@ class WaveRNN(nn.Layer):
'''
# num_folds = (total_len - overlap) // (target + overlap)
num_folds, length = y.shape
num_folds, length = paddle.shape(y)
target = length - 2 * overlap
total_len = num_folds * (target + overlap) + overlap
# Need some silence for the run warmup
slience_len = overlap // 2
fade_len = overlap - slience_len
slience = paddle.zeros([slience_len], dtype=paddle.float64)
linear = paddle.ones([fade_len], dtype=paddle.float64)
slience = paddle.zeros([slience_len], dtype=paddle.float32)
linear = paddle.ones([fade_len], dtype=paddle.float32)
# Equal power crossfade
# fade_in increase from 0 to 1, fade_out reduces from 1 to 0
t = paddle.linspace(-1, 1, fade_len, dtype=paddle.float64)
t = paddle.linspace(-1, 1, fade_len, dtype=paddle.float32)
fade_in = paddle.sqrt(0.5 * (1 + t))
fade_out = paddle.sqrt(0.5 * (1 - t))
# Concat the silence to the fades
......@@ -568,7 +575,7 @@ class WaveRNN(nn.Layer):
y[:, :overlap] *= fade_in
y[:, -overlap:] *= fade_out
unfolded = paddle.zeros([total_len], dtype=paddle.float64)
unfolded = paddle.zeros([total_len], dtype=paddle.float32)
# Loop to add up all the samples
for i in range(num_folds):
......@@ -606,11 +613,13 @@ class WaveRNNInference(nn.Layer):
mu_law: bool=True,
gen_display: bool=False):
normalized_mel = self.normalizer(logmel)
wav = self.wavernn.generate(
normalized_mel,
batched=batched,
target=target,
overlap=overlap,
mu_law=mu_law,
gen_display=gen_display)
normalized_mel, )
# batched=batched,
# target=target,
# overlap=overlap,
# mu_law=mu_law,
# gen_display=gen_display)
return wav
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册