diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index 4e0337bccea500c382f0782860cec36ad4897c46..5468f257f220770f5a6cb152351cfa221ec34ef5 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -175,14 +175,21 @@ class TTSExecutor(BaseExecutor): if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'): logger.info('Models had been initialized.') return + # am + if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: + use_pretrained_am = True + else: + use_pretrained_am = False + am_tag = am + '-' + lang self.task_resource.set_task_model( model_tag=am_tag, model_type=0, # am + skip_download=not use_pretrained_am, version=None, # default version ) - if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: + if use_pretrained_am: self.am_res_path = self.task_resource.res_dir self.am_config = os.path.join(self.am_res_path, self.task_resource.res_dict['config']) @@ -220,13 +227,19 @@ class TTSExecutor(BaseExecutor): self.speaker_dict = speaker_dict # voc + if voc_ckpt is None or voc_config is None or voc_stat is None: + use_pretrained_voc = True + else: + use_pretrained_voc = False + voc_tag = voc + '-' + lang self.task_resource.set_task_model( model_tag=voc_tag, model_type=1, # vocoder + skip_download=not use_pretrained_voc, version=None, # default version ) - if voc_ckpt is None or voc_config is None or voc_stat is None: + if use_pretrained_voc: self.voc_res_path = self.task_resource.voc_res_dir self.voc_config = os.path.join( self.voc_res_path, self.task_resource.voc_res_dict['config']) diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py index 70f12b64c2dc5bbf6ef508b41872e0504855d6fb..8e9914b2e13912d34413f92ff042cd1f3cbd95d0 100644 --- a/paddlespeech/resource/resource.py +++ b/paddlespeech/resource/resource.py @@ -60,6 +60,7 @@ class CommonTaskResource: def set_task_model(self, model_tag: str, model_type: int=0, + skip_download: bool=False, version: Optional[str]=None): """Set model tag and version of current task. @@ -83,16 +84,18 @@ class CommonTaskResource: self.version = version self.res_dict = self.pretrained_models[model_tag][version] self._format_path(self.res_dict) - self.res_dir = self._fetch(self.res_dict, - self._get_model_dir(model_type)) + if not skip_download: + self.res_dir = self._fetch(self.res_dict, + self._get_model_dir(model_type)) else: assert self.task == 'tts', 'Vocoder will only be used in tts task.' self.voc_model_tag = model_tag self.voc_version = version self.voc_res_dict = self.pretrained_models[model_tag][version] self._format_path(self.voc_res_dict) - self.voc_res_dir = self._fetch(self.voc_res_dict, - self._get_model_dir(model_type)) + if not skip_download: + self.voc_res_dir = self._fetch(self.voc_res_dict, + self._get_model_dir(model_type)) @staticmethod def get_model_class(model_name) -> List[object]: