diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f72b44ac609927d8c34276190a0c03191e8498eb..44bbd5cadcd4b72e736ded902ab828ed98ec776f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,12 +26,12 @@ repos: - --no-sort-keys - --autofix - id: check-merge-conflict - - id: flake8 - aergs: - - --ignore=E501,E228,E226,E261,E266,E128,E402,W503 - - --builtins=G,request - - --jobs=1 - exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$ + # - id: flake8 + # aergs: + # - --ignore=E501,E228,E226,E261,E266,E128,E402,W503 + # - --builtins=G,request + # - --jobs=1 + # exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$ - repo : https://github.com/Lucas-C/pre-commit-hooks rev: v1.0.1 diff --git a/audio/setup.py b/audio/setup.py index 0fe6e59954741fb32eb8ecae3f7f38a51ca74ede..f7d4594469c89355597a4ba0630b39c3d5c085f7 100644 --- a/audio/setup.py +++ b/audio/setup.py @@ -38,8 +38,10 @@ VERSION = '1.2.0' COMMITID = 'none' base = [ + # paddleaudio align with librosa==0.8.1, which need numpy==1.23.x + "librosa==0.8.1", + "numpy==1.23.5", "kaldiio", - "librosa>=0.10.0", "pathos", "pybind11", "parameterized", diff --git a/paddlespeech/server/engine/tts/online/onnx/tts_engine.py b/paddlespeech/server/engine/tts/online/onnx/tts_engine.py index 0995a55daa7ef8c0450c96ef94ad54bbeb277d5a..9dd31a08b041bff5c2426e56aef455fcb5e3383d 100644 --- a/paddlespeech/server/engine/tts/online/onnx/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/onnx/tts_engine.py @@ -28,7 +28,7 @@ from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.onnx_infer import get_sess from paddlespeech.server.utils.util import denorm from paddlespeech.server.utils.util import get_chunks -from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.en_frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend __all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] diff --git a/paddlespeech/server/engine/tts/online/python/tts_engine.py b/paddlespeech/server/engine/tts/online/python/tts_engine.py index a46b84bd969e56e0e0650990ca30560cfaadc902..0cfb20354449cbe9b4628b95484d974fddd49414 100644 --- a/paddlespeech/server/engine/tts/online/python/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/python/tts_engine.py @@ -29,7 +29,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.util import denorm from paddlespeech.server.utils.util import get_chunks -from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.en_frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.modules.normalizer import ZScore diff --git a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py index 20b98fae6765bae9278a652b49f8f3eaec3cacfe..3a6461f8cc78b0da850dbf11602b31291394e9f1 100644 --- a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py +++ b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py @@ -32,7 +32,7 @@ from paddlespeech.server.utils.errors import ErrorCode from paddlespeech.server.utils.exception import ServerBaseException from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import run_model -from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.en_frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend __all__ = ['TTSEngine', 'PaddleTTSConnectionHandler'] diff --git a/paddlespeech/t2s/__init__.py b/paddlespeech/t2s/__init__.py index 57fe82a9c68f6eea00487c06285570d25b334909..7d93c026ecedda485d52b84c349e8fc1806daaf5 100644 --- a/paddlespeech/t2s/__init__.py +++ b/paddlespeech/t2s/__init__.py @@ -18,6 +18,5 @@ from . import exps from . import frontend from . import models from . import modules -from . import ssml from . import training from . import utils diff --git a/paddlespeech/t2s/assets/__init__.py b/paddlespeech/t2s/assets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..595add0aed9e110889fb8cb1e07a1b8d5877e441 --- /dev/null +++ b/paddlespeech/t2s/assets/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/t2s/assets/sentences_mix.txt b/paddlespeech/t2s/assets/sentences_mix.txt index 06e97d14a82070f9e80f28a7e1d0468c8afce88c..bfa0db63653e45b55afc4827c4297259c9e4adb8 100644 --- a/paddlespeech/t2s/assets/sentences_mix.txt +++ b/paddlespeech/t2s/assets/sentences_mix.txt @@ -5,4 +5,5 @@ 005 Paddle Bo Bo: 使用 Paddle Speech 的语音合成模块生成虚拟人的声音。 006 热烈欢迎您在 Discussions 中提交问题,并在 Issues 中指出发现的 bug。此外,我们非常希望您参与到 Paddle Speech 的开发中! 007 我喜欢 eat apple, 你喜欢 drink milk。 -008 我们要去云南 team building, 非常非常 happy. \ No newline at end of file +008 我们要去云南 team building, 非常非常 happy. +009 AI for Sceience 平台。 \ No newline at end of file diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index 57c79dee17d1abebbb4e21c97a282d8a058a421b..9a07df64de8cf9244cbaa056ea22b05260ea18f3 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -33,8 +33,8 @@ from yacs.config import CfgNode from paddlespeech.t2s.datasets.am_batch_fn import * from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.vocoder_batch_fn import Clip_static -from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.canton_frontend import CantonFrontend +from paddlespeech.t2s.frontend.en_frontend import English from paddlespeech.t2s.frontend.mix_frontend import MixFrontend from paddlespeech.t2s.frontend.sing_frontend import SingFrontend from paddlespeech.t2s.frontend.zh_frontend import Frontend @@ -99,14 +99,23 @@ def norm(data, mean, std): return (data - mean) / std -def get_chunks(data, block_size: int, pad_size: int): - data_len = data.shape[1] +def get_chunks(mel, chunk_size: int, pad_size: int): + """ + Split mel by chunk size with left and right context. + + Args: + mel (paddle.Tensor): mel spectrogram, shape (B, T, D) + chunk_size (int): chunk size + pad_size (int): size for left and right context. + """ + T = mel.shape[1] + n = math.ceil(T / chunk_size) + chunks = [] - n = math.ceil(data_len / block_size) for i in range(n): - start = max(0, i * block_size - pad_size) - end = min((i + 1) * block_size + pad_size, data_len) - chunks.append(data[:, start:end, :]) + start = max(0, i * chunk_size - pad_size) + end = min((i + 1) * chunk_size + pad_size, T) + chunks.append(mel[:, start:end, :]) return chunks @@ -117,14 +126,10 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'): with open(text_file, 'rt', encoding='utf-8') as f: for line in f: if line.strip() != "": - items = re.split(r"\s+", line.strip(), 1) + items = re.split(r"\s+", line.strip(), maxsplit=1) + assert len(items) == 2 utt_id = items[0] - if lang in {'zh', 'canton'}: - sentence = "".join(items[1:]) - elif lang == 'en': - sentence = " ".join(items[1:]) - elif lang == 'mix': - sentence = " ".join(items[1:]) + sentence = items[1] sentences.append((utt_id, sentence)) return sentences @@ -319,6 +324,7 @@ def run_frontend( input_ids = {} if text.strip() != "" and re.match(r".*?.*?.*", text, re.DOTALL): + # using ssml input_ids = frontend.get_input_ids_ssml( text, merge_sentences=merge_sentences, @@ -359,6 +365,7 @@ def run_frontend( outs.update({'is_slurs': is_slurs}) else: print("lang should in {'zh', 'en', 'mix', 'canton', 'sing'}!") + outs.update({'phone_ids': phone_ids}) return outs diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 0c7b34b096670bb293846334abd51a7be458c77a..cafd065a3fc7f789c6972fef1aea81abd09cabc3 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -13,6 +13,7 @@ # limitations under the License. import argparse from pathlib import Path +from pprint import pprint import paddle import soundfile as sf @@ -78,6 +79,7 @@ def evaluate(args): # whether dygraph to static if args.inference_dir: + print("convert am and voc to static model.") # acoustic model am_inference = am_to_static( am_inference=am_inference, @@ -92,6 +94,7 @@ def evaluate(args): output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) + merge_sentences = False # Avoid not stopping at the end of a sub sentence when tacotron2_ljspeech dygraph to static graph # but still not stopping in the end (NOTE by yuantian01 Feb 9 2022) @@ -102,13 +105,19 @@ def evaluate(args): if am_name == 'speedyspeech': get_tone_ids = True + # wav samples N = 0 + # inference time cost T = 0 + + # [(uid, text), ] if am_name == 'diffsinger': sentences = get_sentences_svs(text_file=args.text) else: sentences = get_sentences(text_file=args.text, lang=args.lang) + for utt_id, sentence in sentences: + print(f"{utt_id} {sentence}") with timer() as t: if am_name == "diffsinger": text = "" @@ -116,6 +125,8 @@ def evaluate(args): else: text = sentence svs_input = None + + # frontend frontend_dict = run_frontend( frontend=frontend, text=text, @@ -124,25 +135,33 @@ def evaluate(args): lang=args.lang, svs_input=svs_input) phone_ids = frontend_dict['phone_ids'] + # pprint(f"{utt_id} {phone_ids}") + with paddle.no_grad(): flags = 0 for i in range(len(phone_ids)): + # sub phone, split by `sp` or punctuation. part_phone_ids = phone_ids[i] + # acoustic model if am_name == 'fastspeech2': # multi speaker if am_dataset in {"aishell3", "vctk", "mix", "canton"}: - spk_id = paddle.to_tensor(args.spk_id) + # multi-speaker + spk_id = paddle.to_tensor([args.spk_id]) mel = am_inference(part_phone_ids, spk_id) else: + # single-speaker mel = am_inference(part_phone_ids) elif am_name == 'speedyspeech': part_tone_ids = frontend_dict['tone_ids'][i] if am_dataset in {"aishell3", "vctk", "mix"}: - spk_id = paddle.to_tensor(args.spk_id) + # multi-speaker + spk_id = paddle.to_tensor([args.spk_id]) mel = am_inference(part_phone_ids, part_tone_ids, spk_id) else: + # single-speaker mel = am_inference(part_phone_ids, part_tone_ids) elif am_name == 'tacotron2': mel = am_inference(part_phone_ids) @@ -155,6 +174,7 @@ def evaluate(args): note=part_note_ids, note_dur=part_note_durs, is_slur=part_is_slurs, ) + # vocoder wav = voc_inference(mel) if flags == 0: @@ -162,17 +182,23 @@ def evaluate(args): flags = 1 else: wav_all = paddle.concat([wav_all, wav]) + wav = wav_all.numpy() N += wav.size T += t.elapse + + # samples per second speed = wav.size / t.elapse + # generate one second wav need `RTF` seconds rtf = am_config.fs / speed print( f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." ) + sf.write( str(output_dir / (utt_id + ".wav")), wav, samplerate=am_config.fs) print(f"{utt_id} done!") + print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }") diff --git a/paddlespeech/t2s/exps/transformer_tts/preprocess.py b/paddlespeech/t2s/exps/transformer_tts/preprocess.py index 2ebd5ecc2fdbc0ebd69203b779b71809e9fad8c9..4e82e53ff1211dfd4f88ab0e7bce355c72c9ece3 100644 --- a/paddlespeech/t2s/exps/transformer_tts/preprocess.py +++ b/paddlespeech/t2s/exps/transformer_tts/preprocess.py @@ -27,7 +27,7 @@ import yaml from yacs.config import CfgNode as Configuration from paddlespeech.t2s.datasets.get_feats import LogMelFBank -from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.en_frontend import English def get_lj_sentences(file_name, frontend): diff --git a/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py b/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py index 0cd7d224e0b983d1461bfd42ffebd812f79380b4..279407b386157ca2a3635c4c87dc048fea332b30 100644 --- a/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py @@ -21,7 +21,7 @@ import soundfile as sf import yaml from yacs.config import CfgNode -from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.en_frontend import English from paddlespeech.t2s.models.transformer_tts import TransformerTTS from paddlespeech.t2s.models.transformer_tts import TransformerTTSInference from paddlespeech.t2s.models.waveflow import ConditionalWaveFlow diff --git a/paddlespeech/t2s/frontend/__init__.py b/paddlespeech/t2s/frontend/__init__.py index 64015435eefd7a8f1d3369a49cb0be7e10c8ec60..a8f77d5522862c92b7bd30bf299806bc1d836eab 100644 --- a/paddlespeech/t2s/frontend/__init__.py +++ b/paddlespeech/t2s/frontend/__init__.py @@ -13,8 +13,8 @@ # limitations under the License. from .generate_lexicon import * from .normalizer import * -from .phonectic import * from .punctuation import * +from .ssml import * from .tone_sandhi import * from .vocab import * from .zh_normalization import * diff --git a/paddlespeech/t2s/frontend/arpabet.py b/paddlespeech/t2s/frontend/arpabet.py index 7a81b645d426c618d49f2ded1acd73d1bc9ccbbe..9b2b11b3d7b6a7704dfe99b46c3663cd22a3a22f 100644 --- a/paddlespeech/t2s/frontend/arpabet.py +++ b/paddlespeech/t2s/frontend/arpabet.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from paddlespeech.t2s.frontend.phonectic import Phonetics """ A phonology system with ARPABET symbols and limited punctuations. The G2P conversion is done by g2p_en. @@ -19,55 +18,68 @@ conversion is done by g2p_en. Note that g2p_en does not handle words with hypen well. So make sure the input sentence is first normalized. """ -from paddlespeech.t2s.frontend.vocab import Vocab from g2p_en import G2p +from paddlespeech.t2s.frontend.phonectic import Phonetics +from paddlespeech.t2s.frontend.vocab import Vocab + class ARPABET(Phonetics): - """A phonology for English that uses ARPABET as the phoneme vocabulary. + """A phonology for English that uses ARPABET without stress as the phoneme vocabulary. + + 47 symbols = 39 phones + 4 punctuations + 4 special tokens( ) + + The current phoneme set contains 39 phonemes, vowels carry a lexical stress marker: + 0 — No stress + 1 — Primary stress + 2 — Secondary stress + + Phoneme Set: + Phoneme Example Translation + ------- ------- ----------- + AA odd AA D + AE at AE T + AH hut HH AH T + AO ought AO T + AW cow K AW + AY hide HH AY D + B be B IY + CH cheese CH IY Z + D dee D IY + DH thee DH IY + EH Ed EH D + ER hurt HH ER T + EY ate EY T + F fee F IY + G green G R IY N + HH he HH IY + IH it IH T + IY eat IY T + JH gee JH IY + K key K IY + L lee L IY + M me M IY + N knee N IY + NG ping P IH NG + OW oat OW T + OY toy T OY + P pee P IY + R read R IY D + S sea S IY + SH she SH IY + T tea T IY + TH theta TH EY T AH + UH hood HH UH D + UW two T UW + V vee V IY + W we W IY + Y yield Y IY L D + Z zee Z IY + ZH seizure S IY ZH ER + See http://www.speech.cs.cmu.edu/cgi-bin/cmudict for more details. - Phoneme Example Translation - ------- ------- ----------- - AA odd AA D - AE at AE T - AH hut HH AH T - AO ought AO T - AW cow K AW - AY hide HH AY D - B be B IY - CH cheese CH IY Z - D dee D IY - DH thee DH IY - EH Ed EH D - ER hurt HH ER T - EY ate EY T - F fee F IY - G green G R IY N - HH he HH IY - IH it IH T - IY eat IY T - JH gee JH IY - K key K IY - L lee L IY - M me M IY - N knee N IY - NG ping P IH NG - OW oat OW T - OY toy T OY - P pee P IY - R read R IY D - S sea S IY - SH she SH IY - T tea T IY - TH theta TH EY T AH - UH hood HH UH D - UW two T UW - V vee V IY - W we W IY - Y yield Y IY L D - Z zee Z IY - ZH seizure S IY ZH ER """ + # 39 phonemes phonemes = [ 'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', @@ -76,6 +88,8 @@ class ARPABET(Phonetics): ] punctuations = [',', '.', '?', '!'] symbols = phonemes + punctuations + # vowels carry a lexical stress marker: + # 0 unstressed(无重音), 1 primary stress(主重音)和 2 secondary stress(次重音) _stress_to_no_stress_ = { 'AA0': 'AA', 'AA1': 'AA', @@ -124,7 +138,12 @@ class ARPABET(Phonetics): 'UW2': 'UW' } + def __repr__(self): + fmt = "ARPABETWithoutStress(phonemes: {}, punctuations: {})" + return fmt.format(len(phonemes), punctuations) + def __init__(self): + # https://github.com/Kyubyong/g2p/blob/master/g2p_en/g2p.py self.backend = G2p() self.vocab = Vocab(self.phonemes + self.punctuations) @@ -139,6 +158,7 @@ class ARPABET(Phonetics): Returns: List[str]: The list of pronunciation sequence. """ + # g2p and remove vowel stress phonemes = [ self._remove_vowels(item) for item in self.backend(sentence) ] @@ -158,6 +178,7 @@ class ARPABET(Phonetics): Returns: List[int]: The list of pronunciation id sequence. """ + # phonemes to ids ids = [self.vocab.lookup(item) for item in phonemes] return ids @@ -189,11 +210,16 @@ class ARPABET(Phonetics): def vocab_size(self): """ Vocab size. """ - # 47 = 39 phones + 4 punctuations + 4 special tokens + # 47 = 39 phones + 4 punctuations + 4 special tokens( ) return len(self.vocab) class ARPABETWithStress(Phonetics): + """ + A phonology for English that uses ARPABET with stress as the phoneme vocabulary. + + 77 symbols = 69 phones + 4 punctuations + 4 special tokens + """ phonemes = [ 'AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0', 'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', @@ -206,6 +232,10 @@ class ARPABETWithStress(Phonetics): punctuations = [',', '.', '?', '!'] symbols = phonemes + punctuations + def __repr__(self): + fmt = "ARPABETWithStress(phonemes: {}, punctuations: {})" + return fmt.format(len(phonemes), punctuations) + def __init__(self): self.backend = G2p() self.vocab = Vocab(self.phonemes + self.punctuations) diff --git a/paddlespeech/t2s/frontend/canton_frontend.py b/paddlespeech/t2s/frontend/canton_frontend.py index f2c7175fe51fb7c456e18188264b73d8c2619560..bbb7bcf00643b320905180c997e5b1197382882d 100644 --- a/paddlespeech/t2s/frontend/canton_frontend.py +++ b/paddlespeech/t2s/frontend/canton_frontend.py @@ -29,7 +29,8 @@ INITIALS = [ INITIALS += ['sp', 'spl', 'spn', 'sil'] -def get_lines(cantons: List[str]): +def jyuping_to_phonemes(cantons: List[str]): + # jyuping to inital and final phones = [] for canton in cantons: for consonant in INITIALS: @@ -47,7 +48,7 @@ def get_lines(cantons: List[str]): class CantonFrontend(): def __init__(self, phone_vocab_path: str): self.text_normalizer = TextNormalizer() - self.punc = ":,;。?!“”‘’':,;.?!" + self.punc = "、:,;。?!“”‘’':,;.?!" self.vocab_phones = {} if phone_vocab_path: @@ -61,8 +62,11 @@ class CantonFrontend(): merge_sentences: bool=True) -> List[List[str]]: phones_list = [] for sentence in sentences: + # jyuping + # 'gam3 ngaam1 lou5 sai3 jiu1 kau4 keoi5 dang2 zan6 jiu3 hoi1 wui2, zing6 dai1 ge2 je5 ngo5 wui5 gaau2 dim6 ga3 laa3.' phones_str = ToJyutping.get_jyutping_text(sentence) - phones_split = get_lines(phones_str.split(' ')) + # phonemes + phones_split = jyuping_to_phonemes(phones_str.split(' ')) phones_list.append(phones_split) return phones_list @@ -78,8 +82,11 @@ class CantonFrontend(): sentence: str, merge_sentences: bool=True, print_info: bool=False) -> List[List[str]]: + # TN & Text Segmentation sentences = self.text_normalizer.normalize(sentence) + # G2P phonemes = self._g2p(sentences, merge_sentences=merge_sentences) + if print_info: print("----------------------------") print("text norm results:") @@ -88,6 +95,7 @@ class CantonFrontend(): print("g2p results:") print(phonemes) print("----------------------------") + return phonemes def get_input_ids(self, @@ -98,9 +106,9 @@ class CantonFrontend(): phonemes = self.get_phonemes( sentence, merge_sentences=merge_sentences, print_info=print_info) + result = {} temp_phone_ids = [] - for phones in phonemes: if phones: phone_ids = self._p2id(phones) @@ -108,6 +116,8 @@ class CantonFrontend(): if to_tensor: phone_ids = paddle.to_tensor(phone_ids) temp_phone_ids.append(phone_ids) + if temp_phone_ids: result["phone_ids"] = temp_phone_ids + return result diff --git a/paddlespeech/t2s/frontend/en_frontend.py b/paddlespeech/t2s/frontend/en_frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..c58bed7d3b92d8d29ae13e58d28f91ba3892674d --- /dev/null +++ b/paddlespeech/t2s/frontend/en_frontend.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .phonectic import English diff --git a/paddlespeech/t2s/frontend/mix_frontend.py b/paddlespeech/t2s/frontend/mix_frontend.py index b8c16097c44b3265ed32682605f292411ccb8ad0..2ebfe135eeaa04b4147dd7e1c026930a5b304e93 100644 --- a/paddlespeech/t2s/frontend/mix_frontend.py +++ b/paddlespeech/t2s/frontend/mix_frontend.py @@ -18,9 +18,9 @@ from typing import List import numpy as np import paddle -from paddlespeech.t2s.frontend import English -from paddlespeech.t2s.frontend.zh_frontend import Frontend -from paddlespeech.t2s.ssml.xml_processor import MixTextProcessor +from paddlespeech.t2s.frontend.en_frontend import English as EnFrontend +from paddlespeech.t2s.frontend.ssml.xml_processor import MixTextProcessor +from paddlespeech.t2s.frontend.zh_frontend import Frontend as ZhFrontend class MixFrontend(): @@ -28,10 +28,9 @@ class MixFrontend(): g2p_model="pypinyin", phone_vocab_path=None, tone_vocab_path=None): - - self.zh_frontend = Frontend( + self.zh_frontend = ZhFrontend( phone_vocab_path=phone_vocab_path, tone_vocab_path=tone_vocab_path) - self.en_frontend = English(phone_vocab_path=phone_vocab_path) + self.en_frontend = EnFrontend(phone_vocab_path=phone_vocab_path) self.sp_id = self.zh_frontend.vocab_phones["sp"] self.sp_id_numpy = np.array([self.sp_id]) self.sp_id_tensor = paddle.to_tensor([self.sp_id]) @@ -55,15 +54,12 @@ class MixFrontend(): else: return False - def get_segment(self, text: str) -> List[str]: + def split_by_lang(self, text: str) -> List[str]: # sentence --> [ch_part, en_part, ch_part, ...] segments = [] types = [] - flag = 0 - temp_seg = "" - temp_lang = "" - # Determine the type of each character. type: blank, chinese, alphabet, number, unk and point. + # Determine the type of each character. type: chinese, alphabet, other. for ch in text: if self.is_chinese(ch): types.append("zh") @@ -74,31 +70,31 @@ class MixFrontend(): assert len(types) == len(text) - for i in range(len(types)): + flag = 0 + temp_seg = "" + temp_lang = "" + + for i in range(len(text)): # find the first char of the seg if flag == 0: temp_seg += text[i] temp_lang = types[i] flag = 1 - else: if temp_lang == "other": - if types[i] == temp_lang: - temp_seg += text[i] - else: - temp_seg += text[i] + # text start is not lang. + temp_seg += text[i] + if types[i] != temp_lang: temp_lang = types[i] - else: - if types[i] == temp_lang: - temp_seg += text[i] - elif types[i] == "other": + if types[i] == temp_lang or types[i] == "other": + # merge same lang or other temp_seg += text[i] else: + # change lang segments.append((temp_seg, temp_lang)) temp_seg = text[i] - temp_lang = types[i] - flag = 1 + temp_lang = types[i] # new lang segments.append((temp_seg, temp_lang)) @@ -110,76 +106,95 @@ class MixFrontend(): get_tone_ids: bool=False, add_sp: bool=True, to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]: - ''' 1. 添加SSML支持,先列出 文字 和 标签内容, - 然后添加到tmpSegments数组里 - ''' - d_inputs = MixTextProcessor.get_dom_split(sentence) - tmpSegments = [] - for instr in d_inputs: - ''' 暂时只支持 say-as ''' - if instr.lower().startswith("" segments.append(tuple(currentSeg)) + # en segments.append(seg) + # reset currentSeg = ["", ""] else: + # zh if currentSeg[0] == '': + # first see currentSeg[0] = seg[0] currentSeg[1] = seg[1] else: + # merge zh currentSeg[0] = currentSeg[0] + seg[0] + if currentSeg[0] != '': + # last zh currentSeg[0] = "" + currentSeg[0] + "" segments.append(tuple(currentSeg)) phones_list = [] result = {} + # 008 我们要去云南 team building, 非常非常 happy. + # seg ('我们要去云南 ', 'zh') + # seg ('team building, ', 'en') + # seg ('非常非常 ', 'zh') + # seg ('happy.', 'en') + # [('我们要去云南 ', 'zh'), ('team building, ', 'en'), ('非常非常 ', 'zh'), ('happy.', 'en')] for seg in segments: content = seg[0] lang = seg[1] - if content != '': - if lang == "en": - input_ids = self.en_frontend.get_input_ids( - content, merge_sentences=False, to_tensor=to_tensor) + + if not content: + continue + + if lang == "en": + input_ids = self.en_frontend.get_input_ids( + content, merge_sentences=False, to_tensor=to_tensor) + else: + if content.strip() != "" and \ + re.match(r".*?.*?.*", content, re.DOTALL): + # process ssml + input_ids = self.zh_frontend.get_input_ids_ssml( + content, + merge_sentences=False, + get_tone_ids=get_tone_ids, + to_tensor=to_tensor) else: - ''' 3. 把带speak tag的中文和普通文字分开处理 - ''' - if content.strip() != "" and \ - re.match(r".*?.*?.*", content, re.DOTALL): - input_ids = self.zh_frontend.get_input_ids_ssml( - content, - merge_sentences=False, - get_tone_ids=get_tone_ids, - to_tensor=to_tensor) - else: - input_ids = self.zh_frontend.get_input_ids( - content, - merge_sentences=False, - get_tone_ids=get_tone_ids, - to_tensor=to_tensor) - if add_sp: - if to_tensor: - input_ids["phone_ids"][-1] = paddle.concat( - [input_ids["phone_ids"][-1], self.sp_id_tensor]) - else: - input_ids["phone_ids"][-1] = np.concatenate( - (input_ids["phone_ids"][-1], self.sp_id_numpy)) + # process plain text + input_ids = self.zh_frontend.get_input_ids( + content, + merge_sentences=False, + get_tone_ids=get_tone_ids, + to_tensor=to_tensor) + + if add_sp: + # add sp between zh and en + if to_tensor: + input_ids["phone_ids"][-1] = paddle.concat( + [input_ids["phone_ids"][-1], self.sp_id_tensor]) + else: + input_ids["phone_ids"][-1] = np.concatenate( + (input_ids["phone_ids"][-1], self.sp_id_numpy)) - for phones in input_ids["phone_ids"]: - phones_list.append(phones) + phones_list.extend(input_ids["phone_ids"]) if merge_sentences: merge_list = paddle.concat(phones_list) diff --git a/paddlespeech/t2s/frontend/phonectic.py b/paddlespeech/t2s/frontend/phonectic.py index af86d9b80a47689a1cf27f7cbc766ce68b1ac5e8..d6c66f1e041de7eff28096e64fe4fd330c7acdd0 100644 --- a/paddlespeech/t2s/frontend/phonectic.py +++ b/paddlespeech/t2s/frontend/phonectic.py @@ -47,15 +47,34 @@ class Phonetics(ABC): class English(Phonetics): """ Normalize the input text sequence and convert into pronunciation id sequence. + + https://github.com/Kyubyong/g2p/blob/master/g2p_en/g2p.py + + phonemes = ["", "", "", ""] + [ + 'AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0', + 'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', + 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', + 'EY2', 'F', 'G', 'HH', + 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', + 'M', 'N', 'NG', 'OW0', 'OW1', + 'OW2', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', + 'UH0', 'UH1', 'UH2', 'UW', + 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'] """ + LEXICON = { + # key using lowercase + "AI".lower(): [["EY0", "AY1"]], + } + def __init__(self, phone_vocab_path=None): self.backend = G2p() + self.backend.cmu.update(English.LEXICON) self.phonemes = list(self.backend.phonemes) self.punctuations = get_punctuations("en") self.vocab = Vocab(self.phonemes + self.punctuations) self.vocab_phones = {} - self.punc = ":,;。?!“”‘’':,;.?!" + self.punc = "、:,;。?!“”‘’':,;.?!" self.text_normalizer = TextNormalizer() if phone_vocab_path: with open(phone_vocab_path, 'rt', encoding='utf-8') as f: @@ -86,8 +105,8 @@ class English(Phonetics): sentence: str, merge_sentences: bool=False, to_tensor: bool=True) -> paddle.Tensor: - result = {} sentences = self.text_normalizer._split(sentence, lang="en") + phones_list = [] temp_phone_ids = [] for sentence in sentences: @@ -118,7 +137,10 @@ class English(Phonetics): if to_tensor: phone_ids = paddle.to_tensor(phone_ids) temp_phone_ids.append(phone_ids) + + result = {} result["phone_ids"] = temp_phone_ids + return result def numericalize(self, phonemes): diff --git a/paddlespeech/t2s/frontend/polyphonic.py b/paddlespeech/t2s/frontend/polyphonic.py new file mode 100644 index 0000000000000000000000000000000000000000..9a757e20438b90dc6797824bfcc9985a4c7f52e2 --- /dev/null +++ b/paddlespeech/t2s/frontend/polyphonic.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import yaml + + +class Polyphonic(): + def __init__(self): + with open( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'polyphonic.yaml'), + 'r', + encoding='utf-8') as polyphonic_file: + # 解析yaml + polyphonic_dict = yaml.load(polyphonic_file, Loader=yaml.FullLoader) + self.polyphonic_words = polyphonic_dict["polyphonic"] + + def correct_pronunciation(self, word, pinyin): + # 词汇被词典收录则返回纠正后的读音 + if word in self.polyphonic_words.keys(): + pinyin = self.polyphonic_words[word] + # 否则返回原读音 + return pinyin diff --git a/paddlespeech/t2s/frontend/polyphonic.yaml b/paddlespeech/t2s/frontend/polyphonic.yaml index 6885035e743d78d96ad88e1e64868cfe165157eb..f52b1cf589f3f23f05fa22767347a3950791f2e4 100644 --- a/paddlespeech/t2s/frontend/polyphonic.yaml +++ b/paddlespeech/t2s/frontend/polyphonic.yaml @@ -47,4 +47,8 @@ polyphonic: 恶行: ['e4','xing2'] 唉: ['ai4'] 扎实: ['zha1','shi2'] - 干将: ['gan4','jiang4'] \ No newline at end of file + 干将: ['gan4','jiang4'] + 陈威行: ['chen2', 'wei1', 'hang2'] + 郭晟: ['guo1', 'sheng4'] + 中标: ['zhong4', 'biao1'] + 抗住: ['kang2', 'zhu4'] \ No newline at end of file diff --git a/paddlespeech/t2s/frontend/sing_frontend.py b/paddlespeech/t2s/frontend/sing_frontend.py index c2aecf273af3a396bb6e7d48473ddd9be778acc8..fff72a10c2af899fc94adfec69bfff715ab88ab9 100644 --- a/paddlespeech/t2s/frontend/sing_frontend.py +++ b/paddlespeech/t2s/frontend/sing_frontend.py @@ -29,7 +29,7 @@ class SingFrontend(): pinyin_phone_path (str): pinyin to phone file path, a 'pinyin|phones' (like: ba|b a ) pair per line. phone_vocab_path (str): phone to phone id file path, a 'phone phone id' (like: a 4 ) pair per line. """ - self.punc = '[:,;。?!“”‘’\':,;.?!]' + self.punc = '[、:,;。?!“”‘’\':,;.?!]' self.pinyin_phones = {'AP': 'AP', 'SP': 'SP'} if pinyin_phone_path: diff --git a/paddlespeech/t2s/ssml/__init__.py b/paddlespeech/t2s/frontend/ssml/__init__.py similarity index 89% rename from paddlespeech/t2s/ssml/__init__.py rename to paddlespeech/t2s/frontend/ssml/__init__.py index 9b4db053b2efbbbd8384875495e6c077744a5191..b1b9d726f847caf5ccd44149c9a7d4ce2c221fda 100644 --- a/paddlespeech/t2s/ssml/__init__.py +++ b/paddlespeech/t2s/frontend/ssml/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/paddlespeech/t2s/ssml/xml_processor.py b/paddlespeech/t2s/frontend/ssml/xml_processor.py similarity index 84% rename from paddlespeech/t2s/ssml/xml_processor.py rename to paddlespeech/t2s/frontend/ssml/xml_processor.py index 892ca371e9d9c0646c63331f9fde8bdfbf11a8f1..1d216c31b7c187a6289eabe1f968624008287f63 100644 --- a/paddlespeech/t2s/ssml/xml_processor.py +++ b/paddlespeech/t2s/frontend/ssml/xml_processor.py @@ -1,4 +1,17 @@ # -*- coding: utf-8 -*- +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import re import xml.dom.minidom import xml.parsers.expat @@ -17,7 +30,6 @@ Note: xml 有5种特殊字符, &<>"' ' ' 例如: "姓名" - ''' @@ -61,17 +73,29 @@ class MixTextProcessor(): patn = re.compile(r'(.*\s*?)(.*?)(.*\s*)$', re.M | re.S) mat = re.match(patn, mixstr) if mat: + # pre pre_xml = mat.group(1) + # between ... in_xml = mat.group(2) + # post after_xml = mat.group(3) - ctlist.append([pre_xml, []]) + # pre with none syllable + if pre_xml: + ctlist.append([pre_xml, []]) + + # between with syllable + # [(sub sentence, [syllables]), ...] dom = DomXml(in_xml) pinyinlist = dom.get_pinyins_for_xml() ctlist = ctlist + pinyinlist - ctlist.append([after_xml, []]) + + # post with none syllable + if after_xml: + ctlist.append([after_xml, []]) else: ctlist.append([mixstr, []]) + return ctlist @classmethod @@ -86,17 +110,21 @@ class MixTextProcessor(): in_xml = mat.group(2) after_xml = mat.group(3) - ctlist.append(pre_xml) + if pre_xml: + ctlist.append(pre_xml) + dom = DomXml(in_xml) tags = dom.get_text_and_sayas_tags() ctlist.extend(tags) - - ctlist.append(after_xml) - return ctlist + + if after_xml: + ctlist.append(after_xml) else: ctlist.append(mixstr) + return ctlist + class DomXml(): def __init__(self, xmlstr): self.tdom = parseString(xmlstr) #Document diff --git a/paddlespeech/t2s/frontend/tone_sandhi.py b/paddlespeech/t2s/frontend/tone_sandhi.py index 42f7b8b2fe714d8bd2f92a6190b2fb6d548cb389..690f69aa254662875a08c739c55659975fbc2581 100644 --- a/paddlespeech/t2s/frontend/tone_sandhi.py +++ b/paddlespeech/t2s/frontend/tone_sandhi.py @@ -20,6 +20,9 @@ from pypinyin import Style class ToneSandhi(): + def __repr__(self): + return "MandarinToneSandhi" + def __init__(self): self.must_neural_tone_words = { '麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝', @@ -65,9 +68,22 @@ class ToneSandhi(): '男子', '女子', '分子', '原子', '量子', '莲子', '石子', '瓜子', '电子', '人人', '虎虎', '幺幺', '干嘛', '学子', '哈哈', '数数', '袅袅', '局地', '以下', '娃哈哈', '花花草草', '留得', '耕地', '想想', '熙熙', '攘攘', '卵子', '死死', '冉冉', '恳恳', '佼佼', '吵吵', '打打', - '考考', '整整', '莘莘', '落地', '算子', '家家户户' + '考考', '整整', '莘莘', '落地', '算子', '家家户户', '青青' } - self.punc = ":,;。?!“”‘’':,;.?!" + self.punc = "、:,;。?!“”‘’':,;.?!" + + def _split_word(self, word: str) -> List[str]: + word_list = jieba.cut_for_search(word) + word_list = sorted(word_list, key=lambda i: len(i), reverse=False) + first_subword = word_list[0] + first_begin_idx = word.find(first_subword) + if first_begin_idx == 0: + second_subword = word[len(first_subword):] + new_word_list = [first_subword, second_subword] + else: + second_subword = word[:-len(first_subword)] + new_word_list = [second_subword, first_subword] + return new_word_list # the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041 # e.g. @@ -154,18 +170,8 @@ class ToneSandhi(): finals[i] = finals[i][:-1] + "4" return finals - def _split_word(self, word: str) -> List[str]: - word_list = jieba.cut_for_search(word) - word_list = sorted(word_list, key=lambda i: len(i), reverse=False) - first_subword = word_list[0] - first_begin_idx = word.find(first_subword) - if first_begin_idx == 0: - second_subword = word[len(first_subword):] - new_word_list = [first_subword, second_subword] - else: - second_subword = word[:-len(first_subword)] - new_word_list = [second_subword, first_subword] - return new_word_list + def _all_tone_three(self, finals: List[str]) -> bool: + return all(x[-1] == "3" for x in finals) def _three_sandhi(self, word: str, finals: List[str]) -> List[str]: @@ -207,9 +213,6 @@ class ToneSandhi(): return finals - def _all_tone_three(self, finals: List[str]) -> bool: - return all(x[-1] == "3" for x in finals) - # merge "不" and the word behind it # if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: @@ -336,6 +339,9 @@ class ToneSandhi(): def pre_merge_for_modify( self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + """ + seg: [(word, pos), ...] + """ seg = self._merge_bu(seg) seg = self._merge_yi(seg) seg = self._merge_reduplication(seg) @@ -346,7 +352,11 @@ class ToneSandhi(): def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]: - + """ + word: 分词 + pos: 词性 + finals: 带调韵母, [final1, ..., finaln] + """ finals = self._bu_sandhi(word, finals) finals = self._yi_sandhi(word, finals) finals = self._neural_sandhi(word, pos, finals) diff --git a/paddlespeech/t2s/frontend/zh_frontend.py b/paddlespeech/t2s/frontend/zh_frontend.py index 35b97a93ad3cc3c6ace96857b088e55148d22dbe..1431bc6d8501c2ab88dc0d56795b91e54db960e0 100644 --- a/paddlespeech/t2s/frontend/zh_frontend.py +++ b/paddlespeech/t2s/frontend/zh_frontend.py @@ -14,6 +14,7 @@ import os import re from operator import itemgetter +from pprint import pprint from typing import Dict from typing import List @@ -30,10 +31,11 @@ from pypinyin_dict.phrase_pinyin_data import large_pinyin from paddlespeech.t2s.frontend.g2pw import G2PWOnnxConverter from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon +from paddlespeech.t2s.frontend.polyphonic import Polyphonic from paddlespeech.t2s.frontend.rhy_prediction.rhy_predictor import RhyPredictor +from paddlespeech.t2s.frontend.ssml.xml_processor import MixTextProcessor from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer -from paddlespeech.t2s.ssml.xml_processor import MixTextProcessor INITIALS = [ 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh', @@ -41,6 +43,9 @@ INITIALS = [ ] INITIALS += ['y', 'w', 'sp', 'spl', 'spn', 'sil'] +# 0 for None, 5 for neutral +TONES = ["0", "1", "2", "3", "4", "5"] + def intersperse(lst, item): result = [item] * (len(lst) * 2 + 1) @@ -49,34 +54,19 @@ def intersperse(lst, item): def insert_after_character(lst, item): + """ + inset `item` after finals. + """ result = [item] + for phone in lst: result.append(phone) if phone not in INITIALS: # finals has tones # assert phone[-1] in "12345" result.append(item) - return result - - -class Polyphonic(): - def __init__(self): - with open( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), - 'polyphonic.yaml'), - 'r', - encoding='utf-8') as polyphonic_file: - # 解析yaml - polyphonic_dict = yaml.load(polyphonic_file, Loader=yaml.FullLoader) - self.polyphonic_words = polyphonic_dict["polyphonic"] - def correct_pronunciation(self, word, pinyin): - # 词汇被词典收录则返回纠正后的读音 - if word in self.polyphonic_words.keys(): - pinyin = self.polyphonic_words[word] - # 否则返回原读音 - return pinyin + return result class Frontend(): @@ -85,10 +75,8 @@ class Frontend(): phone_vocab_path=None, tone_vocab_path=None, use_rhy=False): - self.mix_ssml_processor = MixTextProcessor() - self.tone_modifier = ToneSandhi() - self.text_normalizer = TextNormalizer() - self.punc = ":,;。?!“”‘’':,;.?!" + + self.punc = "、:,;。?!“”‘’':,;.?!" self.rhy_phns = ['sp1', 'sp2', 'sp3', 'sp4'] self.phrases_dict = { '开户行': [['ka1i'], ['hu4'], ['hang2']], @@ -108,28 +96,7 @@ class Frontend(): '嘞': [['lei5']], '掺和': [['chan1'], ['huo5']] } - self.use_rhy = use_rhy - if use_rhy: - self.rhy_predictor = RhyPredictor() - print("Rhythm predictor loaded.") - # g2p_model can be pypinyin and g2pM and g2pW - self.g2p_model = g2p_model - if self.g2p_model == "g2pM": - self.g2pM_model = G2pM() - self.pinyin2phone = generate_lexicon( - with_tone=True, with_erhua=False) - elif self.g2p_model == "g2pW": - # use pypinyin as backup for non polyphonic characters in g2pW - self._init_pypinyin() - self.corrector = Polyphonic() - self.g2pM_model = G2pM() - self.g2pW_model = G2PWOnnxConverter( - style='pinyin', enable_non_tradional_chinese=True) - self.pinyin2phone = generate_lexicon( - with_tone=True, with_erhua=False) - else: - self._init_pypinyin() self.must_erhua = { "小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿" } @@ -154,13 +121,51 @@ class Frontend(): for tone, id in tone_id: self.vocab_tones[tone] = int(id) + # SSML + self.mix_ssml_processor = MixTextProcessor() + # tone sandhi + self.tone_modifier = ToneSandhi() + # TN + self.text_normalizer = TextNormalizer() + + # prosody + self.use_rhy = use_rhy + if use_rhy: + self.rhy_predictor = RhyPredictor() + print("Rhythm predictor loaded.") + + # g2p + assert g2p_model in ('pypinyin', 'g2pM', 'g2pW') + self.g2p_model = g2p_model + if self.g2p_model == "g2pM": + self.g2pM_model = G2pM() + self.pinyin2phone = generate_lexicon( + with_tone=True, with_erhua=False) + elif self.g2p_model == "g2pW": + # use pypinyin as backup for non polyphonic characters in g2pW + self._init_pypinyin() + self.corrector = Polyphonic() + self.g2pM_model = G2pM() + self.g2pW_model = G2PWOnnxConverter( + style='pinyin', enable_non_tradional_chinese=True) + self.pinyin2phone = generate_lexicon( + with_tone=True, with_erhua=False) + else: + self._init_pypinyin() + def _init_pypinyin(self): + """ + Load pypinyin G2P module. + """ large_pinyin.load() load_phrases_dict(self.phrases_dict) # 调整字的拼音顺序 load_single_dict({ord(u'地'): u'de,di4'}) def _get_initials_finals(self, word: str) -> List[List[str]]: + """ + Get word initial and final by pypinyin or g2pM + """ initials = [] finals = [] if self.g2p_model == "pypinyin": @@ -171,11 +176,14 @@ class Frontend(): for c, v in zip(orig_initials, orig_finals): if re.match(r'i\d', v): if c in ['z', 'c', 's']: + # zi, ci, si v = re.sub('i', 'ii', v) elif c in ['zh', 'ch', 'sh', 'r']: + # zhi, chi, shi v = re.sub('i', 'iii', v) initials.append(c) finals.append(v) + elif self.g2p_model == "g2pM": pinyins = self.g2pM_model(word, tone=True, char_split=False) for pinyin in pinyins: @@ -192,58 +200,123 @@ class Frontend(): # If it's not pinyin (possibly punctuation) or no conversion is required initials.append(pinyin) finals.append(pinyin) + return initials, finals + def _merge_erhua(self, + initials: List[str], + finals: List[str], + word: str, + pos: str) -> List[List[str]]: + """ + Do erhub. + """ + # fix er1 + for i, phn in enumerate(finals): + if i == len(finals) - 1 and word[i] == "儿" and phn == 'er1': + finals[i] = 'er2' + + # 发音 + if word not in self.must_erhua and (word in self.not_erhua or + pos in {"a", "j", "nr"}): + return initials, finals + + # "……" 等情况直接返回 + if len(finals) != len(word): + return initials, finals + + assert len(finals) == len(word) + + # 不发音 + new_initials = [] + new_finals = [] + for i, phn in enumerate(finals): + if i == len(finals) - 1 and word[i] == "儿" and phn in { + "er2", "er5" + } and word[-2:] not in self.not_erhua and new_finals: + new_finals[-1] = new_finals[-1][:-1] + "r" + new_finals[-1][-1] + else: + new_initials.append(initials[i]) + new_finals.append(phn) + + return new_initials, new_finals + # if merge_sentences, merge all sentences into one phone sequence def _g2p(self, sentences: List[str], merge_sentences: bool=True, with_erhua: bool=True) -> List[List[str]]: + """ + Return: list of list phonemes. + [['w', 'o3', 'm', 'en2', 'sp'], ...] + """ segments = sentences phones_list = [] + + # split by punctuation for seg in segments: if self.use_rhy: seg = self.rhy_predictor._clean_text(seg) - phones = [] - # Replace all English words in the sentence + + # remove all English words in the sentence seg = re.sub('[a-zA-Z]+', '', seg) + + # add prosody mark if self.use_rhy: seg = self.rhy_predictor.get_prediction(seg) + + # [(word, pos), ...] seg_cut = psg.lcut(seg) - initials = [] - finals = [] + # fix wordseg bad case for sandhi seg_cut = self.tone_modifier.pre_merge_for_modify(seg_cut) + # 为了多音词获得更好的效果,这里采用整句预测 + phones = [] + initials = [] + finals = [] if self.g2p_model == "g2pW": try: + # undo prosody if self.use_rhy: seg = self.rhy_predictor._clean_text(seg) + + # g2p pinyins = self.g2pW_model(seg)[0] except Exception: - # g2pW采用模型采用繁体输入,如果有cover不了的简体词,采用g2pM预测 + # g2pW 模型采用繁体输入,如果有cover不了的简体词,采用g2pM预测 print("[%s] not in g2pW dict,use g2pM" % seg) pinyins = self.g2pM_model(seg, tone=True, char_split=False) + + # do prosody if self.use_rhy: rhy_text = self.rhy_predictor.get_prediction(seg) final_py = self.rhy_predictor.pinyin_align(pinyins, rhy_text) pinyins = final_py + pre_word_length = 0 for word, pos in seg_cut: sub_initials = [] sub_finals = [] now_word_length = pre_word_length + len(word) + + # skip english word if pos == 'eng': pre_word_length = now_word_length continue + word_pinyins = pinyins[pre_word_length:now_word_length] - # 矫正发音 + + # 多音字消歧 word_pinyins = self.corrector.correct_pronunciation( word, word_pinyins) + for pinyin, char in zip(word_pinyins, word): if pinyin is None: pinyin = char + pinyin = pinyin.replace("u:", "v") + if pinyin in self.pinyin2phone: initial_final_list = self.pinyin2phone[ pinyin].split(" ") @@ -257,28 +330,41 @@ class Frontend(): # If it's not pinyin (possibly punctuation) or no conversion is required sub_initials.append(pinyin) sub_finals.append(pinyin) + pre_word_length = now_word_length + # tone sandhi sub_finals = self.tone_modifier.modified_tone(word, pos, sub_finals) + # er hua if with_erhua: sub_initials, sub_finals = self._merge_erhua( sub_initials, sub_finals, word, pos) + initials.append(sub_initials) finals.append(sub_finals) # assert len(sub_initials) == len(sub_finals) == len(word) else: + # pypinyin, g2pM for word, pos in seg_cut: if pos == 'eng': + # skip english word continue + + # g2p sub_initials, sub_finals = self._get_initials_finals(word) + # tone sandhi sub_finals = self.tone_modifier.modified_tone(word, pos, sub_finals) + # er hua if with_erhua: sub_initials, sub_finals = self._merge_erhua( sub_initials, sub_finals, word, pos) + initials.append(sub_initials) finals.append(sub_finals) # assert len(sub_initials) == len(sub_finals) == len(word) + + # sum(iterable[, start]) initials = sum(initials, []) finals = sum(finals, []) @@ -287,111 +373,34 @@ class Frontend(): # we discriminate i, ii and iii if c and c not in self.punc: phones.append(c) + # replace punctuation by `sp` if c and c in self.punc: phones.append('sp') + if v and v not in self.punc and v not in self.rhy_phns: phones.append(v) - phones_list.append(phones) - if merge_sentences: - merge_list = sum(phones_list, []) - # rm the last 'sp' to avoid the noise at the end - # cause in the training data, no 'sp' in the end - if merge_list[-1] == 'sp': - merge_list = merge_list[:-1] - phones_list = [] - phones_list.append(merge_list) - return phones_list - def _split_word_to_char(self, words): - res = [] - for x in words: - res.append(x) - return res - - # if using ssml, have pingyin specified, assign pinyin to words - def _g2p_assign(self, - words: List[str], - pinyin_spec: List[str], - merge_sentences: bool=True) -> List[List[str]]: - phones_list = [] - initials = [] - finals = [] - - words = self._split_word_to_char(words[0]) - for pinyin, char in zip(pinyin_spec, words): - sub_initials = [] - sub_finals = [] - pinyin = pinyin.replace("u:", "v") - #self.pinyin2phone: is a dict with all pinyin mapped with sheng_mu yun_mu - if pinyin in self.pinyin2phone: - initial_final_list = self.pinyin2phone[pinyin].split(" ") - if len(initial_final_list) == 2: - sub_initials.append(initial_final_list[0]) - sub_finals.append(initial_final_list[1]) - elif len(initial_final_list) == 1: - sub_initials.append('') - sub_finals.append(initial_final_list[1]) - else: - # If it's not pinyin (possibly punctuation) or no conversion is required - sub_initials.append(pinyin) - sub_finals.append(pinyin) - initials.append(sub_initials) - finals.append(sub_finals) + phones_list.append(phones) - initials = sum(initials, []) - finals = sum(finals, []) - phones = [] - for c, v in zip(initials, finals): - # NOTE: post process for pypinyin outputs - # we discriminate i, ii and iii - if c and c not in self.punc: - phones.append(c) - if c and c in self.punc: - phones.append('sp') - if v and v not in self.punc and v not in self.rhy_phns: - phones.append(v) - phones_list.append(phones) + # merge split sub sentence into one sentence. if merge_sentences: + # sub sentence phonemes merge_list = sum(phones_list, []) # rm the last 'sp' to avoid the noise at the end # cause in the training data, no 'sp' in the end if merge_list[-1] == 'sp': merge_list = merge_list[:-1] + + # sentence phonemes phones_list = [] phones_list.append(merge_list) - return phones_list - def _merge_erhua(self, - initials: List[str], - finals: List[str], - word: str, - pos: str) -> List[List[str]]: - # fix er1 - for i, phn in enumerate(finals): - if i == len(finals) - 1 and word[i] == "儿" and phn == 'er1': - finals[i] = 'er2' - if word not in self.must_erhua and (word in self.not_erhua or - pos in {"a", "j", "nr"}): - return initials, finals - # "……" 等情况直接返回 - if len(finals) != len(word): - return initials, finals - - assert len(finals) == len(word) - - new_initials = [] - new_finals = [] - for i, phn in enumerate(finals): - if i == len(finals) - 1 and word[i] == "儿" and phn in { - "er2", "er5" - } and word[-2:] not in self.not_erhua and new_finals: - new_finals[-1] = new_finals[-1][:-1] + "r" + new_finals[-1][-1] - else: - new_finals.append(phn) - new_initials.append(initials[i]) - return new_initials, new_finals + return phones_list def _p2id(self, phonemes: List[str]) -> np.ndarray: + """ + Phoneme to Index + """ # replace unk phone with sp phonemes = [ phn if phn in self.vocab_phones else "sp" for phn in phonemes @@ -400,6 +409,9 @@ class Frontend(): return np.array(phone_ids, np.int64) def _t2id(self, tones: List[str]) -> np.ndarray: + """ + Tone to Index. + """ # replace unk phone with sp tones = [tone if tone in self.vocab_tones else "0" for tone in tones] tone_ids = [self.vocab_tones[item] for item in tones] @@ -407,6 +419,9 @@ class Frontend(): def _get_phone_tone(self, phonemes: List[str], get_tone_ids: bool=False) -> List[List[str]]: + """ + Get tone from phonemes. + """ phones = [] tones = [] if get_tone_ids and self.vocab_tones: @@ -423,13 +438,14 @@ class Frontend(): -1] == 'r' and phone not in self.vocab_phones and phone[: -1] in self.vocab_phones: phones.append(phone[:-1]) - phones.append("er") tones.append(tone) + phones.append("er") tones.append("2") else: phones.append(phone) tones.append(tone) else: + # initals with 0 tone. phones.append(full_phone) tones.append('0') else: @@ -443,6 +459,7 @@ class Frontend(): phones.append("er2") else: phones.append(phone) + return phones, tones def get_phonemes(self, @@ -451,10 +468,16 @@ class Frontend(): with_erhua: bool=True, robot: bool=False, print_info: bool=False) -> List[List[str]]: + """ + Main function to do G2P + """ + # TN & Text Segmentation sentences = self.text_normalizer.normalize(sentence) + # Prosody & WS & g2p & tone sandhi phonemes = self._g2p( sentences, merge_sentences=merge_sentences, with_erhua=with_erhua) - # change all tones to `1` + + # simulate robot pronunciation, change all tones to `1` if robot: new_phonemes = [] for sentence in phonemes: @@ -466,6 +489,7 @@ class Frontend(): new_sentence.append(item) new_phonemes.append(new_sentence) phonemes = new_phonemes + if print_info: print("----------------------------") print("text norm results:") @@ -476,25 +500,104 @@ class Frontend(): print("----------------------------") return phonemes - #@an added for ssml pinyin + def _split_word_to_char(self, words): + res = [] + for x in words: + res.append(x) + return res + + # if using ssml, have pingyin specified, assign pinyin to words + def _g2p_assign(self, + words: List[str], + pinyin_spec: List[str], + merge_sentences: bool=True) -> List[List[str]]: + """ + Replace phoneme by SSML + """ + phones_list = [] + initials = [] + finals = [] + + # to charactor list + words = self._split_word_to_char(words[0]) + + for pinyin, char in zip(pinyin_spec, words): + sub_initials = [] + sub_finals = [] + pinyin = pinyin.replace("u:", "v") + + #self.pinyin2phone: is a dict with all pinyin mapped with sheng_mu yun_mu + if pinyin in self.pinyin2phone: + initial_final_list = self.pinyin2phone[pinyin].split(" ") + if len(initial_final_list) == 2: + sub_initials.append(initial_final_list[0]) + sub_finals.append(initial_final_list[1]) + elif len(initial_final_list) == 1: + sub_initials.append('') + sub_finals.append(initial_final_list[1]) + else: + # If it's not pinyin (possibly punctuation) or no conversion is required + sub_initials.append(pinyin) + sub_finals.append(pinyin) + + initials.append(sub_initials) + finals.append(sub_finals) + + initials = sum(initials, []) + finals = sum(finals, []) + + phones = [] + for c, v in zip(initials, finals): + # c for consonant, v for vowel + # NOTE: post process for pypinyin outputs + # we discriminate i, ii and iii + if c and c not in self.punc: + phones.append(c) + # replace punc to `sp` + if c and c in self.punc: + phones.append('sp') + if v and v not in self.punc and v not in self.rhy_phns: + phones.append(v) + phones_list.append(phones) + + if merge_sentences: + merge_list = sum(phones_list, []) + # rm the last 'sp' to avoid the noise at the end + # cause in the training data, no 'sp' in the end + if merge_list[-1] == 'sp': + merge_list = merge_list[:-1] + phones_list = [] + phones_list.append(merge_list) + + return phones_list + def get_phonemes_ssml(self, ssml_inputs: list, merge_sentences: bool=True, with_erhua: bool=True, robot: bool=False, print_info: bool=False) -> List[List[str]]: + """ + Main function to do G2P with SSML support. + """ all_phonemes = [] for word_pinyin_item in ssml_inputs: phonemes = [] + + # ['你喜欢', []] -> 你喜欢 [] sentence, pinyin_spec = itemgetter(0, 1)(word_pinyin_item) + + # TN & Text Segmentation sentences = self.text_normalizer.normalize(sentence) + if len(pinyin_spec) == 0: + # g2p word w/o specified phonemes = self._g2p( sentences, merge_sentences=merge_sentences, with_erhua=with_erhua) else: - # phonemes should be pinyin_spec + # word phonemes specified by phonemes = self._g2p_assign( sentences, pinyin_spec, merge_sentences=merge_sentences) @@ -512,17 +615,24 @@ class Frontend(): new_phonemes.append(new_sentence) all_phonemes = new_phonemes + if merge_sentences: + all_phonemes = [sum(all_phonemes, [])] + if print_info: print("----------------------------") print("text norm results:") print(sentences) print("----------------------------") print("g2p results:") - print(all_phonemes[0]) + print(all_phonemes) print("----------------------------") - return [sum(all_phonemes, [])] + + return all_phonemes def add_sp_if_no(self, phonemes): + """ + Prosody mark #4 added at sentence end. + """ if not phonemes[-1][-1].startswith('sp'): phonemes[-1].append('sp4') return phonemes @@ -542,8 +652,11 @@ class Frontend(): merge_sentences=merge_sentences, print_info=print_info, robot=robot) + + # add #4 for sentence end. if self.use_rhy: phonemes = self.add_sp_if_no(phonemes) + result = {} phones = [] tones = [] @@ -551,28 +664,33 @@ class Frontend(): temp_tone_ids = [] for part_phonemes in phonemes: + phones, tones = self._get_phone_tone( part_phonemes, get_tone_ids=get_tone_ids) + if add_blank: phones = insert_after_character(phones, blank_token) + if tones: tone_ids = self._t2id(tones) if to_tensor: tone_ids = paddle.to_tensor(tone_ids) temp_tone_ids.append(tone_ids) + if phones: phone_ids = self._p2id(phones) # if use paddle.to_tensor() in onnxruntime, the first time will be too low if to_tensor: phone_ids = paddle.to_tensor(phone_ids) temp_phone_ids.append(phone_ids) + if temp_tone_ids: result["tone_ids"] = temp_tone_ids if temp_phone_ids: result["phone_ids"] = temp_phone_ids + return result - # @an added for ssml def get_input_ids_ssml( self, sentence: str, @@ -584,12 +702,15 @@ class Frontend(): blank_token: str="", to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]: - l_inputs = MixTextProcessor.get_pinyin_split(sentence) + # split setence by SSML tag. + texts = MixTextProcessor.get_pinyin_split(sentence) + phonemes = self.get_phonemes_ssml( - l_inputs, + texts, merge_sentences=merge_sentences, print_info=print_info, robot=robot) + result = {} phones = [] tones = [] @@ -599,21 +720,26 @@ class Frontend(): for part_phonemes in phonemes: phones, tones = self._get_phone_tone( part_phonemes, get_tone_ids=get_tone_ids) + if add_blank: phones = insert_after_character(phones, blank_token) + if tones: tone_ids = self._t2id(tones) if to_tensor: tone_ids = paddle.to_tensor(tone_ids) temp_tone_ids.append(tone_ids) + if phones: phone_ids = self._p2id(phones) # if use paddle.to_tensor() in onnxruntime, the first time will be too low if to_tensor: phone_ids = paddle.to_tensor(phone_ids) temp_phone_ids.append(phone_ids) + if temp_tone_ids: result["tone_ids"] = temp_tone_ids if temp_phone_ids: result["phone_ids"] = temp_phone_ids + return result diff --git a/setup.py b/setup.py index 07b411bd048e965bf0107f53b749a2a66fa9c968..af7c4dc3dfc3f1dfb15ccf11c5d61080f196efdb 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,8 @@ base = [ "hyperpyyaml", "inflect", "jsonlines", + # paddleaudio align with librosa==0.8.1, which need numpy==1.23.x + "numpy==1.23.5", "librosa==0.8.1", "scipy>=1.4.0", "loguru", @@ -260,6 +262,7 @@ setup_info = dict( long_description=read("README.md"), long_description_content_type="text/markdown", keywords=[ + "SSL" "speech", "asr", "tts", @@ -268,12 +271,19 @@ setup_info = dict( "text frontend", "MFA", "paddlepaddle", + "paddleaudio", + "streaming asr", + "streaming tts", "beam search", "ctcdecoder", "deepspeech2", + "wav2vec2", + "hubert", + "wavlm", "transformer", "conformer", "fastspeech2", + "hifigan", "gan vocoders", ], python_requires='>=3.7', diff --git a/tests/unit/tts/test_enfrontend.py b/tests/unit/tts/test_enfrontend.py new file mode 100644 index 0000000000000000000000000000000000000000..4f8c49305b4c772ecd0d9216731a7dd9839d6d74 --- /dev/null +++ b/tests/unit/tts/test_enfrontend.py @@ -0,0 +1,28 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddlespeech.t2s.frontend.en_frontend import English as EnFrontend + +if __name__ == '__main__': + + fe = EnFrontend() + + text = "AI for Sceience" + phonemes = fe.phoneticize(text) + print(text) + print(phonemes) + + text = "eight" + phonemes = fe.phoneticize(text) + print(text) + print(phonemes) diff --git a/tests/unit/tts/test_mixfrontend.py b/tests/unit/tts/test_mixfrontend.py new file mode 100644 index 0000000000000000000000000000000000000000..5751dd2a7524809679ef999b61defbde3145e1a5 --- /dev/null +++ b/tests/unit/tts/test_mixfrontend.py @@ -0,0 +1,444 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +import tempfile + +from paddlespeech.t2s.frontend.mix_frontend import MixFrontend + +# mix zh & en phonemes +phone_id_str = """ + 0 + 1 +AA0 2 +AA1 3 +AA2 4 +AE0 5 +AE1 6 +AE2 7 +AH0 8 +AH1 9 +AH2 10 +AO0 11 +AO1 12 +AO2 13 +AW0 14 +AW1 15 +AW2 16 +AY0 17 +AY1 18 +AY2 19 +B 20 +CH 21 +D 22 +DH 23 +EH0 24 +EH1 25 +EH2 26 +ER0 27 +ER1 28 +ER2 29 +EY0 30 +EY1 31 +EY2 32 +F 33 +G 34 +HH 35 +IH0 36 +IH1 37 +IH2 38 +IY0 39 +IY1 40 +IY2 41 +JH 42 +K 43 +L 44 +M 45 +N 46 +NG 47 +OW0 48 +OW1 49 +OW2 50 +OY0 51 +OY1 52 +OY2 53 +P 54 +R 55 +S 56 +SH 57 +T 58 +TH 59 +UH0 60 +UH1 61 +UH2 62 +UW0 63 +UW1 64 +UW2 65 +V 66 +W 67 +Y 68 +Z 69 +ZH 70 +a1 71 +a2 72 +a3 73 +a4 74 +a5 75 +ai1 76 +ai2 77 +ai3 78 +ai4 79 +ai5 80 +air2 81 +air3 82 +air4 83 +an1 84 +an2 85 +an3 86 +an4 87 +an5 88 +ang1 89 +ang2 90 +ang3 91 +ang4 92 +ang5 93 +angr2 94 +angr4 95 +anr1 96 +anr3 97 +anr4 98 +ao1 99 +ao2 100 +ao3 101 +ao4 102 +ao5 103 +aor1 104 +aor3 105 +aor4 106 +aor5 107 +ar2 108 +ar3 109 +ar4 110 +ar5 111 +b 112 +c 113 +ch 114 +d 115 +e1 116 +e2 117 +e3 118 +e4 119 +e5 120 +ei1 121 +ei2 122 +ei3 123 +ei4 124 +ei5 125 +eir4 126 +en1 127 +en2 128 +en3 129 +en4 130 +en5 131 +eng1 132 +eng2 133 +eng3 134 +eng4 135 +eng5 136 +engr4 137 +enr1 138 +enr2 139 +enr3 140 +enr4 141 +enr5 142 +er1 143 +er2 144 +er3 145 +er4 146 +er5 147 +f 148 +g 149 +h 150 +i1 151 +i2 152 +i3 153 +i4 154 +i5 155 +ia1 156 +ia2 157 +ia3 158 +ia4 159 +ia5 160 +ian1 161 +ian2 162 +ian3 163 +ian4 164 +ian5 165 +iang1 166 +iang2 167 +iang3 168 +iang4 169 +iang5 170 +iangr4 171 +ianr1 172 +ianr2 173 +ianr3 174 +ianr4 175 +ianr5 176 +iao1 177 +iao2 178 +iao3 179 +iao4 180 +iao5 181 +iaor1 182 +iaor2 183 +iaor3 184 +iaor4 185 +iar1 186 +iar3 187 +iar4 188 +ie1 189 +ie2 190 +ie3 191 +ie4 192 +ie5 193 +ii1 194 +ii2 195 +ii3 196 +ii4 197 +ii5 198 +iii1 199 +iii2 200 +iii3 201 +iii4 202 +iii5 203 +iiir1 204 +iiir4 205 +iir2 206 +in1 207 +in2 208 +in3 209 +in4 210 +in5 211 +ing1 212 +ing2 213 +ing3 214 +ing4 215 +ing5 216 +ingr1 217 +ingr2 218 +ingr3 219 +ingr4 220 +inr1 221 +inr4 222 +io1 223 +io3 224 +io5 225 +iong1 226 +iong2 227 +iong3 228 +iong4 229 +iong5 230 +iou1 231 +iou2 232 +iou3 233 +iou4 234 +iou5 235 +iour1 236 +iour2 237 +iour3 238 +iour4 239 +ir1 240 +ir2 241 +ir3 242 +ir4 243 +ir5 244 +j 245 +k 246 +l 247 +m 248 +n 249 +o1 250 +o2 251 +o3 252 +o4 253 +o5 254 +ong1 255 +ong2 256 +ong3 257 +ong4 258 +ong5 259 +ongr4 260 +or2 261 +ou1 262 +ou2 263 +ou3 264 +ou4 265 +ou5 266 +our2 267 +our3 268 +our4 269 +our5 270 +p 271 +q 272 +r 273 +s 274 +sh 275 +sil 276 +sp 277 +spl 278 +spn 279 +t 280 +u1 281 +u2 282 +u3 283 +u4 284 +u5 285 +ua1 286 +ua2 287 +ua3 288 +ua4 289 +ua5 290 +uai1 291 +uai2 292 +uai3 293 +uai4 294 +uai5 295 +uair4 296 +uan1 297 +uan2 298 +uan3 299 +uan4 300 +uan5 301 +uang1 302 +uang2 303 +uang3 304 +uang4 305 +uang5 306 +uangr4 307 +uanr1 308 +uanr2 309 +uanr3 310 +uanr4 311 +uanr5 312 +uar1 313 +uar2 314 +uar4 315 +uei1 316 +uei2 317 +uei3 318 +uei4 319 +uei5 320 +ueir1 321 +ueir2 322 +ueir3 323 +ueir4 324 +uen1 325 +uen2 326 +uen3 327 +uen4 328 +uen5 329 +ueng1 330 +ueng2 331 +ueng3 332 +ueng4 333 +uenr1 334 +uenr2 335 +uenr3 336 +uenr4 337 +uo1 338 +uo2 339 +uo3 340 +uo4 341 +uo5 342 +uor1 343 +uor2 344 +uor3 345 +uor5 346 +ur1 347 +ur2 348 +ur3 349 +ur4 350 +ur5 351 +v1 352 +v2 353 +v3 354 +v4 355 +v5 356 +van1 357 +van2 358 +van3 359 +van4 360 +van5 361 +vanr1 362 +vanr2 363 +vanr3 364 +vanr4 365 +ve1 366 +ve2 367 +ve3 368 +ve4 369 +ve5 370 +ver3 371 +ver4 372 +vn1 373 +vn2 374 +vn3 375 +vn4 376 +vn5 377 +vnr2 378 +vr3 379 +x 380 +z 381 +zh 382 +, 383 +. 384 +? 385 +! 386 + 387 +""" + +if __name__ == '__main__': + with tempfile.NamedTemporaryFile(mode='wt') as f: + phone_ids = phone_id_str.split() + for phone, id in zip(phone_ids[::2], phone_ids[1::2]): + f.write(f"{phone} {id}") + f.write('\n') + f.flush() + + frontend = MixFrontend(phone_vocab_path=f.name) + + text = "hello, 我爱北京天安们,what about you." + print(text) + # [('hello, ', 'en'), ('我爱北京天安们,', 'zh'), ('what about you.', 'en')] + segs = frontend.split_by_lang(text) + print(segs) + + text = "hello?!!我爱北京天安们,what about you." + print(text) + # [('hello?!!', 'en'), ('我爱北京天安们,', 'zh'), ('what about you.', 'en')] + segs = frontend.split_by_lang(text) + print(segs) + + text = " hello,我爱北京天安们,what about you." + print(text) + # [(' hello,', 'en'), ('我爱北京天安们,', 'zh'), ('what about you.', 'en')] + segs = frontend.split_by_lang(text) + print(segs) + + # 对于SSML的xml标记处理不好。需要先解析SSML,后处理中英的划分。 + text = "我们的声学模型使用了 Fast Speech Two。前浪在沙滩上,沙滩上倒了一堆。 想象干干的树干了, 里面有个干尸,不知是被谁死的。" + print(text) + # [('', 'en'), ('我们的声学模型使用了 ', 'zh'), ('Fast Speech Two。', 'en'), ('前浪<', 'zh'), ("say-as pinyin='dao3'>", 'en'), ('倒', 'en'), ('在沙滩上,沙滩上倒了一堆<', 'zh'), ("say-as pinyin='tu3'>", 'en'), ('土。 ', 'en'), ('想象<', 'zh'), ("say-as pinyin='gan1 gan1'>", 'en'), ('干干', 'en'), ('的树干<', 'zh'), ("say-as pinyin='dao3'>", 'en'), ('倒', 'en'), ('了, 里面有个干尸,不知是被谁<', 'zh'), ("say-as pinyin='gan4'>", 'en'), ('干', 'en'), ('死的。', 'en')] + segs = frontend.split_by_lang(text) + print(segs) diff --git a/tests/unit/tts/test_ssml.py b/tests/unit/tts/test_ssml.py new file mode 100644 index 0000000000000000000000000000000000000000..2c24018377a523f31bc7dd5f0eb231e1cbb2272c --- /dev/null +++ b/tests/unit/tts/test_ssml.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddlespeech.t2s.frontend.ssml.xml_processor import MixTextProcessor + +if __name__ == '__main__': + text = "你好吗,我们的声学模型使用了 Fast Speech Two。前浪在沙滩上,沙滩上倒了一堆。 想象干干的树干了, 里面有个干尸,不知是被谁死的。thank you." + + # SSML: 13 + # 0 ['你好吗,', []] + # 1 ['我们的声学模型使用了FastSpeechTwo。前浪', []] + # 2 ['倒', ['dao3']] + # 3 ['在沙滩上,沙滩上倒了一堆', []] + # 4 ['土', ['tu3']] + # 5 ['。想象', []] + # 6 ['干干', ['gan1', 'gan1']] + # 7 ['的树干', []] + # 8 ['倒', ['dao3']] + # 9 ['了,里面有个干尸,不知是被谁', []] + # 10 ['干', ['gan4']] + # 11 ['死的。', []] + # 12 ['thank you.', []] + inputs = MixTextProcessor.get_pinyin_split(text) + print(f"SSML get_pinyin_split: {len(inputs)}") + for i, sub in enumerate(inputs): + print(i, sub) + print() + + # SSML get_dom_split: 13 + # 0 你好吗, + # 1 我们的声学模型使用了 Fast Speech Two。前浪 + # 2 + # 3 在沙滩上,沙滩上倒了一堆 + # 4 + # 5 。 想象 + # 6 干干 + # 7 的树干 + # 8 + # 9 了, 里面有个干尸,不知是被谁 + # 10 + # 11 死的。 + # 12 thank you. + inputs = MixTextProcessor.get_dom_split(text) + print(f"SSML get_dom_split: {len(inputs)}") + for i, sub in enumerate(inputs): + print(i, sub) + print() + + # SSML object.get_pinyin_split: 246 + # 我们的声学模型使用了 Fast Speech Two。前浪在沙滩上,沙滩上倒了一堆。 想象干干的树干了, 里面有个干尸,不知是被谁死的。 + outs = MixTextProcessor().get_xml_content(text) + print(f"SSML object.get_pinyin_split: {len(outs)}") + print(outs) + print() + + # SSML object.get_content_split: 30 你好吗, + # 1 我们的声学模型使用了 Fast Speech Two。前浪在沙滩上,沙滩上倒了一堆。 想象干干的树干 + # 倒了, 里面有个干尸,不知是被谁死的。 + # 2 thank you. + outs = MixTextProcessor().get_content_split(text) + print(f"SSML object.get_content_split: {len(outs)}") + for i, sub in enumerate(outs): + print(i, sub) + print()