From adc7c9b4aac2825e0abe56df9566a02d6d85b969 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Wed, 29 Jun 2022 17:57:21 +0800 Subject: [PATCH] Fix unnecessary download present in issue #2067. --- paddlespeech/cli/tts/infer.py | 17 +++++++++++++++-- paddlespeech/resource/resource.py | 11 +++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index 4e0337bc..5468f257 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -175,14 +175,21 @@ class TTSExecutor(BaseExecutor): if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'): logger.info('Models had been initialized.') return + # am + if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: + use_pretrained_am = True + else: + use_pretrained_am = False + am_tag = am + '-' + lang self.task_resource.set_task_model( model_tag=am_tag, model_type=0, # am + skip_download=not use_pretrained_am, version=None, # default version ) - if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: + if use_pretrained_am: self.am_res_path = self.task_resource.res_dir self.am_config = os.path.join(self.am_res_path, self.task_resource.res_dict['config']) @@ -220,13 +227,19 @@ class TTSExecutor(BaseExecutor): self.speaker_dict = speaker_dict # voc + if voc_ckpt is None or voc_config is None or voc_stat is None: + use_pretrained_voc = True + else: + use_pretrained_voc = False + voc_tag = voc + '-' + lang self.task_resource.set_task_model( model_tag=voc_tag, model_type=1, # vocoder + skip_download=not use_pretrained_voc, version=None, # default version ) - if voc_ckpt is None or voc_config is None or voc_stat is None: + if use_pretrained_voc: self.voc_res_path = self.task_resource.voc_res_dir self.voc_config = os.path.join( self.voc_res_path, self.task_resource.voc_res_dict['config']) diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py index 70f12b64..8e9914b2 100644 --- a/paddlespeech/resource/resource.py +++ b/paddlespeech/resource/resource.py @@ -60,6 +60,7 @@ class CommonTaskResource: def set_task_model(self, model_tag: str, model_type: int=0, + skip_download: bool=False, version: Optional[str]=None): """Set model tag and version of current task. @@ -83,16 +84,18 @@ class CommonTaskResource: self.version = version self.res_dict = self.pretrained_models[model_tag][version] self._format_path(self.res_dict) - self.res_dir = self._fetch(self.res_dict, - self._get_model_dir(model_type)) + if not skip_download: + self.res_dir = self._fetch(self.res_dict, + self._get_model_dir(model_type)) else: assert self.task == 'tts', 'Vocoder will only be used in tts task.' self.voc_model_tag = model_tag self.voc_version = version self.voc_res_dict = self.pretrained_models[model_tag][version] self._format_path(self.voc_res_dict) - self.voc_res_dir = self._fetch(self.voc_res_dict, - self._get_model_dir(model_type)) + if not skip_download: + self.voc_res_dir = self._fetch(self.voc_res_dict, + self._get_model_dir(model_type)) @staticmethod def get_model_class(model_name) -> List[object]: -- GitLab