diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index 2684e9edfddc8974390f5ce8a64bd766d4dd077e..65d6d5282590ba55a2dcb52fe1dd3a1b01ba103d 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -403,8 +403,6 @@ class TTSExecutor(BaseExecutor): with open(self.voc_config) as f: self.voc_config = CfgNode(yaml.safe_load(f)) - # Enter the path of model root - with open(self.phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] vocab_size = len(phn_id) @@ -499,10 +497,10 @@ class TTSExecutor(BaseExecutor): """ Model inference and result stored in self.output. """ - model_name = am[:am.rindex('_')] - dataset = am[am.rindex('_') + 1:] + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] get_tone_ids = False - if 'speedyspeech' in model_name: + if am_name == 'speedyspeech': get_tone_ids = True if lang == 'zh': input_ids = self.frontend.get_input_ids( @@ -519,15 +517,14 @@ class TTSExecutor(BaseExecutor): print("lang should in {'zh', 'en'}!") # am - if 'speedyspeech' in model_name: + if am_name == 'speedyspeech': mel = self.am_inference(phone_ids, tone_ids) # fastspeech2 else: # multi speaker - if dataset in {"aishell3", "vctk"}: + if am_dataset in {"aishell3", "vctk"}: mel = self.am_inference( phone_ids, spk_id=paddle.to_tensor(spk_id)) - else: mel = self.am_inference(phone_ids)