提交 111a4523 编写于 作者: J Jerryuhoo

Fix the code format, test=tts

上级 1e710ef5
......@@ -15,21 +15,22 @@
# for mb melgan finetune
# 长度和原本的 mel 不一致怎么办?
import argparse
import os
from pathlib import Path
import numpy as np
import paddle
import yaml
from yacs.config import CfgNode
from tqdm import tqdm
import os
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.models.speedyspeech import SpeedySpeech
from paddlespeech.t2s.models.speedyspeech import SpeedySpeechInference
from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.t2s.frontend.zh_frontend import Frontend
def evaluate(args, speedyspeech_config):
rootdir = Path(args.rootdir).expanduser()
......@@ -50,17 +51,21 @@ def evaluate(args, speedyspeech_config):
tone_size = len(tone_id)
print("tone_size:", tone_size)
frontend = Frontend(phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
frontend = Frontend(
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
if args.speaker_dict:
with open(args.speaker_dict, 'rt') as f:
spk_id_list = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id_list)
else:
spk_num=None
spk_num = None
model = SpeedySpeech(
vocab_size=vocab_size, tone_size=tone_size, **speedyspeech_config["model"], spk_num=spk_num)
vocab_size=vocab_size,
tone_size=tone_size,
**speedyspeech_config["model"],
spk_num=spk_num)
model.set_state_dict(
paddle.load(args.speedyspeech_checkpoint)["main_params"])
......@@ -105,9 +110,15 @@ def evaluate(args, speedyspeech_config):
else:
train_wav_files += wav_files
train_wav_files = [os.path.basename(str(str_path)) for str_path in train_wav_files]
dev_wav_files = [os.path.basename(str(str_path)) for str_path in dev_wav_files]
test_wav_files = [os.path.basename(str(str_path)) for str_path in test_wav_files]
train_wav_files = [
os.path.basename(str(str_path)) for str_path in train_wav_files
]
dev_wav_files = [
os.path.basename(str(str_path)) for str_path in dev_wav_files
]
test_wav_files = [
os.path.basename(str(str_path)) for str_path in test_wav_files
]
for i, utt_id in enumerate(tqdm(sentences)):
phones = sentences[utt_id][0]
......@@ -122,8 +133,7 @@ def evaluate(args, speedyspeech_config):
durations = durations[:-1]
phones = phones[:-1]
phones, tones = frontend._get_phone_tone(
phones, get_tone_ids=True)
phones, tones = frontend._get_phone_tone(phones, get_tone_ids=True)
if tones:
tone_ids = frontend._t2id(tones)
tone_ids = paddle.to_tensor(tone_ids)
......@@ -132,7 +142,8 @@ def evaluate(args, speedyspeech_config):
phone_ids = paddle.to_tensor(phone_ids)
if args.speaker_dict:
speaker_id = int([item[1] for item in spk_id_list if speaker == item[0]][0])
speaker_id = int(
[item[1] for item in spk_id_list if speaker == item[0]][0])
speaker_id = paddle.to_tensor(speaker_id)
else:
speaker_id = None
......@@ -155,7 +166,8 @@ def evaluate(args, speedyspeech_config):
sub_output_dir.mkdir(parents=True, exist_ok=True)
with paddle.no_grad():
mel = speedyspeech_inference(phone_ids, tone_ids, durations=durations, spk_id=speaker_id)
mel = speedyspeech_inference(
phone_ids, tone_ids, durations=durations, spk_id=speaker_id)
np.save(sub_output_dir / (utt_id + "_feats.npy"), mel)
......@@ -193,10 +205,7 @@ def main():
default="tone_id_map.txt",
help="tone vocabulary file.")
parser.add_argument(
"--speaker-dict",
type=str,
default=None,
help="speaker id map file.")
"--speaker-dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
"--dur-file", default=None, type=str, help="path to durations.txt.")
......
......@@ -272,9 +272,6 @@ class SpeedySpeechInference(nn.Layer):
def forward(self, phones, tones, durations=None, spk_id=None):
normalized_mel = self.acoustic_model.inference(
phones,
tones,
durations=durations,
spk_id=spk_id)
phones, tones, durations=durations, spk_id=spk_id)
logmel = self.normalizer.inverse(normalized_mel)
return logmel
......@@ -20,6 +20,7 @@ import jsonlines
import numpy as np
from tqdm import tqdm
def main():
# parse config and args
parser = argparse.ArgumentParser(
......@@ -63,7 +64,8 @@ def main():
os.symlink(old_dump_dir / sub / ("raw/" + wave_name),
output_dir / ("raw/" + wave_name))
except FileNotFoundError:
print("delete " + name + " because it cannot be found in the dump folder")
print("delete " + name +
" because it cannot be found in the dump folder")
os.remove(output_dir / "raw" / name)
continue
except FileExistsError:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册