diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index eb68df5d616653000178ad5846a503cf092e5afd..fdf93e2c27180d6ffd87abec57f1cdd5e28e2026 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -405,8 +405,6 @@ class TTSExecutor(BaseExecutor): with open(self.voc_config) as f: self.voc_config = CfgNode(yaml.safe_load(f)) - # Enter the path of model root - with open(self.phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] vocab_size = len(phn_id) @@ -501,10 +499,10 @@ class TTSExecutor(BaseExecutor): """ Model inference and result stored in self.output. """ - model_name = am[:am.rindex('_')] - dataset = am[am.rindex('_') + 1:] + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] get_tone_ids = False - if 'speedyspeech' in model_name: + if am_name == 'speedyspeech': get_tone_ids = True if lang == 'zh': input_ids = self.frontend.get_input_ids( @@ -521,15 +519,14 @@ class TTSExecutor(BaseExecutor): print("lang should in {'zh', 'en'}!") # am - if 'speedyspeech' in model_name: + if am_name == 'speedyspeech': mel = self.am_inference(phone_ids, tone_ids) # fastspeech2 else: # multi speaker - if dataset in {"aishell3", "vctk"}: + if am_dataset in {"aishell3", "vctk"}: mel = self.am_inference( phone_ids, spk_id=paddle.to_tensor(spk_id)) - else: mel = self.am_inference(phone_ids)