diff --git a/examples/csmsc/tts3/local/inference.sh b/examples/csmsc/tts3/local/inference.sh index 7c58980cdd1e9602743b13dccbaac09b0e3f443b..9322cfd697912100663457f2b9bcada543e27733 100755 --- a/examples/csmsc/tts3/local/inference.sh +++ b/examples/csmsc/tts3/local/inference.sh @@ -48,4 +48,15 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --text=${BIN_DIR}/../sentences.txt \ --output_dir=${train_output_path}/pd_infer_out \ --phones_dict=dump/phone_id_map.txt +fi + +# wavernn +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + python3 ${BIN_DIR}/../inference.py \ + --inference_dir=${train_output_path}/inference \ + --am=fastspeech2_csmsc \ + --voc=wavernn_csmsc \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/pd_infer_out \ + --phones_dict=dump/phone_id_map.txt fi \ No newline at end of file diff --git a/examples/csmsc/tts3/local/synthesize_e2e.sh b/examples/csmsc/tts3/local/synthesize_e2e.sh index 49101ea0bafee10e317fa59872c99c377a311216..d1fadf77d9e14b0230f5d3547d3dcd7e8a221b7d 100755 --- a/examples/csmsc/tts3/local/synthesize_e2e.sh +++ b/examples/csmsc/tts3/local/synthesize_e2e.sh @@ -108,5 +108,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --lang=zh \ --text=${BIN_DIR}/../sentences.txt \ --output_dir=${train_output_path}/test_e2e \ - --phones_dict=dump/phone_id_map.txt + --phones_dict=dump/phone_id_map.txt \ + --inference_dir=${train_output_path}/inference fi diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 37afd0abcf4bffc67de9a0bc6437e1a316a865c2..8044c445ef163cdaeb0ad6216f7c6b2e2fb3884d 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -54,7 +54,7 @@ def main(): default='pwgan_csmsc', choices=[ 'pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc', 'pwgan_aishell3', - 'pwgan_vctk' + 'pwgan_vctk', 'wavernn_csmsc' ], help='Choose vocoder type of tts task.') # other diff --git a/paddlespeech/t2s/models/wavernn/wavernn.py b/paddlespeech/t2s/models/wavernn/wavernn.py index 2c6941b04f682a05f1f06faf72ef7515805ef06b..f30879ed6c72dfde9ee4874c87117c740f8ed364 100644 --- a/paddlespeech/t2s/models/wavernn/wavernn.py +++ b/paddlespeech/t2s/models/wavernn/wavernn.py @@ -76,6 +76,7 @@ class MelResNet(nn.Layer): Tensor Output tensor (B, res_out_dims, T). ''' + x = self.conv_in(x) x = self.batch_norm(x) x = F.relu(x) @@ -230,6 +231,7 @@ class WaveRNN(nn.Layer): self.rnn1 = nn.GRU(rnn_dims, rnn_dims) self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims) + self._to_flatten += [self.rnn1, self.rnn2] self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) @@ -326,17 +328,17 @@ class WaveRNN(nn.Layer): output = [] start = time.time() - rnn1 = self.get_gru_cell(self.rnn1) - rnn2 = self.get_gru_cell(self.rnn2) + # pseudo batch # (T, C_aux) -> (1, C_aux, T) c = paddle.transpose(c, [1, 0]).unsqueeze(0) - - wave_len = (paddle.shape(c)[-1] - 1) * self.hop_length + T = paddle.shape(c)[-1] + wave_len = (T - 1) * self.hop_length # TODO remove two transpose op by modifying function pad_tensor c = self.pad_tensor( c.transpose([0, 2, 1]), pad=self.aux_context_window, side='both').transpose([0, 2, 1]) + c, aux = self.upsample(c) if batched: @@ -344,7 +346,13 @@ class WaveRNN(nn.Layer): c = self.fold_with_overlap(c, target, overlap) aux = self.fold_with_overlap(aux, target, overlap) - b_size, seq_len, _ = paddle.shape(c) + # for dygraph to static graph, if use seq_len of `b_size, seq_len, _ = paddle.shape(c)` in for + # will not get TensorArray + # see https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/04_dygraph_to_static/case_analysis_cn.html#list-lodtensorarray + # b_size, seq_len, _ = paddle.shape(c) + b_size = paddle.shape(c)[0] + seq_len = paddle.shape(c)[1] + h1 = paddle.zeros([b_size, self.rnn_dims]) h2 = paddle.zeros([b_size, self.rnn_dims]) x = paddle.zeros([b_size, 1]) @@ -354,14 +362,20 @@ class WaveRNN(nn.Layer): for i in range(seq_len): m_t = c[:, i, :] - - a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) + # for dygraph to static graph + # a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) + a1_t = aux_split[0][:, i, :] + a2_t = aux_split[1][:, i, :] + a3_t = aux_split[2][:, i, :] + a4_t = aux_split[3][:, i, :] x = paddle.concat([x, m_t, a1_t], axis=1) x = self.I(x) - h1, _ = rnn1(x, h1) + # use GRUCell here + h1, _ = self.rnn1[0].cell(x, h1) x = x + h1 inp = paddle.concat([x, a2_t], axis=1) - h2, _ = rnn2(inp, h2) + # use GRUCell here + h2, _ = self.rnn2[0].cell(inp, h2) x = x + h2 x = paddle.concat([x, a3_t], axis=1) @@ -413,15 +427,6 @@ class WaveRNN(nn.Layer): # 增加 C_out 维度 return output.unsqueeze(-1) - def get_gru_cell(self, gru): - gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) - gru_cell.weight_hh = gru.weight_hh_l0 - gru_cell.weight_ih = gru.weight_ih_l0 - gru_cell.bias_hh = gru.bias_hh_l0 - gru_cell.bias_ih = gru.bias_ih_l0 - - return gru_cell - def _flatten_parameters(self): [m.flatten_parameters() for m in self._to_flatten] @@ -438,7 +443,9 @@ class WaveRNN(nn.Layer): ---------- Tensor ''' - b, t, c = paddle.shape(x) + b, t, _ = paddle.shape(x) + # for dygraph to static graph + c = x.shape[-1] total = t + 2 * pad if side == 'both' else t + pad padded = paddle.zeros([b, total, c]) if side == 'before' or side == 'both': @@ -516,7 +523,7 @@ class WaveRNN(nn.Layer): y : Tensor Batched sequences of audio samples shape=(num_folds, target + 2 * overlap) - dtype=paddle.float64 + dtype=paddle.float32 overlap : int Timesteps for both xfade and rnn warmup @@ -525,7 +532,7 @@ class WaveRNN(nn.Layer): Tensor audio samples in a 1d array shape=(total_len) - dtype=paddle.float64 + dtype=paddle.float32 Details ---------- @@ -545,19 +552,19 @@ class WaveRNN(nn.Layer): ''' # num_folds = (total_len - overlap) // (target + overlap) - num_folds, length = y.shape + num_folds, length = paddle.shape(y) target = length - 2 * overlap total_len = num_folds * (target + overlap) + overlap # Need some silence for the run warmup slience_len = overlap // 2 fade_len = overlap - slience_len - slience = paddle.zeros([slience_len], dtype=paddle.float64) - linear = paddle.ones([fade_len], dtype=paddle.float64) + slience = paddle.zeros([slience_len], dtype=paddle.float32) + linear = paddle.ones([fade_len], dtype=paddle.float32) # Equal power crossfade # fade_in increase from 0 to 1, fade_out reduces from 1 to 0 - t = paddle.linspace(-1, 1, fade_len, dtype=paddle.float64) + t = paddle.linspace(-1, 1, fade_len, dtype=paddle.float32) fade_in = paddle.sqrt(0.5 * (1 + t)) fade_out = paddle.sqrt(0.5 * (1 - t)) # Concat the silence to the fades @@ -568,7 +575,7 @@ class WaveRNN(nn.Layer): y[:, :overlap] *= fade_in y[:, -overlap:] *= fade_out - unfolded = paddle.zeros([total_len], dtype=paddle.float64) + unfolded = paddle.zeros([total_len], dtype=paddle.float32) # Loop to add up all the samples for i in range(num_folds): @@ -606,11 +613,13 @@ class WaveRNNInference(nn.Layer): mu_law: bool=True, gen_display: bool=False): normalized_mel = self.normalizer(logmel) + wav = self.wavernn.generate( - normalized_mel, - batched=batched, - target=target, - overlap=overlap, - mu_law=mu_law, - gen_display=gen_display) + normalized_mel, ) + # batched=batched, + # target=target, + # overlap=overlap, + # mu_law=mu_law, + # gen_display=gen_display) + return wav