未验证 提交 f4ac0c79 编写于 作者: 小湉湉's avatar 小湉湉 提交者: GitHub

Merge pull request #2143 from lym0302/mix_front

[tts] add mix frontend
...@@ -29,6 +29,7 @@ from yacs.config import CfgNode ...@@ -29,6 +29,7 @@ from yacs.config import CfgNode
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.mix_frontend import MixFrontend
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import from paddlespeech.utils.dynamic_import import dynamic_import
...@@ -98,6 +99,8 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'): ...@@ -98,6 +99,8 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
sentence = "".join(items[1:]) sentence = "".join(items[1:])
elif lang == 'en': elif lang == 'en':
sentence = " ".join(items[1:]) sentence = " ".join(items[1:])
elif lang == 'mix':
sentence = " ".join(items[1:])
sentences.append((utt_id, sentence)) sentences.append((utt_id, sentence))
return sentences return sentences
...@@ -111,7 +114,8 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]], ...@@ -111,7 +114,8 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]],
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
fields = ["utt_id", "text"] fields = ["utt_id", "text"]
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: if am_dataset in {"aishell3", "vctk",
"mix"} and speaker_dict is not None:
print("multiple speaker fastspeech2!") print("multiple speaker fastspeech2!")
fields += ["spk_id"] fields += ["spk_id"]
elif voice_cloning: elif voice_cloning:
...@@ -140,6 +144,10 @@ def get_frontend(lang: str='zh', ...@@ -140,6 +144,10 @@ def get_frontend(lang: str='zh',
phone_vocab_path=phones_dict, tone_vocab_path=tones_dict) phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
elif lang == 'en': elif lang == 'en':
frontend = English(phone_vocab_path=phones_dict) frontend = English(phone_vocab_path=phones_dict)
elif lang == 'mix':
frontend = MixFrontend(
phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
else: else:
print("wrong lang!") print("wrong lang!")
print("frontend done!") print("frontend done!")
...@@ -341,8 +349,12 @@ def get_am_output( ...@@ -341,8 +349,12 @@ def get_am_output(
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences) input, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
elif lang == 'mix':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else: else:
print("lang should in {'zh', 'en'}!") print("lang should in {'zh', 'en', 'mix'}!")
if get_tone_ids: if get_tone_ids:
tone_ids = input_ids["tone_ids"] tone_ids = input_ids["tone_ids"]
......
...@@ -113,8 +113,12 @@ def evaluate(args): ...@@ -113,8 +113,12 @@ def evaluate(args):
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences) sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
elif args.lang == 'mix':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else: else:
print("lang should in {'zh', 'en'}!") print("lang should in {'zh', 'en', 'mix'}!")
with paddle.no_grad(): with paddle.no_grad():
flags = 0 flags = 0
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
...@@ -122,7 +126,7 @@ def evaluate(args): ...@@ -122,7 +126,7 @@ def evaluate(args):
# acoustic model # acoustic model
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
# multi speaker # multi speaker
if am_dataset in {"aishell3", "vctk"}: if am_dataset in {"aishell3", "vctk", "mix"}:
spk_id = paddle.to_tensor(args.spk_id) spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, spk_id) mel = am_inference(part_phone_ids, spk_id)
else: else:
...@@ -170,7 +174,7 @@ def parse_args(): ...@@ -170,7 +174,7 @@ def parse_args():
choices=[ choices=[
'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc', 'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc',
'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk', 'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk',
'tacotron2_csmsc', 'tacotron2_ljspeech' 'tacotron2_csmsc', 'tacotron2_ljspeech', 'fastspeech2_mix'
], ],
help='Choose acoustic model type of tts task.') help='Choose acoustic model type of tts task.')
parser.add_argument( parser.add_argument(
...@@ -231,7 +235,7 @@ def parse_args(): ...@@ -231,7 +235,7 @@ def parse_args():
'--lang', '--lang',
type=str, type=str,
default='zh', default='zh',
help='Choose model language. zh or en') help='Choose model language. zh or en or mix')
parser.add_argument( parser.add_argument(
"--inference_dir", "--inference_dir",
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict
from typing import List
import paddle
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
class MixFrontend():
def __init__(self,
g2p_model="pypinyin",
phone_vocab_path=None,
tone_vocab_path=None):
self.zh_frontend = Frontend(
phone_vocab_path=phone_vocab_path, tone_vocab_path=tone_vocab_path)
self.en_frontend = English(phone_vocab_path=phone_vocab_path)
self.SENTENCE_SPLITOR = re.compile(r'([:、,;。?!,;?!][”’]?)')
self.sp_id = self.zh_frontend.vocab_phones["sp"]
self.sp_id_tensor = paddle.to_tensor([self.sp_id])
def is_chinese(self, char):
if char >= '\u4e00' and char <= '\u9fa5':
return True
else:
return False
def is_alphabet(self, char):
if (char >= '\u0041' and char <= '\u005a') or (char >= '\u0061' and
char <= '\u007a'):
return True
else:
return False
def is_number(self, char):
if char >= '\u0030' and char <= '\u0039':
return True
else:
return False
def is_other(self, char):
if not (self.is_chinese(char) or self.is_number(char) or
self.is_alphabet(char)):
return True
else:
return False
def _split(self, text: str) -> List[str]:
text = re.sub(r'[《》【】<=>{}()()#&@“”^_|…\\]', '', text)
text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
text = text.strip()
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
return sentences
def _distinguish(self, text: str) -> List[str]:
# sentence --> [ch_part, en_part, ch_part, ...]
segments = []
types = []
flag = 0
temp_seg = ""
temp_lang = ""
# Determine the type of each character. type: blank, chinese, alphabet, number, unk.
for ch in text:
if self.is_chinese(ch):
types.append("zh")
elif self.is_alphabet(ch):
types.append("en")
elif ch == " ":
types.append("blank")
elif self.is_number(ch):
types.append("num")
else:
types.append("unk")
assert len(types) == len(text)
for i in range(len(types)):
# find the first char of the seg
if flag == 0:
if types[i] != "unk" and types[i] != "blank":
temp_seg += text[i]
temp_lang = types[i]
flag = 1
else:
if types[i] == temp_lang or types[i] == "num":
temp_seg += text[i]
elif temp_lang == "num" and types[i] != "unk":
temp_seg += text[i]
if types[i] == "zh" or types[i] == "en":
temp_lang = types[i]
elif temp_lang == "en" and types[i] == "blank":
temp_seg += text[i]
elif types[i] == "unk":
pass
else:
segments.append((temp_seg, temp_lang))
if types[i] != "unk" and types[i] != "blank":
temp_seg = text[i]
temp_lang = types[i]
flag = 1
else:
flag = 0
temp_seg = ""
temp_lang = ""
segments.append((temp_seg, temp_lang))
return segments
def get_input_ids(self,
sentence: str,
merge_sentences: bool=True,
get_tone_ids: bool=False,
add_sp: bool=True) -> Dict[str, List[paddle.Tensor]]:
sentences = self._split(sentence)
phones_list = []
result = {}
for text in sentences:
phones_seg = []
segments = self._distinguish(text)
for seg in segments:
content = seg[0]
lang = seg[1]
if lang == "zh":
input_ids = self.zh_frontend.get_input_ids(
content,
merge_sentences=True,
get_tone_ids=get_tone_ids)
elif lang == "en":
input_ids = self.en_frontend.get_input_ids(
content, merge_sentences=True)
phones_seg.append(input_ids["phone_ids"][0])
if add_sp:
phones_seg.append(self.sp_id_tensor)
phones = paddle.concat(phones_seg)
phones_list.append(phones)
if merge_sentences:
merge_list = paddle.concat(phones_list)
# rm the last 'sp' to avoid the noise at the end
# cause in the training data, no 'sp' in the end
if merge_list[-1] == self.sp_id_tensor:
merge_list = merge_list[:-1]
phones_list = []
phones_list.append(merge_list)
result["phone_ids"] = phones_list
return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册