diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..7c9f4165c70971010d320eeacf85fcd85e43bff1 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include paddlespeech/t2s/exps/*.txt +include paddlespeech/t2s/frontend/*.yaml \ No newline at end of file diff --git a/README.md b/README.md index e93aa1d9ccbe37410b826bd78c586deff36874d8..e35289e2b2c9cbf209e11a24c441da94a775cf28 100644 --- a/README.md +++ b/README.md @@ -699,6 +699,7 @@ You are warmly welcome to submit questions in [discussions](https://github.com/P ## Acknowledgement +- Many thanks to [BarryKCL](https://github.com/BarryKCL) improved TTS Chinses frontend based on [G2PW](https://github.com/GitYCC/g2pW) - Many thanks to [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) for years of attention, constructive advice and great help. - Many thanks to [mymagicpower](https://github.com/mymagicpower) for the Java implementation of ASR upon [short](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk) and [long](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk) audio files. - Many thanks to [JiehangXie](https://github.com/JiehangXie)/[PaddleBoBo](https://github.com/JiehangXie/PaddleBoBo) for developing Virtual Uploader(VUP)/Virtual YouTuber(VTuber) with PaddleSpeech TTS function. diff --git a/README_cn.md b/README_cn.md index 896c575ce9b12cb71c21d3e49792d74d442d5c5a..1c6a949fd78569ddfcafa0823c6d6b401f069292 100644 --- a/README_cn.md +++ b/README_cn.md @@ -833,6 +833,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 ## 致谢 +- 非常感谢 [BarryKCL](https://github.com/BarryKCL)基于[G2PW](https://github.com/GitYCC/g2pW)对TTS中文文本前端的优化。 - 非常感谢 [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) 多年来的关注和建议,以及在诸多问题上的帮助。 - 非常感谢 [mymagicpower](https://github.com/mymagicpower) 采用PaddleSpeech 对 ASR 的[短语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk)及[长语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk)进行 Java 实现。 - 非常感谢 [JiehangXie](https://github.com/JiehangXie)/[PaddleBoBo](https://github.com/JiehangXie/PaddleBoBo) 采用 PaddleSpeech 语音合成功能实现 Virtual Uploader(VUP)/Virtual YouTuber(VTuber) 虚拟主播。 diff --git a/docs/requirements.txt b/docs/requirements.txt index d6e27e226e24f5fc93b554fa7da5f632b67bb8cf..ee116a9b686494bf813b73fde0ed91ca914de4ba 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -19,6 +19,7 @@ loguru matplotlib nara_wpe onnxruntime==1.10.0 +opencc pandas paddlenlp paddlespeech_feat diff --git a/docs/source/reference.md b/docs/source/reference.md index ed91c2066f2fe8c4a6470e45ac094d6187472a41..0d36d96f7404ec3a794b0f2c7594b392b29b0e04 100644 --- a/docs/source/reference.md +++ b/docs/source/reference.md @@ -40,3 +40,6 @@ We borrowed a lot of code from these repos to build `model` and `engine`, thanks * [ThreadPool](https://github.com/progschj/ThreadPool/blob/master/COPYING) - zlib License - ThreadPool + +* [g2pW](https://github.com/GitYCC/g2pW/blob/master/LICENCE) +- Apache-2.0 license diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index bfe2bc7ecf14bc7f7601d53dfa78072db4bccdc6..e39f7721a723301fead1738cf9c0cb185abc4f8a 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -1335,3 +1335,17 @@ kws_dynamic_pretrained_models = { }, }, } + +# --------------------------------- +# ------------- G2PW --------------- +# --------------------------------- +g2pw_onnx_models = { + 'G2PWModel': { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/g2p/G2PWModel.tar', + 'md5': + '63bc0894af15a5a591e58b2130a2bcac', + }, + }, +} diff --git a/paddlespeech/t2s/frontend/g2pw/__init__.py b/paddlespeech/t2s/frontend/g2pw/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6e1ee0db86715ea72579f3926e5f185fe5c36880 --- /dev/null +++ b/paddlespeech/t2s/frontend/g2pw/__init__.py @@ -0,0 +1,2 @@ +from paddlespeech.t2s.frontend.g2pw.onnx_api import G2PWOnnxConverter + diff --git a/paddlespeech/t2s/frontend/g2pw/dataset.py b/paddlespeech/t2s/frontend/g2pw/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8125f71f0f0227422d3408926e219bf5bb6e065d --- /dev/null +++ b/paddlespeech/t2s/frontend/g2pw/dataset.py @@ -0,0 +1,134 @@ +""" +Credits + This code is modified from https://github.com/GitYCC/g2pW +""" +import numpy as np +from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map + +ANCHOR_CHAR = '▁' + + +def prepare_onnx_input(tokenizer, labels, char2phonemes, chars, texts, query_ids, phonemes=None, pos_tags=None, + use_mask=False, use_char_phoneme=False, use_pos=False, window_size=None, max_len=512): + if window_size is not None: + truncated_texts, truncated_query_ids = _truncate_texts(window_size, texts, query_ids) + + input_ids = [] + token_type_ids = [] + attention_masks = [] + phoneme_masks = [] + char_ids = [] + position_ids = [] + + for idx in range(len(texts)): + text = (truncated_texts if window_size else texts)[idx].lower() + query_id = (truncated_query_ids if window_size else query_ids)[idx] + + try: + tokens, text2token, token2text = tokenize_and_map(tokenizer, text) + except Exception: + print(f'warning: text "{text}" is invalid') + return {} + + text, query_id, tokens, text2token, token2text = _truncate(max_len, text, query_id, tokens, text2token, token2text) + + processed_tokens = ['[CLS]'] + tokens + ['[SEP]'] + + input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) + token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) + attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) + + query_char = text[query_id] + phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \ + if use_mask else [1] * len(labels) + char_id = chars.index(query_char) + position_id = text2token[query_id] + 1 # [CLS] token locate at first place + + input_ids.append(input_id) + token_type_ids.append(token_type_id) + attention_masks.append(attention_mask) + phoneme_masks.append(phoneme_mask) + char_ids.append(char_id) + position_ids.append(position_id) + + outputs = { + 'input_ids': np.array(input_ids), + 'token_type_ids': np.array(token_type_ids), + 'attention_masks': np.array(attention_masks), + 'phoneme_masks': np.array(phoneme_masks).astype(np.float32), + 'char_ids': np.array(char_ids), + 'position_ids': np.array(position_ids), + } + return outputs + +def _truncate_texts(window_size, texts, query_ids): + truncated_texts = [] + truncated_query_ids = [] + for text, query_id in zip(texts, query_ids): + start = max(0, query_id - window_size // 2) + end = min(len(text), query_id + window_size // 2) + truncated_text = text[start:end] + truncated_texts.append(truncated_text) + + truncated_query_id = query_id - start + truncated_query_ids.append(truncated_query_id) + return truncated_texts, truncated_query_ids + +def _truncate(max_len, text, query_id, tokens, text2token, token2text): + truncate_len = max_len - 2 + if len(tokens) <= truncate_len: + return (text, query_id, tokens, text2token, token2text) + + token_position = text2token[query_id] + + token_start = token_position - truncate_len // 2 + token_end = token_start + truncate_len + font_exceed_dist = -token_start + back_exceed_dist = token_end - len(tokens) + if font_exceed_dist > 0: + token_start += font_exceed_dist + token_end += font_exceed_dist + elif back_exceed_dist > 0: + token_start -= back_exceed_dist + token_end -= back_exceed_dist + + start = token2text[token_start][0] + end = token2text[token_end - 1][1] + + return ( + text[start:end], + query_id - start, + tokens[token_start:token_end], + [i - token_start if i is not None else None for i in text2token[start:end]], + [(s - start, e - start) for s, e in token2text[token_start:token_end]] + ) + +def prepare_data(sent_path, lb_path=None): + raw_texts = open(sent_path).read().rstrip().split('\n') + query_ids = [raw.index(ANCHOR_CHAR) for raw in raw_texts] + texts = [raw.replace(ANCHOR_CHAR, '') for raw in raw_texts] + if lb_path is None: + return texts, query_ids + else: + phonemes = open(lb_path).read().rstrip().split('\n') + return texts, query_ids, phonemes + + +def get_phoneme_labels(polyphonic_chars): + labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars]))) + char2phonemes = {} + for char, phoneme in polyphonic_chars: + if char not in char2phonemes: + char2phonemes[char] = [] + char2phonemes[char].append(labels.index(phoneme)) + return labels, char2phonemes + + +def get_char_phoneme_labels(polyphonic_chars): + labels = sorted(list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars]))) + char2phonemes = {} + for char, phoneme in polyphonic_chars: + if char not in char2phonemes: + char2phonemes[char] = [] + char2phonemes[char].append(labels.index(f'{char} {phoneme}')) + return labels, char2phonemes diff --git a/paddlespeech/t2s/frontend/g2pw/onnx_api.py b/paddlespeech/t2s/frontend/g2pw/onnx_api.py new file mode 100644 index 0000000000000000000000000000000000000000..ace943f2ba2a23dda150193fa1d1096472f78ea4 --- /dev/null +++ b/paddlespeech/t2s/frontend/g2pw/onnx_api.py @@ -0,0 +1,143 @@ +""" +Credits + This code is modified from https://github.com/GitYCC/g2pW +""" +import os +import json +import onnxruntime +import numpy as np + +from opencc import OpenCC +from pypinyin import pinyin, lazy_pinyin, Style +from paddlenlp.transformers import BertTokenizer +from paddlespeech.utils.env import MODEL_HOME +from paddlespeech.t2s.frontend.g2pw.dataset import prepare_data,\ + prepare_onnx_input,\ + get_phoneme_labels,\ + get_char_phoneme_labels +from paddlespeech.t2s.frontend.g2pw.utils import load_config +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.resource.pretrained_models import g2pw_onnx_models + + +def predict(session, onnx_input, labels): + all_preds = [] + all_confidences = [] + probs = session.run([],{"input_ids": onnx_input['input_ids'], + "token_type_ids":onnx_input['token_type_ids'], + "attention_mask":onnx_input['attention_masks'], + "phoneme_mask":onnx_input['phoneme_masks'], + "char_ids":onnx_input['char_ids'], + "position_ids":onnx_input['position_ids']})[0] + + preds = np.argmax(probs,axis=1).tolist() + max_probs = [] + for index,arr in zip(preds,probs.tolist()): + max_probs.append(arr[index]) + all_preds += [labels[pred] for pred in preds] + all_confidences += max_probs + + return all_preds, all_confidences + + +class G2PWOnnxConverter: + def __init__(self, model_dir = MODEL_HOME, style='bopomofo', model_source=None, enable_non_tradional_chinese=False): + if not os.path.exists(os.path.join(model_dir, 'G2PWModel/g2pW.onnx')): + uncompress_path = download_and_decompress(g2pw_onnx_models['G2PWModel']['1.0'],model_dir) + + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL + sess_options.intra_op_num_threads = 2 + self.session_g2pW = onnxruntime.InferenceSession(os.path.join(model_dir, 'G2PWModel/g2pW.onnx'),sess_options=sess_options) + self.config = load_config(os.path.join(model_dir, 'G2PWModel/config.py'), use_default=True) + + self.model_source = model_source if model_source else self.config.model_source + self.enable_opencc = enable_non_tradional_chinese + + self.tokenizer = BertTokenizer.from_pretrained(self.config.model_source) + + polyphonic_chars_path = os.path.join(model_dir, 'G2PWModel/POLYPHONIC_CHARS.txt') + monophonic_chars_path = os.path.join(model_dir, 'G2PWModel/MONOPHONIC_CHARS.txt') + self.polyphonic_chars = [line.split('\t') for line in open(polyphonic_chars_path,encoding='utf-8').read().strip().split('\n')] + self.monophonic_chars = [line.split('\t') for line in open(monophonic_chars_path,encoding='utf-8').read().strip().split('\n')] + self.labels, self.char2phonemes = get_char_phoneme_labels(self.polyphonic_chars) if self.config.use_char_phoneme else get_phoneme_labels(self.polyphonic_chars) + + self.chars = sorted(list(self.char2phonemes.keys())) + self.pos_tags = ['UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'] + + with open(os.path.join(model_dir,'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'), 'r',encoding='utf-8') as fr: + self.bopomofo_convert_dict = json.load(fr) + self.style_convert_func = { + 'bopomofo': lambda x: x, + 'pinyin': self._convert_bopomofo_to_pinyin, + }[style] + + with open(os.path.join(model_dir,'G2PWModel/char_bopomofo_dict.json'), 'r',encoding='utf-8') as fr: + self.char_bopomofo_dict = json.load(fr) + + if self.enable_opencc: + self.cc = OpenCC('s2tw') + + def _convert_bopomofo_to_pinyin(self, bopomofo): + tone = bopomofo[-1] + assert tone in '12345' + component = self.bopomofo_convert_dict.get(bopomofo[:-1]) + if component: + return component + tone + else: + print(f'Warning: "{bopomofo}" cannot convert to pinyin') + return None + + def __call__(self, sentences): + if isinstance(sentences, str): + sentences = [sentences] + + if self.enable_opencc: + translated_sentences = [] + for sent in sentences: + translated_sent = self.cc.convert(sent) + assert len(translated_sent) == len(sent) + translated_sentences.append(translated_sent) + sentences = translated_sentences + + texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences) + if len(texts) == 0: + # sentences no polyphonic words + return partial_results + + onnx_input = prepare_onnx_input(self.tokenizer, self.labels, self.char2phonemes, self.chars, texts, query_ids, + use_mask=self.config.use_mask, use_char_phoneme=self.config.use_char_phoneme, + window_size=None) + + preds, confidences = predict(self.session_g2pW, onnx_input, self.labels) + if self.config.use_char_phoneme: + preds = [pred.split(' ')[1] for pred in preds] + + results = partial_results + for sent_id, query_id, pred in zip(sent_ids, query_ids, preds): + results[sent_id][query_id] = self.style_convert_func(pred) + + return results + + def _prepare_data(self, sentences): + polyphonic_chars = set(self.chars) + monophonic_chars_dict = { + char: phoneme for char, phoneme in self.monophonic_chars + } + texts, query_ids, sent_ids, partial_results = [], [], [], [] + for sent_id, sent in enumerate(sentences): + pypinyin_result = pinyin(sent,style=Style.TONE3) + partial_result = [None] * len(sent) + for i, char in enumerate(sent): + if char in polyphonic_chars: + texts.append(sent) + query_ids.append(i) + sent_ids.append(sent_id) + elif char in monophonic_chars_dict: + partial_result[i] = self.style_convert_func(monophonic_chars_dict[char]) + elif char in self.char_bopomofo_dict: + partial_result[i] = pypinyin_result[i][0] + # partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0]) + partial_results.append(partial_result) + return texts, query_ids, sent_ids, partial_results diff --git a/paddlespeech/t2s/frontend/g2pw/utils.py b/paddlespeech/t2s/frontend/g2pw/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..771e900795dcadb0286ff2437582b19915cc36be --- /dev/null +++ b/paddlespeech/t2s/frontend/g2pw/utils.py @@ -0,0 +1,133 @@ + +""" +Credits + This code is modified from https://github.com/GitYCC/g2pW +""" +import re +import sys + +def wordize_and_map(text): + words = [] + index_map_from_text_to_word = [] + index_map_from_word_to_text = [] + while len(text) > 0: + match_space = re.match(r'^ +', text) + if match_space: + space_str = match_space.group(0) + index_map_from_text_to_word += [None] * len(space_str) + text = text[len(space_str):] + continue + + match_en = re.match(r'^[a-zA-Z0-9]+', text) + if match_en: + en_word = match_en.group(0) + + word_start_pos = len(index_map_from_text_to_word) + word_end_pos = word_start_pos + len(en_word) + index_map_from_word_to_text.append((word_start_pos, word_end_pos)) + + index_map_from_text_to_word += [len(words)] * len(en_word) + + words.append(en_word) + text = text[len(en_word):] + else: + word_start_pos = len(index_map_from_text_to_word) + word_end_pos = word_start_pos + 1 + index_map_from_word_to_text.append((word_start_pos, word_end_pos)) + + index_map_from_text_to_word += [len(words)] + + words.append(text[0]) + text = text[1:] + return words, index_map_from_text_to_word, index_map_from_word_to_text + + +def tokenize_and_map(tokenizer, text): + words, text2word, word2text = wordize_and_map(text) + + tokens = [] + index_map_from_token_to_text = [] + for word, (word_start, word_end) in zip(words, word2text): + word_tokens = tokenizer.tokenize(word) + + if len(word_tokens) == 0 or word_tokens == ['[UNK]']: + index_map_from_token_to_text.append((word_start, word_end)) + tokens.append('[UNK]') + else: + current_word_start = word_start + for word_token in word_tokens: + word_token_len = len(re.sub(r'^##', '', word_token)) + index_map_from_token_to_text.append( + (current_word_start, current_word_start + word_token_len)) + current_word_start = current_word_start + word_token_len + tokens.append(word_token) + + index_map_from_text_to_token = text2word + for i, (token_start, token_end) in enumerate(index_map_from_token_to_text): + for token_pos in range(token_start, token_end): + index_map_from_text_to_token[token_pos] = i + + return tokens, index_map_from_text_to_token, index_map_from_token_to_text + + +def _load_config(config_path): + import importlib.util + spec = importlib.util.spec_from_file_location('__init__', config_path) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + return config + + +default_config_dict = { + 'manual_seed': 1313, + 'model_source': 'bert-base-chinese', + 'window_size': 32, + 'num_workers': 2, + 'use_mask': True, + 'use_char_phoneme': False, + 'use_conditional': True, + 'param_conditional': { + 'affect_location': 'softmax', + 'bias': True, + 'char-linear': True, + 'pos-linear': False, + 'char+pos-second': True, + + 'char+pos-second_lowrank': False, + 'lowrank_size': 0, + 'char+pos-second_fm': False, + 'fm_size': 0, + 'fix_mode': None, + 'count_json': 'train.count.json' + }, + 'lr': 5e-5, + 'val_interval': 200, + 'num_iter': 10000, + 'use_focal': False, + 'param_focal': { + 'alpha': 0.0, + 'gamma': 0.7 + }, + 'use_pos': True, + 'param_pos ': { + 'weight': 0.1, + 'pos_joint_training': True, + 'train_pos_path': 'train.pos', + 'valid_pos_path': 'dev.pos', + 'test_pos_path': 'test.pos' + } +} + + +def load_config(config_path, use_default=False): + config = _load_config(config_path) + if use_default: + for attr, val in default_config_dict.items(): + if not hasattr(config, attr): + setattr(config, attr, val) + elif isinstance(val, dict): + d = getattr(config, attr) + for dict_k, dict_v in val.items(): + if dict_k not in d: + d[dict_k] = dict_v + return config \ No newline at end of file diff --git a/paddlespeech/t2s/frontend/polyphonic.yaml b/paddlespeech/t2s/frontend/polyphonic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..629bcd262f55881b87e55dcfbba85fc38fff8f68 --- /dev/null +++ b/paddlespeech/t2s/frontend/polyphonic.yaml @@ -0,0 +1,26 @@ +polyphonic: + 湖泊: ['hu2','po1'] + 地壳: ['di4','qiao4'] + 柏树: ['bai3','shu4'] + 曝光: ['bao4','guang1'] + 弹力: ['tan2','li4'] + 字帖: ['zi4','tie4'] + 口吃: ['kou3','chi1'] + 包扎: ['bao1','za1'] + 哪吒: ['ne2','zha1'] + 说服: ['shuo1','fu2'] + 识字: ['shi2','zi4'] + 骨头: ['gu3','tou5'] + 对称: ['dui4','chen4'] + 口供: ['kou3','gong4'] + 抹布: ['ma1','bu4'] + 露背: ['lu4','bei4'] + 圈养: ['juan4', 'yang3'] + 眼眶: ['yan3', 'kuang4'] + 品行: ['pin3','xing2'] + 颤抖: ['chan4','dou3'] + 差不多: ['cha4','bu5','duo1'] + 鸭绿江: ['ya1','lu4','jiang1'] + 撒切尔: ['sa4','qie4','er3'] + 比比皆是: ['bi3','bi3','jie1','shi4'] + 身无长物: ['shen1','wu2','chang2','wu4'] \ No newline at end of file diff --git a/paddlespeech/t2s/frontend/zh_frontend.py b/paddlespeech/t2s/frontend/zh_frontend.py index ef8963c085921e9a40aad57eb2d41bbc01a70281..e612999921df32485187911120d7e4fa28427328 100644 --- a/paddlespeech/t2s/frontend/zh_frontend.py +++ b/paddlespeech/t2s/frontend/zh_frontend.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import re +import os +import yaml from typing import Dict from typing import List @@ -25,6 +27,7 @@ from pypinyin import load_single_dict from pypinyin import Style from pypinyin_dict.phrase_pinyin_data import large_pinyin +from paddlespeech.t2s.frontend.g2pw import G2PWOnnxConverter from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer @@ -53,20 +56,42 @@ def insert_after_character(lst, item): return result +class Polyphonic(): + def __init__(self): + with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'polyphonic.yaml'), 'r',encoding='utf-8') as polyphonic_file: + # 解析yaml + polyphonic_dict = yaml.load(polyphonic_file, Loader=yaml.FullLoader) + self.polyphonic_words = polyphonic_dict["polyphonic"] + + def correct_pronunciation(self,word,pinyin): + # 词汇被词典收录则返回纠正后的读音 + if word in self.polyphonic_words.keys(): + pinyin = self.polyphonic_words[word] + # 否则返回原读音 + return pinyin + class Frontend(): def __init__(self, - g2p_model="pypinyin", + g2p_model="g2pW", phone_vocab_path=None, tone_vocab_path=None): self.tone_modifier = ToneSandhi() self.text_normalizer = TextNormalizer() self.punc = ":,;。?!“”‘’':,;.?!" - # g2p_model can be pypinyin and g2pM + # g2p_model can be pypinyin and g2pM and g2pW self.g2p_model = g2p_model if self.g2p_model == "g2pM": self.g2pM_model = G2pM() self.pinyin2phone = generate_lexicon( with_tone=True, with_erhua=False) + elif self.g2p_model == "g2pW": + self.corrector = Polyphonic() + self.g2pM_model = G2pM() + self.g2pW_model = G2PWOnnxConverter(style='pinyin', enable_non_tradional_chinese=True) + self.pinyin2phone = generate_lexicon( + with_tone=True, with_erhua=False) + else: self.__init__pypinyin() self.must_erhua = {"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿"} @@ -156,18 +181,63 @@ class Frontend(): initials = [] finals = [] seg_cut = self.tone_modifier.pre_merge_for_modify(seg_cut) - for word, pos in seg_cut: - if pos == 'eng': - continue - sub_initials, sub_finals = self._get_initials_finals(word) - sub_finals = self.tone_modifier.modified_tone(word, pos, - sub_finals) - if with_erhua: - sub_initials, sub_finals = self._merge_erhua( - sub_initials, sub_finals, word, pos) - initials.append(sub_initials) - finals.append(sub_finals) - # assert len(sub_initials) == len(sub_finals) == len(word) + # 为了多音词获得更好的效果,这里采用整句预测 + if self.g2p_model == "g2pW": + try: + pinyins = self.g2pW_model(seg)[0] + except Exception: + # g2pW采用模型采用繁体输入,如果有cover不了的简体词,采用g2pM预测 + print("[%s] not in g2pW dict,use g2pM"%seg) + pinyins = self.g2pM_model(seg, tone=True, char_split=False) + pre_word_length = 0 + for word, pos in seg_cut: + sub_initials = [] + sub_finals = [] + now_word_length = pre_word_length + len(word) + if pos == 'eng': + pre_word_length = now_word_length + continue + word_pinyins = pinyins[pre_word_length:now_word_length] + # 矫正发音 + word_pinyins = self.corrector.correct_pronunciation(word,word_pinyins) + for pinyin,char in zip(word_pinyins,word): + if pinyin == None: + pinyin = char + pinyin = pinyin.replace("u:", "v") + if pinyin in self.pinyin2phone: + initial_final_list = self.pinyin2phone[pinyin].split(" ") + if len(initial_final_list) == 2: + sub_initials.append(initial_final_list[0]) + sub_finals.append(initial_final_list[1]) + elif len(initial_final_list) == 1: + sub_initials.append('') + sub_finals.append(initial_final_list[1]) + else: + # If it's not pinyin (possibly punctuation) or no conversion is required + sub_initials.append(pinyin) + sub_finals.append(pinyin) + pre_word_length = now_word_length + sub_finals = self.tone_modifier.modified_tone(word, pos, + sub_finals) + if with_erhua: + sub_initials, sub_finals = self._merge_erhua( + sub_initials, sub_finals, word, pos) + initials.append(sub_initials) + finals.append(sub_finals) + # assert len(sub_initials) == len(sub_finals) == len(word) + else: + for word, pos in seg_cut: + if pos == 'eng': + continue + sub_initials, sub_finals = self._get_initials_finals(word) + sub_finals = self.tone_modifier.modified_tone(word, pos, + sub_finals) + if with_erhua: + sub_initials, sub_finals = self._merge_erhua( + sub_initials, sub_finals, word, pos) + initials.append(sub_initials) + finals.append(sub_finals) + # assert len(sub_initials) == len(sub_finals) == len(word) initials = sum(initials, []) finals = sum(finals, []) diff --git a/setup.py b/setup.py index 56d9e19e8ae7e012c9a281368733ddb44bb11539..079803b7effb514a5c9a589afd9bd2385496cd75 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ base = [ "matplotlib", "nara_wpe", "onnxruntime==1.10.0", + "opencc", "pandas", "paddlenlp", "paddlespeech_feat", @@ -329,4 +330,4 @@ setup_info = dict( }) with version_info(): - setup(**setup_info) + setup(**setup_info,include_package_data=True)