提交 111a4523 编写于 作者: J Jerryuhoo

Fix the code format, test=tts

上级 1e710ef5
...@@ -15,21 +15,22 @@ ...@@ -15,21 +15,22 @@
# for mb melgan finetune # for mb melgan finetune
# 长度和原本的 mel 不一致怎么办? # 长度和原本的 mel 不一致怎么办?
import argparse import argparse
import os
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle import paddle
import yaml import yaml
from yacs.config import CfgNode
from tqdm import tqdm from tqdm import tqdm
import os from yacs.config import CfgNode
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.models.speedyspeech import SpeedySpeech from paddlespeech.t2s.models.speedyspeech import SpeedySpeech
from paddlespeech.t2s.models.speedyspeech import SpeedySpeechInference from paddlespeech.t2s.models.speedyspeech import SpeedySpeechInference
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.t2s.frontend.zh_frontend import Frontend
def evaluate(args, speedyspeech_config): def evaluate(args, speedyspeech_config):
rootdir = Path(args.rootdir).expanduser() rootdir = Path(args.rootdir).expanduser()
...@@ -50,17 +51,21 @@ def evaluate(args, speedyspeech_config): ...@@ -50,17 +51,21 @@ def evaluate(args, speedyspeech_config):
tone_size = len(tone_id) tone_size = len(tone_id)
print("tone_size:", tone_size) print("tone_size:", tone_size)
frontend = Frontend(phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) frontend = Frontend(
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
if args.speaker_dict: if args.speaker_dict:
with open(args.speaker_dict, 'rt') as f: with open(args.speaker_dict, 'rt') as f:
spk_id_list = [line.strip().split() for line in f.readlines()] spk_id_list = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id_list) spk_num = len(spk_id_list)
else: else:
spk_num=None spk_num = None
model = SpeedySpeech( model = SpeedySpeech(
vocab_size=vocab_size, tone_size=tone_size, **speedyspeech_config["model"], spk_num=spk_num) vocab_size=vocab_size,
tone_size=tone_size,
**speedyspeech_config["model"],
spk_num=spk_num)
model.set_state_dict( model.set_state_dict(
paddle.load(args.speedyspeech_checkpoint)["main_params"]) paddle.load(args.speedyspeech_checkpoint)["main_params"])
...@@ -105,9 +110,15 @@ def evaluate(args, speedyspeech_config): ...@@ -105,9 +110,15 @@ def evaluate(args, speedyspeech_config):
else: else:
train_wav_files += wav_files train_wav_files += wav_files
train_wav_files = [os.path.basename(str(str_path)) for str_path in train_wav_files] train_wav_files = [
dev_wav_files = [os.path.basename(str(str_path)) for str_path in dev_wav_files] os.path.basename(str(str_path)) for str_path in train_wav_files
test_wav_files = [os.path.basename(str(str_path)) for str_path in test_wav_files] ]
dev_wav_files = [
os.path.basename(str(str_path)) for str_path in dev_wav_files
]
test_wav_files = [
os.path.basename(str(str_path)) for str_path in test_wav_files
]
for i, utt_id in enumerate(tqdm(sentences)): for i, utt_id in enumerate(tqdm(sentences)):
phones = sentences[utt_id][0] phones = sentences[utt_id][0]
...@@ -122,8 +133,7 @@ def evaluate(args, speedyspeech_config): ...@@ -122,8 +133,7 @@ def evaluate(args, speedyspeech_config):
durations = durations[:-1] durations = durations[:-1]
phones = phones[:-1] phones = phones[:-1]
phones, tones = frontend._get_phone_tone( phones, tones = frontend._get_phone_tone(phones, get_tone_ids=True)
phones, get_tone_ids=True)
if tones: if tones:
tone_ids = frontend._t2id(tones) tone_ids = frontend._t2id(tones)
tone_ids = paddle.to_tensor(tone_ids) tone_ids = paddle.to_tensor(tone_ids)
...@@ -132,7 +142,8 @@ def evaluate(args, speedyspeech_config): ...@@ -132,7 +142,8 @@ def evaluate(args, speedyspeech_config):
phone_ids = paddle.to_tensor(phone_ids) phone_ids = paddle.to_tensor(phone_ids)
if args.speaker_dict: if args.speaker_dict:
speaker_id = int([item[1] for item in spk_id_list if speaker == item[0]][0]) speaker_id = int(
[item[1] for item in spk_id_list if speaker == item[0]][0])
speaker_id = paddle.to_tensor(speaker_id) speaker_id = paddle.to_tensor(speaker_id)
else: else:
speaker_id = None speaker_id = None
...@@ -155,7 +166,8 @@ def evaluate(args, speedyspeech_config): ...@@ -155,7 +166,8 @@ def evaluate(args, speedyspeech_config):
sub_output_dir.mkdir(parents=True, exist_ok=True) sub_output_dir.mkdir(parents=True, exist_ok=True)
with paddle.no_grad(): with paddle.no_grad():
mel = speedyspeech_inference(phone_ids, tone_ids, durations=durations, spk_id=speaker_id) mel = speedyspeech_inference(
phone_ids, tone_ids, durations=durations, spk_id=speaker_id)
np.save(sub_output_dir / (utt_id + "_feats.npy"), mel) np.save(sub_output_dir / (utt_id + "_feats.npy"), mel)
...@@ -193,10 +205,7 @@ def main(): ...@@ -193,10 +205,7 @@ def main():
default="tone_id_map.txt", default="tone_id_map.txt",
help="tone vocabulary file.") help="tone vocabulary file.")
parser.add_argument( parser.add_argument(
"--speaker-dict", "--speaker-dict", type=str, default=None, help="speaker id map file.")
type=str,
default=None,
help="speaker id map file.")
parser.add_argument( parser.add_argument(
"--dur-file", default=None, type=str, help="path to durations.txt.") "--dur-file", default=None, type=str, help="path to durations.txt.")
......
...@@ -272,9 +272,6 @@ class SpeedySpeechInference(nn.Layer): ...@@ -272,9 +272,6 @@ class SpeedySpeechInference(nn.Layer):
def forward(self, phones, tones, durations=None, spk_id=None): def forward(self, phones, tones, durations=None, spk_id=None):
normalized_mel = self.acoustic_model.inference( normalized_mel = self.acoustic_model.inference(
phones, phones, tones, durations=durations, spk_id=spk_id)
tones,
durations=durations,
spk_id=spk_id)
logmel = self.normalizer.inverse(normalized_mel) logmel = self.normalizer.inverse(normalized_mel)
return logmel return logmel
\ No newline at end of file
...@@ -20,6 +20,7 @@ import jsonlines ...@@ -20,6 +20,7 @@ import jsonlines
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
def main(): def main():
# parse config and args # parse config and args
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -61,9 +62,10 @@ def main(): ...@@ -61,9 +62,10 @@ def main():
try: try:
wav = np.load(old_dump_dir / sub / ("raw/" + wave_name)) wav = np.load(old_dump_dir / sub / ("raw/" + wave_name))
os.symlink(old_dump_dir / sub / ("raw/" + wave_name), os.symlink(old_dump_dir / sub / ("raw/" + wave_name),
output_dir / ("raw/" + wave_name)) output_dir / ("raw/" + wave_name))
except FileNotFoundError: except FileNotFoundError:
print("delete " + name + " because it cannot be found in the dump folder") print("delete " + name +
" because it cannot be found in the dump folder")
os.remove(output_dir / "raw" / name) os.remove(output_dir / "raw" / name)
continue continue
except FileExistsError: except FileExistsError:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册