未验证 提交 fbe3c051 编写于 作者: 小湉湉's avatar 小湉湉 提交者: GitHub

add style_melgan and hifigan in tts cli, test=tts (#1241)

上级 a232cd8b
......@@ -178,6 +178,32 @@ pretrained_models = {
'speech_stats':
'feats_stats.npy',
},
# style_melgan
"style_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'5de2d5348f396de0c966926b8c462755',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
}
model_alias = {
......@@ -199,6 +225,14 @@ model_alias = {
"paddlespeech.t2s.models.melgan:MelGANGenerator",
"mb_melgan_inference":
"paddlespeech.t2s.models.melgan:MelGANInference",
"style_melgan":
"paddlespeech.t2s.models.melgan:StyleMelGANGenerator",
"style_melgan_inference":
"paddlespeech.t2s.models.melgan:StyleMelGANInference",
"hifigan":
"paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
"hifigan_inference":
"paddlespeech.t2s.models.hifigan:HiFiGANInference",
}
......@@ -266,7 +300,7 @@ class TTSExecutor(BaseExecutor):
default='pwgan_csmsc',
choices=[
'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
'mb_melgan_csmsc'
'mb_melgan_csmsc', 'style_melgan_csmsc', 'hifigan_csmsc'
],
help='Choose vocoder type of tts task.')
......@@ -504,37 +538,47 @@ class TTSExecutor(BaseExecutor):
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
get_tone_ids = False
merge_sentences = False
if am_name == 'speedyspeech':
get_tone_ids = True
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
text, merge_sentences=True, get_tone_ids=get_tone_ids)
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
phone_ids = phone_ids[0]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
tone_ids = tone_ids[0]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(text)
input_ids = self.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
# am
if am_name == 'speedyspeech':
mel = self.am_inference(phone_ids, tone_ids)
# fastspeech2
else:
# multi speaker
if am_dataset in {"aishell3", "vctk"}:
mel = self.am_inference(
phone_ids, spk_id=paddle.to_tensor(spk_id))
flags = 0
for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i]
# am
if am_name == 'speedyspeech':
part_tone_ids = tone_ids[i]
mel = self.am_inference(part_phone_ids, part_tone_ids)
# fastspeech2
else:
mel = self.am_inference(phone_ids)
# voc
wav = self.voc_inference(mel)
self._outputs['wav'] = wav
# multi speaker
if am_dataset in {"aishell3", "vctk"}:
mel = self.am_inference(
part_phone_ids, spk_id=paddle.to_tensor(spk_id))
else:
mel = self.am_inference(part_phone_ids)
# voc
wav = self.voc_inference(mel)
if flags == 0:
wav_all = wav
flags = 1
else:
wav_all = paddle.concat([wav_all, wav])
self._outputs['wav'] = wav_all
def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]:
"""
......
......@@ -196,41 +196,47 @@ def evaluate(args):
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
merge_sentences = False
for utt_id, sentence in sentences:
get_tone_ids = False
if am_name == 'speedyspeech':
get_tone_ids = True
if args.lang == 'zh':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=True, get_tone_ids=get_tone_ids)
sentence, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
phone_ids = phone_ids[0]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
tone_ids = tone_ids[0]
elif args.lang == 'en':
input_ids = frontend.get_input_ids(sentence)
input_ids = frontend.get_input_ids(sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
with paddle.no_grad():
# acoustic model
if am_name == 'fastspeech2':
# multi speaker
if am_dataset in {"aishell3", "vctk"}:
spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(phone_ids, spk_id)
flags = 0
for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i]
# acoustic model
if am_name == 'fastspeech2':
# multi speaker
if am_dataset in {"aishell3", "vctk"}:
spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, spk_id)
else:
mel = am_inference(part_phone_ids)
elif am_name == 'speedyspeech':
part_tone_ids = tone_ids[i]
mel = am_inference(part_phone_ids, part_tone_ids)
# vocoder
wav = voc_inference(mel)
if flags == 0:
wav_all = wav
flags = 1
else:
mel = am_inference(phone_ids)
elif am_name == 'speedyspeech':
mel = am_inference(phone_ids, tone_ids)
# vocoder
wav = voc_inference(mel)
wav_all = paddle.concat([wav_all, wav])
sf.write(
str(output_dir / (utt_id + ".wav")),
wav.numpy(),
wav_all.numpy(),
samplerate=am_config.fs)
print(f"{utt_id} done!")
......
......@@ -13,7 +13,9 @@
# limitations under the License.
from abc import ABC
from abc import abstractmethod
from typing import List
import numpy as np
import paddle
from g2p_en import G2p
from g2pM import G2pM
......@@ -21,6 +23,7 @@ from g2pM import G2pM
from paddlespeech.t2s.frontend.normalizer.normalizer import normalize
from paddlespeech.t2s.frontend.punctuation import get_punctuations
from paddlespeech.t2s.frontend.vocab import Vocab
from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer
# discard opencc untill we find an easy solution to install it on windows
# from opencc import OpenCC
......@@ -53,6 +56,7 @@ class English(Phonetics):
self.vocab = Vocab(self.phonemes + self.punctuations)
self.vocab_phones = {}
self.punc = ":,;。?!“”‘’':,;.?!"
self.text_normalizer = TextNormalizer()
if phone_vocab_path:
with open(phone_vocab_path, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
......@@ -78,19 +82,42 @@ class English(Phonetics):
phonemes = [item for item in phonemes if item in self.vocab.stoi]
return phonemes
def get_input_ids(self, sentence: str) -> paddle.Tensor:
result = {}
phones = self.phoneticize(sentence)
# remove start_symbol and end_symbol
phones = phones[1:-1]
phones = [phn for phn in phones if not phn.isspace()]
phones = [
def _p2id(self, phonemes: List[str]) -> np.array:
# replace unk phone with sp
phonemes = [
phn if (phn in self.vocab_phones and phn not in self.punc) else "sp"
for phn in phones
for phn in phonemes
]
phone_ids = [self.vocab_phones[phn] for phn in phones]
phone_ids = paddle.to_tensor(phone_ids)
result["phone_ids"] = phone_ids
phone_ids = [self.vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64)
def get_input_ids(self, sentence: str,
merge_sentences: bool=False) -> paddle.Tensor:
result = {}
sentences = self.text_normalizer._split(sentence, lang="en")
phones_list = []
temp_phone_ids = []
for sentence in sentences:
phones = self.phoneticize(sentence)
# remove start_symbol and end_symbol
phones = phones[1:-1]
phones = [phn for phn in phones if not phn.isspace()]
phones_list.append(phones)
if merge_sentences:
merge_list = sum(phones_list, [])
# rm the last 'sp' to avoid the noise at the end
# cause in the training data, no 'sp' in the end
if merge_list[-1] == 'sp':
merge_list = merge_list[:-1]
phones_list = []
phones_list.append(merge_list)
for part_phones_list in phones_list:
phone_ids = self._p2id(part_phones_list)
phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids)
result["phone_ids"] = temp_phone_ids
return result
def numericalize(self, phonemes):
......
......@@ -53,7 +53,7 @@ class TextNormalizer():
def __init__(self):
self.SENTENCE_SPLITOR = re.compile(r'([:,;。?!,;?!][”’]?)')
def _split(self, text: str) -> List[str]:
def _split(self, text: str, lang="zh") -> List[str]:
"""Split long text into sentences with sentence-splitting punctuations.
Parameters
----------
......@@ -65,7 +65,8 @@ class TextNormalizer():
Sentences.
"""
# Only for pure Chinese here
text = text.replace(" ", "")
if lang == "zh":
text = text.replace(" ", "")
text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
text = text.strip()
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册