diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 09e8202fd7d54e59b9c535b6b5a598123147f539..842acf5ce65a00057fad00def8050ca7dcaee48e 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -33,11 +33,8 @@ from ..utils import CLI_TIMER from ..utils import MODEL_HOME from ..utils import stats_wrapper from ..utils import timer_register -from .pretrained_models import model_alias -from .pretrained_models import pretrained_models from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.transform.transformation import Transformation -from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['ASRExecutor'] @@ -46,10 +43,7 @@ __all__ = ['ASRExecutor'] @timer_register class ASRExecutor(BaseExecutor): def __init__(self): - super().__init__() - self.model_alias = model_alias - self.pretrained_models = pretrained_models - + super().__init__(task='asr', inference_type='offline') self.parser = argparse.ArgumentParser( prog='paddlespeech.asr', add_help=True) self.parser.add_argument( @@ -59,7 +53,8 @@ class ASRExecutor(BaseExecutor): type=str, default='conformer_wenetspeech', choices=[ - tag[:tag.index('-')] for tag in self.pretrained_models.keys() + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() ], help='Choose model type of asr task.') self.parser.add_argument( @@ -141,14 +136,14 @@ class ASRExecutor(BaseExecutor): if cfg_path is None or ckpt_path is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str - res_path = self._get_pretrained_path(tag) # wenetspeech_zh - self.res_path = res_path + self.task_resource.set_task_model(tag, version=None) + self.res_path = self.task_resource.res_dir self.cfg_path = os.path.join( - res_path, self.pretrained_models[tag]['cfg_path']) + self.res_path, self.task_resource.res_dict['cfg_path']) self.ckpt_path = os.path.join( - res_path, - self.pretrained_models[tag]['ckpt_path'] + ".pdparams") - logger.info(res_path) + self.res_path, + self.task_resource.res_dict['ckpt_path'] + ".pdparams") + logger.info(self.res_path) else: self.cfg_path = os.path.abspath(cfg_path) @@ -172,8 +167,8 @@ class ASRExecutor(BaseExecutor): self.collate_fn_test = SpeechCollator.from_config(self.config) self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.vocab) - lm_url = self.pretrained_models[tag]['lm_url'] - lm_md5 = self.pretrained_models[tag]['lm_md5'] + lm_url = self.resource.res_dict['lm_url'] + lm_md5 = self.resource.res_dict['lm_md5'] self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) @@ -191,7 +186,7 @@ class ASRExecutor(BaseExecutor): raise Exception("wrong type") model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} - model_class = dynamic_import(model_name, self.model_alias) + model_class = self.task_resource.get_model_class(model_name) model_conf = self.config model = model_class.from_config(model_conf) self.model = model @@ -438,7 +433,7 @@ class ASRExecutor(BaseExecutor): if not parser_args.verbose: self.disable_task_loggers() - task_source = self.get_task_source(parser_args.input) + task_source = self.get_input_source(parser_args.input) task_results = OrderedDict() has_exceptions = False diff --git a/paddlespeech/cli/asr/pretrained_models.py b/paddlespeech/cli/asr/pretrained_models.py deleted file mode 100644 index 0f521884020b039a074ad302100a58a59e4d77b1..0000000000000000000000000000000000000000 --- a/paddlespeech/cli/asr/pretrained_models.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k". - # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: - # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" - "conformer_wenetspeech-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz', - 'md5': - '76cb19ed857e6623856b7cd7ebbfeda4', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/conformer/checkpoints/wenetspeech', - }, - "conformer_online_wenetspeech-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz', - 'md5': - 'b8c02632b04da34aca88459835be54a6', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/chunk_conformer/checkpoints/avg_10', - }, - "conformer_online_multicn-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz', - 'md5': - '7989b3248c898070904cf042fd656003', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/chunk_conformer/checkpoints/multi_cn', - }, - "conformer_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz', - 'md5': - '3f073eccfa7bb14e0c6867d65fc0dc3a', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/conformer/checkpoints/avg_30', - }, - "conformer_online_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz', - 'md5': - 'b374cfb93537761270b6224fb0bfc26a', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/chunk_conformer/checkpoints/avg_30', - }, - "transformer_librispeech-en-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz', - 'md5': - '2c667da24922aad391eacafe37bc1660', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/transformer/checkpoints/avg_10', - }, - "deepspeech2online_wenetspeech-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz', - 'md5': - 'e393d4d274af0f6967db24fc146e8074', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2_online/checkpoints/avg_10', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "deepspeech2offline_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', - 'md5': - '932c3593d62fe5c741b59b31318aa314', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2/checkpoints/avg_1', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "deepspeech2online_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz', - 'md5': - '98b87b171b7240b7cae6e07d8d0bc9be', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2_online/checkpoints/avg_1', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "deepspeech2offline_librispeech-en-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz', - 'md5': - 'f5666c81ad015c8de03aac2bc92e5762', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2/checkpoints/avg_1', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm', - 'lm_md5': - '099a601759d467cd0a8523ff939819c5' - }, -} - -model_alias = { - "deepspeech2offline": - "paddlespeech.s2t.models.ds2:DeepSpeech2Model", - "deepspeech2online": - "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", - "conformer": - "paddlespeech.s2t.models.u2:U2Model", - "conformer_online": - "paddlespeech.s2t.models.u2:U2Model", - "transformer": - "paddlespeech.s2t.models.u2:U2Model", - "wenetspeech": - "paddlespeech.s2t.models.u2:U2Model", -} diff --git a/paddlespeech/cli/base_commands.py b/paddlespeech/cli/base_commands.py index 4d4d2cc69b0eec34b626a84ee237c1c4c4c540a2..39bf24524d27318de2af8d519076f92da4e3db01 100644 --- a/paddlespeech/cli/base_commands.py +++ b/paddlespeech/cli/base_commands.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import argparse from typing import List +from prettytable import PrettyTable + +from ..resource import CommonTaskResource from .entry import commands from .utils import cli_register from .utils import explicit_command_register from .utils import get_command -__all__ = [ - 'BaseCommand', - 'HelpCommand', -] +__all__ = ['BaseCommand', 'HelpCommand', 'StatsCommand'] @cli_register(name='paddlespeech') @@ -76,6 +77,59 @@ class VersionCommand: return True +model_name_format = { + 'asr': 'Model-Language-Sample Rate', + 'cls': 'Model-Sample Rate', + 'st': 'Model-Source language-Target language', + 'text': 'Model-Task-Language', + 'tts': 'Model-Language', + 'vector': 'Model-Sample Rate' +} + + +@cli_register( + name='paddlespeech.stats', + description='Get speech tasks support models list.') +class StatsCommand: + def __init__(self): + self.parser = argparse.ArgumentParser( + prog='paddlespeech.stats', add_help=True) + self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector'] + self.parser.add_argument( + '--task', + type=str, + default='asr', + choices=self.task_choices, + help='Choose speech task.', + required=True) + + def show_support_models(self, pretrained_models: dict): + fields = model_name_format[self.task].split("-") + table = PrettyTable(fields) + for key in pretrained_models: + table.add_row(key.split("-")) + print(table) + + def execute(self, argv: List[str]) -> bool: + parser_args = self.parser.parse_args(argv) + self.task = parser_args.task + if self.task not in self.task_choices: + print("Please input correct speech task, choices = " + str( + self.task_choices)) + return + + pretrained_models = CommonTaskResource(task=self.task).pretrained_models + + try: + print( + "Here is the list of {} pretrained models released by PaddleSpeech that can be used by command line and python API" + .format(self.task.upper())) + self.show_support_models(pretrained_models) + except BaseException: + print("Failed to get the list of {} pretrained models.".format( + self.task.upper())) + + # Dynamic import when running specific command _commands = { 'asr': ['Speech to text infer command.', 'ASRExecutor'], @@ -91,3 +145,4 @@ for com, info in _commands.items(): name='paddlespeech.{}'.format(com), description=info[0], cls='paddlespeech.cli.{}.{}'.format(com, info[1])) + \ No newline at end of file diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index 3d807b60b3d03d4620875582b41f26f5b699c45b..1a9949748f339e838ac1bed7400308aedd8eb1c9 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -21,26 +21,19 @@ from typing import Union import numpy as np import paddle import yaml -from paddleaudio import load -from paddleaudio.features import LogMelSpectrogram -from paddlespeech.utils.dynamic_import import dynamic_import from ..executor import BaseExecutor from ..log import logger from ..utils import stats_wrapper -from .pretrained_models import model_alias -from .pretrained_models import pretrained_models - +from paddleaudio import load +from paddleaudio.features import LogMelSpectrogram __all__ = ['CLSExecutor'] class CLSExecutor(BaseExecutor): def __init__(self): - super().__init__() - self.model_alias = model_alias - self.pretrained_models = pretrained_models - + super().__init__(task='cls') self.parser = argparse.ArgumentParser( prog='paddlespeech.cls', add_help=True) self.parser.add_argument( @@ -50,7 +43,8 @@ class CLSExecutor(BaseExecutor): type=str, default='panns_cnn14', choices=[ - tag[:tag.index('-')] for tag in self.pretrained_models.keys() + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() ], help='Choose model type of cls task.') self.parser.add_argument( @@ -103,13 +97,16 @@ class CLSExecutor(BaseExecutor): if label_file is None or ckpt_path is None: tag = model_type + '-' + '32k' # panns_cnn14-32k - self.res_path = self._get_pretrained_path(tag) + self.task_resource.set_task_model(tag, version=None) self.cfg_path = os.path.join( - self.res_path, self.pretrained_models[tag]['cfg_path']) + self.task_resource.res_dir, + self.task_resource.res_dict['cfg_path']) self.label_file = os.path.join( - self.res_path, self.pretrained_models[tag]['label_file']) + self.task_resource.res_dir, + self.task_resource.res_dict['label_file']) self.ckpt_path = os.path.join( - self.res_path, self.pretrained_models[tag]['ckpt_path']) + self.task_resource.res_dir, + self.task_resource.res_dict['ckpt_path']) else: self.cfg_path = os.path.abspath(cfg_path) self.label_file = os.path.abspath(label_file) @@ -126,7 +123,7 @@ class CLSExecutor(BaseExecutor): self._label_list.append(line.strip()) # model - model_class = dynamic_import(model_type, self.model_alias) + model_class = self.task_resource.get_model_class(model_type) model_dict = paddle.load(self.ckpt_path) self.model = model_class(extract_embedding=False) self.model.set_state_dict(model_dict) @@ -203,7 +200,7 @@ class CLSExecutor(BaseExecutor): if not parser_args.verbose: self.disable_task_loggers() - task_source = self.get_task_source(parser_args.input) + task_source = self.get_input_source(parser_args.input) task_results = OrderedDict() has_exceptions = False diff --git a/paddlespeech/cli/cls/pretrained_models.py b/paddlespeech/cli/cls/pretrained_models.py deleted file mode 100644 index 1d66850aa7fa55733c8a0680889906894e235126..0000000000000000000000000000000000000000 --- a/paddlespeech/cli/cls/pretrained_models.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". - # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: - # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" - "panns_cnn6-32k": { - 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz', - 'md5': '4cf09194a95df024fd12f84712cf0f9c', - 'cfg_path': 'panns.yaml', - 'ckpt_path': 'cnn6.pdparams', - 'label_file': 'audioset_labels.txt', - }, - "panns_cnn10-32k": { - 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz', - 'md5': 'cb8427b22176cc2116367d14847f5413', - 'cfg_path': 'panns.yaml', - 'ckpt_path': 'cnn10.pdparams', - 'label_file': 'audioset_labels.txt', - }, - "panns_cnn14-32k": { - 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz', - 'md5': 'e3b9b5614a1595001161d0ab95edee97', - 'cfg_path': 'panns.yaml', - 'ckpt_path': 'cnn14.pdparams', - 'label_file': 'audioset_labels.txt', - }, -} - -model_alias = { - "panns_cnn6": "paddlespeech.cls.models.panns:CNN6", - "panns_cnn10": "paddlespeech.cls.models.panns:CNN10", - "panns_cnn14": "paddlespeech.cls.models.panns:CNN14", -} diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index 4a631c7f5b8d73ad1095a3d00b8ddbbc9615a8e5..d390f947d17cccc99a12eee75f634242e4bac9bb 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -24,9 +24,8 @@ from typing import Union import paddle +from ..resource import CommonTaskResource from .log import logger -from .utils import download_and_decompress -from .utils import MODEL_HOME class BaseExecutor(ABC): @@ -34,11 +33,10 @@ class BaseExecutor(ABC): An abstract executor of paddlespeech tasks. """ - def __init__(self): + def __init__(self, task: str, **kwargs): self._inputs = OrderedDict() self._outputs = OrderedDict() - self.pretrained_models = OrderedDict() - self.model_alias = OrderedDict() + self.task_resource = CommonTaskResource(task=task, **kwargs) @abstractmethod def _init_from_path(self, *args, **kwargs): @@ -98,8 +96,8 @@ class BaseExecutor(ABC): """ pass - def get_task_source(self, input_: Union[str, os.PathLike, None] - ) -> Dict[str, Union[str, os.PathLike]]: + def get_input_source(self, input_: Union[str, os.PathLike, None] + ) -> Dict[str, Union[str, os.PathLike]]: """ Get task input source from command line input. @@ -115,15 +113,17 @@ class BaseExecutor(ABC): ret = OrderedDict() if input_ is None: # Take input from stdin - for i, line in enumerate(sys.stdin): - line = line.strip() - if len(line.split(' ')) == 1: - ret[str(i + 1)] = line - elif len(line.split(' ')) == 2: - id_, info = line.split(' ') - ret[id_] = info - else: # No valid input info from one line. - continue + if not sys.stdin.isatty( + ): # Avoid getting stuck when stdin is empty. + for i, line in enumerate(sys.stdin): + line = line.strip() + if len(line.split(' ')) == 1: + ret[str(i + 1)] = line + elif len(line.split(' ')) == 2: + id_, info = line.split(' ') + ret[id_] = info + else: # No valid input info from one line. + continue else: ret[1] = input_ return ret @@ -219,23 +219,6 @@ class BaseExecutor(ABC): for l in loggers: l.disabled = True - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - """ - support_models = list(self.pretrained_models.keys()) - assert tag in self.pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( - tag, '\n\t\t'.join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(self.pretrained_models[tag], - res_path) - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - 'Use pretrained model stored in: {}'.format(decompressed_path)) - - return decompressed_path - def show_rtf(self, info: Dict[str, List[float]]): """ Calculate rft of current task and show results. diff --git a/paddlespeech/cli/st/infer.py b/paddlespeech/cli/st/infer.py index ae188b349632bd7af811b363143ed845db012392..e1ce181af351c4bf651a913d2de7005c5dc37e51 100644 --- a/paddlespeech/cli/st/infer.py +++ b/paddlespeech/cli/st/infer.py @@ -31,21 +31,22 @@ from ..log import logger from ..utils import download_and_decompress from ..utils import MODEL_HOME from ..utils import stats_wrapper -from .pretrained_models import kaldi_bins -from .pretrained_models import model_alias -from .pretrained_models import pretrained_models from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.utils.utility import UpdateConfig -from paddlespeech.utils.dynamic_import import dynamic_import __all__ = ["STExecutor"] +kaldi_bins = { + "url": + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz", + "md5": + "c0682303b3f3393dbf6ed4c4e35a53eb", +} + class STExecutor(BaseExecutor): def __init__(self): - super().__init__() - self.model_alias = model_alias - self.pretrained_models = pretrained_models + super().__init__(task='st') self.kaldi_bins = kaldi_bins self.parser = argparse.ArgumentParser( @@ -57,7 +58,8 @@ class STExecutor(BaseExecutor): type=str, default="fat_st_ted", choices=[ - tag[:tag.index('-')] for tag in self.pretrained_models.keys() + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() ], help="Choose model type of st task.") self.parser.add_argument( @@ -131,14 +133,16 @@ class STExecutor(BaseExecutor): if cfg_path is None or ckpt_path is None: tag = model_type + "-" + src_lang + "-" + tgt_lang - res_path = self._get_pretrained_path(tag) - self.cfg_path = os.path.join(res_path, - pretrained_models[tag]["cfg_path"]) - self.ckpt_path = os.path.join(res_path, - pretrained_models[tag]["ckpt_path"]) - logger.info(res_path) + self.task_resource.set_task_model(tag, version=None) + self.cfg_path = os.path.join( + self.task_resource.res_dir, + self.task_resource.res_dict['cfg_path']) + self.ckpt_path = os.path.join( + self.task_resource.res_dir, + self.task_resource.res_dict['ckpt_path']) logger.info(self.cfg_path) logger.info(self.ckpt_path) + res_path = self.task_resource.res_dir else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path) @@ -163,7 +167,7 @@ class STExecutor(BaseExecutor): model_conf = self.config model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} - model_class = dynamic_import(model_name, self.model_alias) + model_class = self.task_resource.get_model_class(model_name) self.model = model_class.from_config(model_conf) self.model.eval() @@ -301,7 +305,7 @@ class STExecutor(BaseExecutor): if not parser_args.verbose: self.disable_task_loggers() - task_source = self.get_task_source(parser_args.input) + task_source = self.get_input_source(parser_args.input) task_results = OrderedDict() has_exceptions = False diff --git a/paddlespeech/cli/st/pretrained_models.py b/paddlespeech/cli/st/pretrained_models.py deleted file mode 100644 index cc7410d253f34109424e49ea0d2622e12ce93ea5..0000000000000000000000000000000000000000 --- a/paddlespeech/cli/st/pretrained_models.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - "fat_st_ted-en-zh": { - "url": - "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz", - "md5": - "d62063f35a16d91210a71081bd2dd557", - "cfg_path": - "model.yaml", - "ckpt_path": - "exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams", - } -} - -model_alias = {"fat_st": "paddlespeech.s2t.models.u2_st:U2STModel"} - -kaldi_bins = { - "url": - "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz", - "md5": - "c0682303b3f3393dbf6ed4c4e35a53eb", -} diff --git a/paddlespeech/cli/stats/infer.py b/paddlespeech/cli/stats/infer.py deleted file mode 100644 index 7cf4f2368cbced90bac54cb61bdc1bd8fc3d07f8..0000000000000000000000000000000000000000 --- a/paddlespeech/cli/stats/infer.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -from typing import List - -from prettytable import PrettyTable - -from ..utils import cli_register -from ..utils import stats_wrapper - -__all__ = ['StatsExecutor'] - -model_name_format = { - 'asr': 'Model-Language-Sample Rate', - 'cls': 'Model-Sample Rate', - 'st': 'Model-Source language-Target language', - 'text': 'Model-Task-Language', - 'tts': 'Model-Language', - 'vector': 'Model-Sample Rate' -} - - -@cli_register( - name='paddlespeech.stats', - description='Get speech tasks support models list.') -class StatsExecutor(): - def __init__(self): - super().__init__() - - self.parser = argparse.ArgumentParser( - prog='paddlespeech.stats', add_help=True) - self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector'] - self.parser.add_argument( - '--task', - type=str, - default='asr', - choices=self.task_choices, - help='Choose speech task.', - required=True) - - def show_support_models(self, pretrained_models: dict): - fields = model_name_format[self.task].split("-") - table = PrettyTable(fields) - for key in pretrained_models: - table.add_row(key.split("-")) - print(table) - - def execute(self, argv: List[str]) -> bool: - """ - Command line entry. - """ - parser_args = self.parser.parse_args(argv) - has_exceptions = False - try: - self(parser_args.task) - except Exception as e: - has_exceptions = True - if has_exceptions: - return False - else: - return True - - @stats_wrapper - def __call__( - self, - task: str=None, ): - """ - Python API to call an executor. - """ - self.task = task - if self.task not in self.task_choices: - print("Please input correct speech task, choices = " + str( - self.task_choices)) - - elif self.task == 'asr': - try: - from ..asr.pretrained_models import pretrained_models - print( - "Here is the list of ASR pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - except BaseException: - print("Failed to get the list of ASR pretrained models.") - - elif self.task == 'cls': - try: - from ..cls.pretrained_models import pretrained_models - print( - "Here is the list of CLS pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - except BaseException: - print("Failed to get the list of CLS pretrained models.") - - elif self.task == 'st': - try: - from ..st.pretrained_models import pretrained_models - print( - "Here is the list of ST pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - except BaseException: - print("Failed to get the list of ST pretrained models.") - - elif self.task == 'text': - try: - from ..text.pretrained_models import pretrained_models - print( - "Here is the list of TEXT pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - except BaseException: - print("Failed to get the list of TEXT pretrained models.") - - elif self.task == 'tts': - try: - from ..tts.pretrained_models import pretrained_models - print( - "Here is the list of TTS pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - except BaseException: - print("Failed to get the list of TTS pretrained models.") - - elif self.task == 'vector': - try: - from ..vector.pretrained_models import pretrained_models - print( - "Here is the list of Speaker Recognition pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - except BaseException: - print( - "Failed to get the list of Speaker Recognition pretrained models." - ) diff --git a/paddlespeech/cli/text/infer.py b/paddlespeech/cli/text/infer.py index be5b5a10d474c2a50d448d955ae8cead3a13202b..7b8faf99c84691971744fbef291a714900dc60bc 100644 --- a/paddlespeech/cli/text/infer.py +++ b/paddlespeech/cli/text/infer.py @@ -24,21 +24,13 @@ import paddle from ..executor import BaseExecutor from ..log import logger from ..utils import stats_wrapper -from .pretrained_models import model_alias -from .pretrained_models import pretrained_models -from .pretrained_models import tokenizer_alias -from paddlespeech.utils.dynamic_import import dynamic_import __all__ = ['TextExecutor'] class TextExecutor(BaseExecutor): def __init__(self): - super().__init__() - self.model_alias = model_alias - self.pretrained_models = pretrained_models - self.tokenizer_alias = tokenizer_alias - + super().__init__(task='text') self.parser = argparse.ArgumentParser( prog='paddlespeech.text', add_help=True) self.parser.add_argument( @@ -54,7 +46,8 @@ class TextExecutor(BaseExecutor): type=str, default='ernie_linear_p7_wudao', choices=[ - tag[:tag.index('-')] for tag in self.pretrained_models.keys() + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() ], help='Choose model type of text task.') self.parser.add_argument( @@ -112,13 +105,16 @@ class TextExecutor(BaseExecutor): if cfg_path is None or ckpt_path is None or vocab_file is None: tag = '-'.join([model_type, task, lang]) - self.res_path = self._get_pretrained_path(tag) + self.task_resource.set_task_model(tag, version=None) self.cfg_path = os.path.join( - self.res_path, self.pretrained_models[tag]['cfg_path']) + self.task_resource.res_dir, + self.task_resource.res_dict['cfg_path']) self.ckpt_path = os.path.join( - self.res_path, self.pretrained_models[tag]['ckpt_path']) + self.task_resource.res_dir, + self.task_resource.res_dict['ckpt_path']) self.vocab_file = os.path.join( - self.res_path, self.pretrained_models[tag]['vocab_file']) + self.task_resource.res_dir, + self.task_resource.res_dict['vocab_file']) else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path) @@ -133,8 +129,8 @@ class TextExecutor(BaseExecutor): self._punc_list.append(line.strip()) # model - model_class = dynamic_import(model_name, self.model_alias) - tokenizer_class = dynamic_import(model_name, self.tokenizer_alias) + model_class, tokenizer_class = self.task_resource.get_model_class( + model_name) self.model = model_class( cfg_path=self.cfg_path, ckpt_path=self.ckpt_path) self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0') @@ -224,7 +220,7 @@ class TextExecutor(BaseExecutor): if not parser_args.verbose: self.disable_task_loggers() - task_source = self.get_task_source(parser_args.input) + task_source = self.get_input_source(parser_args.input) task_results = OrderedDict() has_exceptions = False diff --git a/paddlespeech/cli/text/pretrained_models.py b/paddlespeech/cli/text/pretrained_models.py deleted file mode 100644 index 817d3caa3cdc634a202703d4885796b21eee4f56..0000000000000000000000000000000000000000 --- a/paddlespeech/cli/text/pretrained_models.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". - # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: - # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" - "ernie_linear_p7_wudao-punc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz', - 'md5': - '12283e2ddde1797c5d1e57036b512746', - 'cfg_path': - 'ckpt/model_config.json', - 'ckpt_path': - 'ckpt/model_state.pdparams', - 'vocab_file': - 'punc_vocab.txt', - }, - "ernie_linear_p3_wudao-punc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz', - 'md5': - '448eb2fdf85b6a997e7e652e80c51dd2', - 'cfg_path': - 'ckpt/model_config.json', - 'ckpt_path': - 'ckpt/model_state.pdparams', - 'vocab_file': - 'punc_vocab.txt', - }, -} - -model_alias = { - "ernie_linear_p7": "paddlespeech.text.models:ErnieLinear", - "ernie_linear_p3": "paddlespeech.text.models:ErnieLinear", -} - -tokenizer_alias = { - "ernie_linear_p7": "paddlenlp.transformers:ErnieTokenizer", - "ernie_linear_p3": "paddlenlp.transformers:ErnieTokenizer", -} diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index 879d4a4db1f37d35df8e27e77f759b65acc7ec03..4e0337bccea500c382f0782860cec36ad4897c46 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -29,22 +29,16 @@ from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger from ..utils import stats_wrapper -from .pretrained_models import model_alias -from .pretrained_models import pretrained_models from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.modules.normalizer import ZScore -from paddlespeech.utils.dynamic_import import dynamic_import __all__ = ['TTSExecutor'] class TTSExecutor(BaseExecutor): def __init__(self): - super().__init__() - self.model_alias = model_alias - self.pretrained_models = pretrained_models - + super().__init__('tts') self.parser = argparse.ArgumentParser( prog='paddlespeech.tts', add_help=True) self.parser.add_argument( @@ -183,19 +177,23 @@ class TTSExecutor(BaseExecutor): return # am am_tag = am + '-' + lang + self.task_resource.set_task_model( + model_tag=am_tag, + model_type=0, # am + version=None, # default version + ) if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: - am_res_path = self._get_pretrained_path(am_tag) - self.am_res_path = am_res_path - self.am_config = os.path.join( - am_res_path, self.pretrained_models[am_tag]['config']) - self.am_ckpt = os.path.join(am_res_path, - self.pretrained_models[am_tag]['ckpt']) + self.am_res_path = self.task_resource.res_dir + self.am_config = os.path.join(self.am_res_path, + self.task_resource.res_dict['config']) + self.am_ckpt = os.path.join(self.am_res_path, + self.task_resource.res_dict['ckpt']) self.am_stat = os.path.join( - am_res_path, self.pretrained_models[am_tag]['speech_stats']) + self.am_res_path, self.task_resource.res_dict['speech_stats']) # must have phones_dict in acoustic self.phones_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['phones_dict']) - logger.info(am_res_path) + self.am_res_path, self.task_resource.res_dict['phones_dict']) + logger.info(self.am_res_path) logger.info(self.am_config) logger.info(self.am_ckpt) else: @@ -207,33 +205,37 @@ class TTSExecutor(BaseExecutor): # for speedyspeech self.tones_dict = None - if 'tones_dict' in self.pretrained_models[am_tag]: + if 'tones_dict' in self.task_resource.res_dict: self.tones_dict = os.path.join( - self.am_res_path, self.pretrained_models[am_tag]['tones_dict']) + self.am_res_path, self.task_resource.res_dict['tones_dict']) if tones_dict: self.tones_dict = tones_dict # for multi speaker fastspeech2 self.speaker_dict = None - if 'speaker_dict' in self.pretrained_models[am_tag]: + if 'speaker_dict' in self.task_resource.res_dict: self.speaker_dict = os.path.join( - self.am_res_path, - self.pretrained_models[am_tag]['speaker_dict']) + self.am_res_path, self.task_resource.res_dict['speaker_dict']) if speaker_dict: self.speaker_dict = speaker_dict # voc voc_tag = voc + '-' + lang + self.task_resource.set_task_model( + model_tag=voc_tag, + model_type=1, # vocoder + version=None, # default version + ) if voc_ckpt is None or voc_config is None or voc_stat is None: - voc_res_path = self._get_pretrained_path(voc_tag) - self.voc_res_path = voc_res_path + self.voc_res_path = self.task_resource.voc_res_dir self.voc_config = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['config']) + self.voc_res_path, self.task_resource.voc_res_dict['config']) self.voc_ckpt = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['ckpt']) + self.voc_res_path, self.task_resource.voc_res_dict['ckpt']) self.voc_stat = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['speech_stats']) - logger.info(voc_res_path) + self.voc_res_path, + self.task_resource.voc_res_dict['speech_stats']) + logger.info(self.voc_res_path) logger.info(self.voc_config) logger.info(self.voc_ckpt) else: @@ -283,9 +285,9 @@ class TTSExecutor(BaseExecutor): # model: {model_name}_{dataset} am_name = am[:am.rindex('_')] - am_class = dynamic_import(am_name, self.model_alias) - am_inference_class = dynamic_import(am_name + '_inference', - self.model_alias) + am_class = self.task_resource.get_model_class(am_name) + am_inference_class = self.task_resource.get_model_class(am_name + + '_inference') if am_name == 'fastspeech2': am = am_class( @@ -314,9 +316,9 @@ class TTSExecutor(BaseExecutor): # vocoder # model: {model_name}_{dataset} voc_name = voc[:voc.rindex('_')] - voc_class = dynamic_import(voc_name, self.model_alias) - voc_inference_class = dynamic_import(voc_name + '_inference', - self.model_alias) + voc_class = self.task_resource.get_model_class(voc_name) + voc_inference_class = self.task_resource.get_model_class(voc_name + + '_inference') if voc_name != 'wavernn': voc = voc_class(**self.voc_config["generator_params"]) voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"]) @@ -444,7 +446,7 @@ class TTSExecutor(BaseExecutor): if not args.verbose: self.disable_task_loggers() - task_source = self.get_task_source(args.input) + task_source = self.get_input_source(args.input) task_results = OrderedDict() has_exceptions = False diff --git a/paddlespeech/cli/tts/pretrained_models.py b/paddlespeech/cli/tts/pretrained_models.py deleted file mode 100644 index 65254a9353fc997038d84368d3918f055d2ccee0..0000000000000000000000000000000000000000 --- a/paddlespeech/cli/tts/pretrained_models.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - # speedyspeech - "speedyspeech_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip', - 'md5': - '6f6fa967b408454b6662c8c00c0027cb', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_30600.pdz', - 'speech_stats': - 'feats_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'tones_dict': - 'tone_id_map.txt', - }, - - # fastspeech2 - "fastspeech2_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', - 'md5': - '637d28a5e53aa60275612ba4393d5f22', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_76000.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "fastspeech2_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip', - 'md5': - 'ffed800c93deaf16ca9b3af89bfcd747', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_100000.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "fastspeech2_aishell3-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip', - 'md5': - 'f4dd4a5f49a4552b77981f544ab3392e', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_96400.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'speaker_dict': - 'speaker_id_map.txt', - }, - "fastspeech2_vctk-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip', - 'md5': - '743e5024ca1e17a88c5c271db9779ba4', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_66200.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'speaker_dict': - 'speaker_id_map.txt', - }, - # tacotron2 - "tacotron2_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip', - 'md5': - '0df4b6f0bcbe0d73c5ed6df8867ab91a', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_30600.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "tacotron2_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip', - 'md5': - '6a5eddd81ae0e81d16959b97481135f3', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_60300.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - - # pwgan - "pwgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip', - 'md5': - '2e481633325b5bdf0a3823c714d2c117', - 'config': - 'pwg_default.yaml', - 'ckpt': - 'pwg_snapshot_iter_400000.pdz', - 'speech_stats': - 'pwg_stats.npy', - }, - "pwgan_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip', - 'md5': - '53610ba9708fd3008ccaf8e99dacbaf0', - 'config': - 'pwg_default.yaml', - 'ckpt': - 'pwg_snapshot_iter_400000.pdz', - 'speech_stats': - 'pwg_stats.npy', - }, - "pwgan_aishell3-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip', - 'md5': - 'd7598fa41ad362d62f85ffc0f07e3d84', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1000000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "pwgan_vctk-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip', - 'md5': - 'b3da1defcde3e578be71eb284cb89f2c', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - # mb_melgan - "mb_melgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip', - 'md5': - 'ee5f0604e20091f0d495b6ec4618b90d', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1000000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - # style_melgan - "style_melgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip', - 'md5': - '5de2d5348f396de0c966926b8c462755', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - # hifigan - "hifigan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', - 'md5': - 'dd40a3d88dfcf64513fba2f0f961ada6', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "hifigan_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip', - 'md5': - '70e9131695decbca06a65fe51ed38a72', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "hifigan_aishell3-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip', - 'md5': - '3bb49bc75032ed12f79c00c8cc79a09a', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "hifigan_vctk-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip', - 'md5': - '7da8f88359bca2457e705d924cf27bd4', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - - # wavernn - "wavernn_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip', - 'md5': - 'ee37b752f09bcba8f2af3b777ca38e13', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_400000.pdz', - 'speech_stats': - 'feats_stats.npy', - } -} - -model_alias = { - # acoustic model - "speedyspeech": - "paddlespeech.t2s.models.speedyspeech:SpeedySpeech", - "speedyspeech_inference": - "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference", - "fastspeech2": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2", - "fastspeech2_inference": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", - "tacotron2": - "paddlespeech.t2s.models.tacotron2:Tacotron2", - "tacotron2_inference": - "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", - # voc - "pwgan": - "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", - "pwgan_inference": - "paddlespeech.t2s.models.parallel_wavegan:PWGInference", - "mb_melgan": - "paddlespeech.t2s.models.melgan:MelGANGenerator", - "mb_melgan_inference": - "paddlespeech.t2s.models.melgan:MelGANInference", - "style_melgan": - "paddlespeech.t2s.models.melgan:StyleMelGANGenerator", - "style_melgan_inference": - "paddlespeech.t2s.models.melgan:StyleMelGANInference", - "hifigan": - "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", - "hifigan_inference": - "paddlespeech.t2s.models.hifigan:HiFiGANInference", - "wavernn": - "paddlespeech.t2s.models.wavernn:WaveRNN", - "wavernn_inference": - "paddlespeech.t2s.models.wavernn:WaveRNNInference", -} diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 07fb73a4c839864a63a1561324a67da55e9df80d..8bf09001397898fe81cfb82cd13e145cff2ca32e 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -22,26 +22,20 @@ from typing import Union import paddle import soundfile -from paddleaudio.backends import load as load_audio -from paddleaudio.compliance.librosa import melspectrogram from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger from ..utils import stats_wrapper -from .pretrained_models import model_alias -from .pretrained_models import pretrained_models -from paddlespeech.utils.dynamic_import import dynamic_import +from paddleaudio.backends import load as load_audio +from paddleaudio.compliance.librosa import melspectrogram from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.modules.sid_model import SpeakerIdetification class VectorExecutor(BaseExecutor): def __init__(self): - super().__init__() - self.model_alias = model_alias - self.pretrained_models = pretrained_models - + super().__init__('vector') self.parser = argparse.ArgumentParser( prog="paddlespeech.vector", add_help=True) @@ -49,7 +43,10 @@ class VectorExecutor(BaseExecutor): "--model", type=str, default="ecapatdnn_voxceleb12", - choices=["ecapatdnn_voxceleb12"], + choices=[ + tag[:tag.index('-')] + for tag in self.task_resource.pretrained_models.keys() + ], help="Choose model type of vector task.") self.parser.add_argument( "--task", @@ -119,7 +116,7 @@ class VectorExecutor(BaseExecutor): self.disable_task_loggers() # stage 2: read the input data and store them as a list - task_source = self.get_task_source(parser_args.input) + task_source = self.get_input_source(parser_args.input) logger.info(f"task source: {task_source}") # stage 3: process the audio one by one @@ -296,6 +293,7 @@ class VectorExecutor(BaseExecutor): # get the mode from pretrained list sample_rate_str = "16k" if sample_rate == 16000 else "8k" tag = model_type + "-" + sample_rate_str + self.task_resource.set_task_model(tag, version=None) logger.info(f"load the pretrained model: {tag}") # get the model from the pretrained list # we download the pretrained model and store it in the res_path @@ -303,10 +301,11 @@ class VectorExecutor(BaseExecutor): self.res_path = res_path self.cfg_path = os.path.join( - res_path, self.pretrained_models[tag]['cfg_path']) + self.task_resource.res_dir, + self.task_resource.res_dict['cfg_path']) self.ckpt_path = os.path.join( - res_path, - self.pretrained_models[tag]['ckpt_path'] + '.pdparams') + self.task_resource.res_dir, + self.task_resource.res_dict['ckpt_path'] + '.pdparams') else: # get the model from disk self.cfg_path = os.path.abspath(cfg_path) @@ -325,8 +324,8 @@ class VectorExecutor(BaseExecutor): # stage 3: get the model name to instance the model network with dynamic_import logger.info("start to dynamic import the model class") model_name = model_type[:model_type.rindex('_')] + model_class = self.task_resource.get_model_class(model_name) logger.info(f"model name {model_name}") - model_class = dynamic_import(model_name, self.model_alias) model_conf = self.config.model backbone = model_class(**model_conf) model = SpeakerIdetification( diff --git a/paddlespeech/cli/vector/pretrained_models.py b/paddlespeech/cli/vector/pretrained_models.py deleted file mode 100644 index 4d1d3a048b22550fa85d77c4a8d5fae5b39a56e2..0000000000000000000000000000000000000000 --- a/paddlespeech/cli/vector/pretrained_models.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]". - # e.g. "ecapatdnn_voxceleb12-16k". - # Command line and python api use "{model_name}[-{dataset}]" as --model, usage: - # "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav" - "ecapatdnn_voxceleb12-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_1.tar.gz', - 'md5': - '67c7ff8885d5246bd16e0f5ac1cba99f', - 'cfg_path': - 'conf/model.yaml', # the yaml config path - 'ckpt_path': - 'model/model', # the format is ${dir}/{model_name}, - # so the first 'model' is dir, the second 'model' is the name - # this means we have a model stored as model/model.pdparams - }, -} - -model_alias = { - "ecapatdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn", -} diff --git a/paddlespeech/cli/stats/__init__.py b/paddlespeech/resource/__init__.py similarity index 83% rename from paddlespeech/cli/stats/__init__.py rename to paddlespeech/resource/__init__.py index 9fe6c4abaf10de2f24f751ddd62f456768a82475..e143413af7a7cecb59b10b80296ec8d95490b14a 100644 --- a/paddlespeech/cli/stats/__init__.py +++ b/paddlespeech/resource/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .infer import StatsExecutor +from .resource import CommonTaskResource diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py new file mode 100644 index 0000000000000000000000000000000000000000..9441a2805b763bc55471380ec9217a93311aceb6 --- /dev/null +++ b/paddlespeech/resource/pretrained_models.py @@ -0,0 +1,822 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + 'asr_dynamic_pretrained_models', + 'asr_static_pretrained_models', + 'cls_dynamic_pretrained_models', + 'cls_static_pretrained_models', + 'st_dynamic_pretrained_models', + 'st_kaldi_bins', + 'text_dynamic_pretrained_models', + 'tts_dynamic_pretrained_models', + 'tts_static_pretrained_models', + 'tts_onnx_pretrained_models', + 'vector_dynamic_pretrained_models', +] + +# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". +# e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k". +# Command line and python api use "{model_name}[_{dataset}]" as --model, usage: +# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" + +# --------------------------------- +# -------------- ASR -------------- +# --------------------------------- +asr_dynamic_pretrained_models = { + "conformer_wenetspeech-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz', + 'md5': + '76cb19ed857e6623856b7cd7ebbfeda4', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/conformer/checkpoints/wenetspeech', + }, + }, + "conformer_online_wenetspeech-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz', + 'md5': + 'b8c02632b04da34aca88459835be54a6', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/chunk_conformer/checkpoints/avg_10', + }, + }, + "conformer_online_multicn-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz', + 'md5': + '7989b3248c898070904cf042fd656003', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/chunk_conformer/checkpoints/multi_cn', + }, + '2.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', + 'md5': + '0ac93d390552336f2a906aec9e33c5fa', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/chunk_conformer/checkpoints/multi_cn', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + }, + "conformer_aishell-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz', + 'md5': + '3f073eccfa7bb14e0c6867d65fc0dc3a', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/conformer/checkpoints/avg_30', + }, + }, + "conformer_online_aishell-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz', + 'md5': + 'b374cfb93537761270b6224fb0bfc26a', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/chunk_conformer/checkpoints/avg_30', + }, + }, + "transformer_librispeech-en-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz', + 'md5': + '2c667da24922aad391eacafe37bc1660', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/transformer/checkpoints/avg_10', + }, + }, + "deepspeech2online_wenetspeech-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz', + 'md5': + 'e393d4d274af0f6967db24fc146e8074', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_10', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + }, + "deepspeech2offline_aishell-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', + 'md5': + '932c3593d62fe5c741b59b31318aa314', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2/checkpoints/avg_1', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + }, + "deepspeech2online_aishell-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz', + 'md5': + '98b87b171b7240b7cae6e07d8d0bc9be', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_1', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + }, + "deepspeech2offline_librispeech-en-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz', + 'md5': + 'f5666c81ad015c8de03aac2bc92e5762', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2/checkpoints/avg_1', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm', + 'lm_md5': + '099a601759d467cd0a8523ff939819c5' + }, + }, +} + +asr_static_pretrained_models = { + "deepspeech2offline_aishell-zh-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', + 'md5': + '932c3593d62fe5c741b59b31318aa314', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2/checkpoints/avg_1', + 'model': + 'exp/deepspeech2/checkpoints/avg_1.jit.pdmodel', + 'params': + 'exp/deepspeech2/checkpoints/avg_1.jit.pdiparams', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + } + }, +} + +# --------------------------------- +# -------------- CLS -------------- +# --------------------------------- +cls_dynamic_pretrained_models = { + "panns_cnn6-32k": { + '1.0': { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz', + 'md5': '4cf09194a95df024fd12f84712cf0f9c', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn6.pdparams', + 'label_file': 'audioset_labels.txt', + }, + }, + "panns_cnn10-32k": { + '1.0': { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz', + 'md5': 'cb8427b22176cc2116367d14847f5413', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn10.pdparams', + 'label_file': 'audioset_labels.txt', + }, + }, + "panns_cnn14-32k": { + '1.0': { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz', + 'md5': 'e3b9b5614a1595001161d0ab95edee97', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn14.pdparams', + 'label_file': 'audioset_labels.txt', + }, + }, +} + +cls_static_pretrained_models = { + "panns_cnn6-32k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz', + 'md5': + 'da087c31046d23281d8ec5188c1967da', + 'cfg_path': + 'panns.yaml', + 'model_path': + 'inference.pdmodel', + 'params_path': + 'inference.pdiparams', + 'label_file': + 'audioset_labels.txt', + }, + }, + "panns_cnn10-32k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz', + 'md5': + '5460cc6eafbfaf0f261cc75b90284ae1', + 'cfg_path': + 'panns.yaml', + 'model_path': + 'inference.pdmodel', + 'params_path': + 'inference.pdiparams', + 'label_file': + 'audioset_labels.txt', + }, + }, + "panns_cnn14-32k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz', + 'md5': + 'ccc80b194821274da79466862b2ab00f', + 'cfg_path': + 'panns.yaml', + 'model_path': + 'inference.pdmodel', + 'params_path': + 'inference.pdiparams', + 'label_file': + 'audioset_labels.txt', + }, + }, +} + +# --------------------------------- +# -------------- ST --------------- +# --------------------------------- +st_dynamic_pretrained_models = { + "fat_st_ted-en-zh": { + '1.0': { + "url": + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz", + "md5": + "d62063f35a16d91210a71081bd2dd557", + "cfg_path": + "model.yaml", + "ckpt_path": + "exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams", + }, + }, +} + +st_kaldi_bins = { + "url": + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz", + "md5": + "c0682303b3f3393dbf6ed4c4e35a53eb", +} + +# --------------------------------- +# -------------- TEXT ------------- +# --------------------------------- +text_dynamic_pretrained_models = { + "ernie_linear_p7_wudao-punc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz', + 'md5': + '12283e2ddde1797c5d1e57036b512746', + 'cfg_path': + 'ckpt/model_config.json', + 'ckpt_path': + 'ckpt/model_state.pdparams', + 'vocab_file': + 'punc_vocab.txt', + }, + }, + "ernie_linear_p3_wudao-punc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz', + 'md5': + '448eb2fdf85b6a997e7e652e80c51dd2', + 'cfg_path': + 'ckpt/model_config.json', + 'ckpt_path': + 'ckpt/model_state.pdparams', + 'vocab_file': + 'punc_vocab.txt', + }, + }, +} + +# --------------------------------- +# -------------- TTS -------------- +# --------------------------------- +tts_dynamic_pretrained_models = { + # speedyspeech + "speedyspeech_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip', + 'md5': + '6f6fa967b408454b6662c8c00c0027cb', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_30600.pdz', + 'speech_stats': + 'feats_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'tones_dict': + 'tone_id_map.txt', + }, + }, + # fastspeech2 + "fastspeech2_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', + 'md5': + '637d28a5e53aa60275612ba4393d5f22', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_76000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + }, + "fastspeech2_ljspeech-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip', + 'md5': + 'ffed800c93deaf16ca9b3af89bfcd747', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_100000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + }, + "fastspeech2_aishell3-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip', + 'md5': + 'f4dd4a5f49a4552b77981f544ab3392e', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_96400.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'speaker_dict': + 'speaker_id_map.txt', + }, + }, + "fastspeech2_vctk-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip', + 'md5': + '743e5024ca1e17a88c5c271db9779ba4', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_66200.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'speaker_dict': + 'speaker_id_map.txt', + }, + }, + # tacotron2 + "tacotron2_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip', + 'md5': + '0df4b6f0bcbe0d73c5ed6df8867ab91a', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_30600.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + }, + "tacotron2_ljspeech-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip', + 'md5': + '6a5eddd81ae0e81d16959b97481135f3', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_60300.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + }, + # pwgan + "pwgan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip', + 'md5': + '2e481633325b5bdf0a3823c714d2c117', + 'config': + 'pwg_default.yaml', + 'ckpt': + 'pwg_snapshot_iter_400000.pdz', + 'speech_stats': + 'pwg_stats.npy', + }, + }, + "pwgan_ljspeech-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip', + 'md5': + '53610ba9708fd3008ccaf8e99dacbaf0', + 'config': + 'pwg_default.yaml', + 'ckpt': + 'pwg_snapshot_iter_400000.pdz', + 'speech_stats': + 'pwg_stats.npy', + }, + }, + "pwgan_aishell3-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip', + 'md5': + 'd7598fa41ad362d62f85ffc0f07e3d84', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1000000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + "pwgan_vctk-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip', + 'md5': + 'b3da1defcde3e578be71eb284cb89f2c', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + # mb_melgan + "mb_melgan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'ee5f0604e20091f0d495b6ec4618b90d', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1000000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + # style_melgan + "style_melgan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip', + 'md5': + '5de2d5348f396de0c966926b8c462755', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + # hifigan + "hifigan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'dd40a3d88dfcf64513fba2f0f961ada6', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + "hifigan_ljspeech-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip', + 'md5': + '70e9131695decbca06a65fe51ed38a72', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + "hifigan_aishell3-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip', + 'md5': + '3bb49bc75032ed12f79c00c8cc79a09a', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + "hifigan_vctk-en": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip', + 'md5': + '7da8f88359bca2457e705d924cf27bd4', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + # wavernn + "wavernn_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip', + 'md5': + 'ee37b752f09bcba8f2af3b777ca38e13', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_400000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + }, + "fastspeech2_cnndecoder_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip', + 'md5': + '6eb28e22ace73e0ebe7845f86478f89f', + 'config': + 'cnndecoder.yaml', + 'ckpt': + 'snapshot_iter_153000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + }, +} + +tts_static_pretrained_models = { + # speedyspeech + "speedyspeech_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip', + 'md5': + 'f10cbdedf47dc7a9668d2264494e1823', + 'model': + 'speedyspeech_csmsc.pdmodel', + 'params': + 'speedyspeech_csmsc.pdiparams', + 'phones_dict': + 'phone_id_map.txt', + 'tones_dict': + 'tone_id_map.txt', + 'sample_rate': + 24000, + }, + }, + # fastspeech2 + "fastspeech2_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip', + 'md5': + '9788cd9745e14c7a5d12d32670b2a5a7', + 'model': + 'fastspeech2_csmsc.pdmodel', + 'params': + 'fastspeech2_csmsc.pdiparams', + 'phones_dict': + 'phone_id_map.txt', + 'sample_rate': + 24000, + }, + }, + # pwgan + "pwgan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip', + 'md5': + 'e3504aed9c5a290be12d1347836d2742', + 'model': + 'pwgan_csmsc.pdmodel', + 'params': + 'pwgan_csmsc.pdiparams', + 'sample_rate': + 24000, + }, + }, + # mb_melgan + "mb_melgan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip', + 'md5': + 'ac6eee94ba483421d750433f4c3b8d36', + 'model': + 'mb_melgan_csmsc.pdmodel', + 'params': + 'mb_melgan_csmsc.pdiparams', + 'sample_rate': + 24000, + }, + }, + # hifigan + "hifigan_csmsc-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip', + 'md5': + '7edd8c436b3a5546b3a7cb8cff9d5a0c', + 'model': + 'hifigan_csmsc.pdmodel', + 'params': + 'hifigan_csmsc.pdiparams', + 'sample_rate': + 24000, + }, + }, +} + +tts_onnx_pretrained_models = { + # fastspeech2 + "fastspeech2_csmsc_onnx-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip', + 'md5': + 'fd3ad38d83273ad51f0ea4f4abf3ab4e', + 'ckpt': ['fastspeech2_csmsc.onnx'], + 'phones_dict': + 'phone_id_map.txt', + 'sample_rate': + 24000, + }, + }, + "fastspeech2_cnndecoder_csmsc_onnx-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip', + 'md5': + '5f70e1a6bcd29d72d54e7931aa86f266', + 'ckpt': [ + 'fastspeech2_csmsc_am_encoder_infer.onnx', + 'fastspeech2_csmsc_am_decoder.onnx', + 'fastspeech2_csmsc_am_postnet.onnx', + ], + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'sample_rate': + 24000, + }, + }, + # mb_melgan + "mb_melgan_csmsc_onnx-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip', + 'md5': + '5b83ec746e8414bc29032d954ffd07ec', + 'ckpt': + 'mb_melgan_csmsc.onnx', + 'sample_rate': + 24000, + }, + }, + # hifigan + "hifigan_csmsc_onnx-zh": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip', + 'md5': + '1a7dc0385875889e46952e50c0994a6b', + 'ckpt': + 'hifigan_csmsc.onnx', + 'sample_rate': + 24000, + }, + }, +} + +# --------------------------------- +# ------------ Vector ------------- +# --------------------------------- +vector_dynamic_pretrained_models = { + "ecapatdnn_voxceleb12-16k": { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz', + 'md5': + 'cc33023c54ab346cd318408f43fcaf95', + 'cfg_path': + 'conf/model.yaml', # the yaml config path + 'ckpt_path': + 'model/model', # the format is ${dir}/{model_name}, + # so the first 'model' is dir, the second 'model' is the name + # this means we have a model stored as model/model.pdparams + }, + }, +} diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py new file mode 100644 index 0000000000000000000000000000000000000000..f00b1b3b0592aff006337b41ededcbfde79fcded --- /dev/null +++ b/paddlespeech/resource/resource.py @@ -0,0 +1,222 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from collections import OrderedDict +from typing import Dict +from typing import List +from typing import Optional + +from ..cli.utils import download_and_decompress +from ..cli.utils import MODEL_HOME +from ..utils.dynamic_import import dynamic_import +from .model_alias import model_alias + +task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector'] +model_format_supported = ['dynamic', 'static', 'onnx'] +inference_mode_supported = ['online', 'offline'] + + +class CommonTaskResource: + def __init__(self, task: str, model_format: str='dynamic', **kwargs): + assert task in task_supported, 'Arg "task" must be one of {}.'.format( + task_supported) + assert model_format in model_format_supported, 'Arg "model_format" must be one of {}.'.format( + model_format_supported) + + self.task = task + self.model_format = model_format + self.pretrained_models = self._get_pretrained_models() + + if 'inference_mode' in kwargs: + assert kwargs[ + 'inference_mode'] in inference_mode_supported, 'Arg "inference_mode" must be one of {}.'.format( + inference_mode_supported) + self._inference_mode_filter(kwargs['inference_mode']) + + # Initialize after model and version had been set. + self.model_tag = None + self.version = None + self.res_dict = None + self.res_dir = None + + if self.task == 'tts': + # For vocoder + self.voc_model_tag = None + self.voc_version = None + self.voc_res_dict = None + self.voc_res_dir = None + + def set_task_model(self, + model_tag: str, + model_type: int=0, + version: Optional[str]=None): + """Set model tag and version of current task. + + Args: + model_tag (str): Model tag. + model_type (int): 0 for acoustic model otherwise vocoder in tts task. + version (Optional[str], optional): Version of pretrained model. Defaults to None. + """ + assert model_tag in self.pretrained_models, \ + "Can't find \"{}\" in resource. Model name must be one of {}".format(model_tag, list(self.pretrained_models.keys())) + + if version is None: + version = self._get_default_version(model_tag) + + assert version in self.pretrained_models[model_tag], \ + "Can't find version \"{}\" in \"{}\". Model name must be one of {}".format( + version, model_tag, list(self.pretrained_models[model_tag].keys())) + + if model_type == 0: + self.model_tag = model_tag + self.version = version + self.res_dict = self.pretrained_models[model_tag][version] + self.res_dir = self._fetch(self.res_dict, + self._get_model_dir(model_type)) + else: + assert self.task == 'tts', 'Vocoder will only be used in tts task.' + self.voc_model_tag = model_tag + self.voc_version = version + self.voc_res_dict = self.pretrained_models[model_tag][version] + self.voc_res_dir = self._fetch(self.voc_res_dict, + self._get_model_dir(model_type)) + + @staticmethod + def get_model_class(model_name) -> List[object]: + """Dynamic import model class. + Args: + model_name (str): Model name. + + Returns: + List[object]: Return a list of model class. + """ + assert model_name in model_alias, 'No model classes found for "{}"'.format( + model_name) + + ret = [] + for import_path in model_alias[model_name]: + ret.append(dynamic_import(import_path)) + + if len(ret) == 1: + return ret[0] + else: + return ret + + def get_versions(self, model_tag: str) -> List[str]: + """List all available versions. + + Args: + model_tag (str): Model tag. + + Returns: + List[str]: Version list of model. + """ + return list(self.pretrained_models[model_tag].keys()) + + def _get_default_version(self, model_tag: str) -> str: + """Get default version of model. + + Args: + model_tag (str): Model tag. + + Returns: + str: Default version. + """ + return self.get_versions(model_tag)[-1] # get latest version + + def _get_model_dir(self, model_type: int=0) -> os.PathLike: + """Get resource directory. + + Args: + model_type (int): 0 for acoustic model otherwise vocoder in tts task. + + Returns: + os.PathLike: Directory of model resource. + """ + if model_type == 0: + model_tag = self.model_tag + version = self.version + else: + model_tag = self.voc_model_tag + version = self.voc_version + + return os.path.join(MODEL_HOME, model_tag, version) + + def _get_pretrained_models(self) -> Dict[str, str]: + """Get all available models for current task. + + Returns: + Dict[str, str]: A dictionary with model tag and resources info. + """ + try: + import_models = '{}_{}_pretrained_models'.format(self.task, + self.model_format) + exec('from .pretrained_models import {}'.format(import_models)) + models = OrderedDict(locals()[import_models]) + except ImportError: + models = OrderedDict({}) # no models. + finally: + return models + + def _inference_mode_filter(self, inference_mode: Optional[str]): + """Filter models dict based on inference_mode. + + Args: + inference_mode (Optional[str]): 'online', 'offline' or None. + """ + if inference_mode is None: + return + + if self.task == 'asr': + online_flags = [ + 'online' in model_tag + for model_tag in self.pretrained_models.keys() + ] + for online_flag, model_tag in zip( + online_flags, list(self.pretrained_models.keys())): + if inference_mode == 'online' and online_flag: + continue + elif inference_mode == 'offline' and not online_flag: + continue + else: + del self.pretrained_models[model_tag] + elif self.task == 'tts': + # Hardcode for tts online models. + tts_online_models = [ + 'fastspeech2_csmsc-zh', 'fastspeech2_cnndecoder_csmsc-zh', + 'mb_melgan_csmsc-zh', 'hifigan_csmsc-zh' + ] + for model_tag in list(self.pretrained_models.keys()): + if inference_mode == 'online' and model_tag in tts_online_models: + continue + elif inference_mode == 'offline': + continue + else: + del self.pretrained_models[model_tag] + else: + raise NotImplementedError('Only supports asr and tts task.') + + @staticmethod + def _fetch(res_dict: Dict[str, str], + target_dir: os.PathLike) -> os.PathLike: + """Fetch archive from url. + + Args: + res_dict (Dict[str, str]): Info dict of a resource. + target_dir (os.PathLike): Directory to save archives. + + Returns: + os.PathLike: Directory of model resource. + """ + return download_and_decompress(res_dict, target_dir) diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py index e59f17d38820374a620e8d9eb78daf060412130d..f1c6b4f8963ed0130097dd3363ef88b5c4db5d8c 100644 --- a/paddlespeech/server/bin/paddlespeech_server.py +++ b/paddlespeech/server/bin/paddlespeech_server.py @@ -25,6 +25,7 @@ from ..executor import BaseExecutor from ..util import cli_server_register from ..util import stats_wrapper from paddlespeech.cli.log import logger +from paddlespeech.resource import CommonTaskResource from paddlespeech.server.engine.engine_pool import init_engine_pool from paddlespeech.server.restful.api import setup_router as setup_http_router from paddlespeech.server.utils.config import get_config @@ -152,101 +153,30 @@ class ServerStatsExecutor(): "Please input correct speech task, choices = ['asr', 'tts']") return False - elif self.task.lower() == 'asr': - try: - from paddlespeech.cli.asr.infer import pretrained_models - logger.info( - "Here is the table of ASR pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) - - # show ASR static pretrained model - from paddlespeech.server.engine.asr.paddleinference.asr_engine import pretrained_models - logger.info( - "Here is the table of ASR static pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) - - return True - except BaseException: - logger.error( - "Failed to get the table of ASR pretrained models supported in the service." - ) - return False - - elif self.task.lower() == 'tts': - try: - from paddlespeech.cli.tts.infer import pretrained_models - logger.info( - "Here is the table of TTS pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) - - # show TTS static pretrained model - from paddlespeech.server.engine.tts.paddleinference.tts_engine import pretrained_models - logger.info( - "Here is the table of TTS static pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) - - return True - except BaseException: - logger.error( - "Failed to get the table of TTS pretrained models supported in the service." - ) - return False + try: + # Dynamic models + dynamic_pretrained_models = CommonTaskResource( + task=self.task, model_format='dynamic').pretrained_models - elif self.task.lower() == 'cls': - try: - from paddlespeech.cli.cls.infer import pretrained_models + if len(dynamic_pretrained_models) > 0: logger.info( - "Here is the table of CLS pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) - - # show CLS static pretrained model - from paddlespeech.server.engine.cls.paddleinference.cls_engine import pretrained_models + "Here is the table of {} pretrained models supported in the service.". + format(self.task.upper())) + self.show_support_models(dynamic_pretrained_models) + + # Static models + static_pretrained_models = CommonTaskResource( + task=self.task, model_format='static').pretrained_models + if len(static_pretrained_models) > 0: logger.info( - "Here is the table of CLS static pretrained models supported in the service." - ) + "Here is the table of {} static pretrained models supported in the service.". + format(self.task.upper())) self.show_support_models(pretrained_models) - return True - except BaseException: - logger.error( - "Failed to get the table of CLS pretrained models supported in the service." - ) - return False - elif self.task.lower() == 'text': - try: - from paddlespeech.cli.text.infer import pretrained_models - logger.info( - "Here is the table of Text pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) + return True - return True - except BaseException: - logger.error( - "Failed to get the table of Text pretrained models supported in the service." - ) - return False - elif self.task.lower() == 'vector': - try: - from paddlespeech.cli.vector.infer import pretrained_models - logger.info( - "Here is the table of Vector pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) - - return True - except BaseException: - logger.error( - "Failed to get the table of Vector pretrained models supported in the service." - ) - return False - else: + except BaseException: logger.error( - f"Failed to get the table of {self.task} pretrained models supported in the service." - ) + "Failed to get the table of {} pretrained models supported in the service.". + format(self.task.upper())) return False diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index d7bd458f8124e2ac8c018b474e29242d478cf3b5..14715bf35ca5ad463ad75c51d1fa6bb6e4aec041 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import os import sys from typing import Optional @@ -21,15 +20,14 @@ import paddle from numpy import float32 from yacs.config import CfgNode -from .pretrained_models import pretrained_models from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.resource import CommonTaskResource from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.transform.transformation import Transformation -from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig @@ -53,7 +51,7 @@ class PaddleASRConnectionHanddler: logger.info( "create an paddle asr connection handler to process the websocket connection" ) - self.config = asr_engine.config # server config + self.config = asr_engine.config # server config self.model_config = asr_engine.executor.config self.asr_engine = asr_engine @@ -251,10 +249,12 @@ class PaddleASRConnectionHanddler: # for deepspeech2 # init state self.chunk_state_h_box = np.zeros( - (self.model_config .num_rnn_layers, 1, self.model_config.rnn_layer_size), + (self.model_config.num_rnn_layers, 1, + self.model_config.rnn_layer_size), dtype=float32) self.chunk_state_c_box = np.zeros( - (self.model_config.num_rnn_layers, 1, self.model_config.rnn_layer_size), + (self.model_config.num_rnn_layers, 1, + self.model_config.rnn_layer_size), dtype=float32) self.decoder.reset_decoder(batch_size=1) @@ -699,7 +699,8 @@ class PaddleASRConnectionHanddler: class ASRServerExecutor(ASRExecutor): def __init__(self): super().__init__() - self.pretrained_models = pretrained_models + self.task_resource = CommonTaskResource( + task='asr', model_format='dynamic', inference_mode='online') def _init_from_path(self, model_type: str=None, @@ -723,20 +724,19 @@ class ASRServerExecutor(ASRExecutor): self.sample_rate = sample_rate sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str - + self.task_resource.set_task_model(model_tag=tag) if cfg_path is None or am_model is None or am_params is None: logger.info(f"Load the pretrained model, tag = {tag}") - res_path = self._get_pretrained_path(tag) # wenetspeech_zh - self.res_path = res_path + self.res_path = self.task_resource.res_dir self.cfg_path = os.path.join( - res_path, self.pretrained_models[tag]['cfg_path']) + self.res_path, self.task_resource.res_dict['cfg_path']) - self.am_model = os.path.join(res_path, - self.pretrained_models[tag]['model']) - self.am_params = os.path.join(res_path, - self.pretrained_models[tag]['params']) - logger.info(res_path) + self.am_model = os.path.join(self.res_path, + self.task_resource.res_dict['model']) + self.am_params = os.path.join(self.res_path, + self.task_resource.res_dict['params']) + logger.info(self.res_path) else: self.cfg_path = os.path.abspath(cfg_path) self.am_model = os.path.abspath(am_model) @@ -763,8 +763,8 @@ class ASRServerExecutor(ASRExecutor): self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.vocab) - lm_url = self.pretrained_models[tag]['lm_url'] - lm_md5 = self.pretrained_models[tag]['lm_md5'] + lm_url = self.task_resource.res_dict['lm_url'] + lm_md5 = self.task_resource.res_dict['lm_md5'] logger.info(f"Start to load language model {lm_url}") self.download_lm( lm_url, @@ -810,7 +810,7 @@ class ASRServerExecutor(ASRExecutor): model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} logger.info(f"model name: {model_name}") - model_class = dynamic_import(model_name, self.model_alias) + model_class = self.task_resource.get_model_class(model_name) model_conf = self.config model = model_class.from_config(model_conf) self.model = model @@ -824,7 +824,7 @@ class ASRServerExecutor(ASRExecutor): raise ValueError(f"Not support: {model_type}") return True - + class ASREngine(BaseEngine): """ASR server resource diff --git a/paddlespeech/server/engine/asr/online/pretrained_models.py b/paddlespeech/server/engine/asr/online/pretrained_models.py deleted file mode 100644 index ff3778657e85efe1808b1cdb8e34d33ebad862d3..0000000000000000000000000000000000000000 --- a/paddlespeech/server/engine/asr/online/pretrained_models.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - "deepspeech2online_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz', - 'md5': - '98b87b171b7240b7cae6e07d8d0bc9be', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2_online/checkpoints/avg_1', - 'model': - 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel', - 'params': - 'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "conformer_online_multicn-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', - 'md5': - '0ac93d390552336f2a906aec9e33c5fa', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/chunk_conformer/checkpoints/multi_cn', - 'model': - 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', - 'params': - 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "conformer_online_wenetspeech-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz', - 'md5': - 'b8c02632b04da34aca88459835be54a6', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/chunk_conformer/checkpoints/avg_10', - 'model': - 'exp/chunk_conformer/checkpoints/avg_10.pdparams', - 'params': - 'exp/chunk_conformer/checkpoints/avg_10.pdparams', - 'lm_url': - '', - 'lm_md5': - '', - }, -} diff --git a/paddlespeech/server/engine/asr/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/paddleinference/asr_engine.py index e275f1088f648df62947ded43f297cbb8d2c70c2..80e323fa04c7baca448df635f3d653a9bc5b0801 100644 --- a/paddlespeech/server/engine/asr/paddleinference/asr_engine.py +++ b/paddlespeech/server/engine/asr/paddleinference/asr_engine.py @@ -19,10 +19,10 @@ from typing import Optional import paddle from yacs.config import CfgNode -from .pretrained_models import pretrained_models from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.resource import CommonTaskResource from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.utils.utility import UpdateConfig @@ -36,7 +36,8 @@ __all__ = ['ASREngine'] class ASRServerExecutor(ASRExecutor): def __init__(self): super().__init__() - self.pretrained_models = pretrained_models + self.task_resource = CommonTaskResource( + task='asr', model_format='static', inference_mode='online') def _init_from_path(self, model_type: str='wenetspeech', @@ -53,17 +54,17 @@ class ASRServerExecutor(ASRExecutor): sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str + self.task_resource.set_task_model(model_tag=tag) if cfg_path is None or am_model is None or am_params is None: - res_path = self._get_pretrained_path(tag) # wenetspeech_zh - self.res_path = res_path + self.res_path = self.task_resource.res_dir self.cfg_path = os.path.join( - res_path, self.pretrained_models[tag]['cfg_path']) + self.res_path, self.task_resource.res_dict['cfg_path']) - self.am_model = os.path.join(res_path, - self.pretrained_models[tag]['model']) - self.am_params = os.path.join(res_path, - self.pretrained_models[tag]['params']) - logger.info(res_path) + self.am_model = os.path.join(self.res_path, + self.task_resource.res_dict['model']) + self.am_params = os.path.join(self.res_path, + self.task_resource.res_dict['params']) + logger.info(self.res_path) logger.info(self.cfg_path) logger.info(self.am_model) logger.info(self.am_params) @@ -89,8 +90,8 @@ class ASRServerExecutor(ASRExecutor): self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.vocab) - lm_url = self.pretrained_models[tag]['lm_url'] - lm_md5 = self.pretrained_models[tag]['lm_md5'] + lm_url = self.task_resource.res_dict['lm_url'] + lm_md5 = self.task_resource.res_dict['lm_md5'] self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) diff --git a/paddlespeech/server/engine/asr/paddleinference/pretrained_models.py b/paddlespeech/server/engine/asr/paddleinference/pretrained_models.py deleted file mode 100644 index c4c23e38cfb0b126e91090053054bcc50dc733e1..0000000000000000000000000000000000000000 --- a/paddlespeech/server/engine/asr/paddleinference/pretrained_models.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - "deepspeech2offline_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', - 'md5': - '932c3593d62fe5c741b59b31318aa314', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2/checkpoints/avg_1', - 'model': - 'exp/deepspeech2/checkpoints/avg_1.jit.pdmodel', - 'params': - 'exp/deepspeech2/checkpoints/avg_1.jit.pdiparams', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, -} diff --git a/paddlespeech/server/engine/cls/paddleinference/cls_engine.py b/paddlespeech/server/engine/cls/paddleinference/cls_engine.py index 0906c2412d36f2d27393731da18e994772c2addd..48792c883aa4d0a5833f4d2cfbd456598dd084b2 100644 --- a/paddlespeech/server/engine/cls/paddleinference/cls_engine.py +++ b/paddlespeech/server/engine/cls/paddleinference/cls_engine.py @@ -20,9 +20,9 @@ import numpy as np import paddle import yaml -from .pretrained_models import pretrained_models from paddlespeech.cli.cls.infer import CLSExecutor from paddlespeech.cli.log import logger +from paddlespeech.resource import CommonTaskResource from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import run_model @@ -33,11 +33,12 @@ __all__ = ['CLSEngine'] class CLSServerExecutor(CLSExecutor): def __init__(self): super().__init__() - self.pretrained_models = pretrained_models + self.task_resource = CommonTaskResource( + task='cls', model_format='static') def _init_from_path( self, - model_type: str='panns_cnn14', + model_type: str='panns_cnn14_audioset', cfg_path: Optional[os.PathLike]=None, model_path: Optional[os.PathLike]=None, params_path: Optional[os.PathLike]=None, @@ -49,15 +50,16 @@ class CLSServerExecutor(CLSExecutor): if cfg_path is None or model_path is None or params_path is None or label_file is None: tag = model_type + '-' + '32k' - self.res_path = self._get_pretrained_path(tag) + self.task_resource.set_task_model(model_tag=tag) + self.res_path = self.task_resource.res_dir self.cfg_path = os.path.join( - self.res_path, self.pretrained_models[tag]['cfg_path']) + self.res_path, self.task_resource.res_dict['cfg_path']) self.model_path = os.path.join( - self.res_path, self.pretrained_models[tag]['model_path']) + self.res_path, self.task_resource.res_dict['model_path']) self.params_path = os.path.join( - self.res_path, self.pretrained_models[tag]['params_path']) + self.res_path, self.task_resource.res_dict['params_path']) self.label_file = os.path.join( - self.res_path, self.pretrained_models[tag]['label_file']) + self.res_path, self.task_resource.res_dict['label_file']) else: self.cfg_path = os.path.abspath(cfg_path) self.model_path = os.path.abspath(model_path) diff --git a/paddlespeech/server/engine/cls/paddleinference/pretrained_models.py b/paddlespeech/server/engine/cls/paddleinference/pretrained_models.py deleted file mode 100644 index e4914874600c2198e434d267c775dea66f3f252a..0000000000000000000000000000000000000000 --- a/paddlespeech/server/engine/cls/paddleinference/pretrained_models.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -pretrained_models = { - "panns_cnn6-32k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz', - 'md5': - 'da087c31046d23281d8ec5188c1967da', - 'cfg_path': - 'panns.yaml', - 'model_path': - 'inference.pdmodel', - 'params_path': - 'inference.pdiparams', - 'label_file': - 'audioset_labels.txt', - }, - "panns_cnn10-32k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz', - 'md5': - '5460cc6eafbfaf0f261cc75b90284ae1', - 'cfg_path': - 'panns.yaml', - 'model_path': - 'inference.pdmodel', - 'params_path': - 'inference.pdiparams', - 'label_file': - 'audioset_labels.txt', - }, - "panns_cnn14-32k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz', - 'md5': - 'ccc80b194821274da79466862b2ab00f', - 'cfg_path': - 'panns.yaml', - 'model_path': - 'inference.pdmodel', - 'params_path': - 'inference.pdiparams', - 'label_file': - 'audioset_labels.txt', - }, -} diff --git a/paddlespeech/server/engine/tts/online/onnx/pretrained_models.py b/paddlespeech/server/engine/tts/online/onnx/pretrained_models.py deleted file mode 100644 index 789f5be7d7ca16965459fec6df7e40f7713ee104..0000000000000000000000000000000000000000 --- a/paddlespeech/server/engine/tts/online/onnx/pretrained_models.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# support online model -pretrained_models = { - # fastspeech2 - "fastspeech2_csmsc_onnx-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip', - 'md5': - 'fd3ad38d83273ad51f0ea4f4abf3ab4e', - 'ckpt': ['fastspeech2_csmsc.onnx'], - 'phones_dict': - 'phone_id_map.txt', - 'sample_rate': - 24000, - }, - "fastspeech2_cnndecoder_csmsc_onnx-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip', - 'md5': - '5f70e1a6bcd29d72d54e7931aa86f266', - 'ckpt': [ - 'fastspeech2_csmsc_am_encoder_infer.onnx', - 'fastspeech2_csmsc_am_decoder.onnx', - 'fastspeech2_csmsc_am_postnet.onnx', - ], - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'sample_rate': - 24000, - }, - - # mb_melgan - "mb_melgan_csmsc_onnx-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip', - 'md5': - '5b83ec746e8414bc29032d954ffd07ec', - 'ckpt': - 'mb_melgan_csmsc.onnx', - 'sample_rate': - 24000, - }, - - # hifigan - "hifigan_csmsc_onnx-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip', - 'md5': - '1a7dc0385875889e46952e50c0994a6b', - 'ckpt': - 'hifigan_csmsc.onnx', - 'sample_rate': - 24000, - }, -} diff --git a/paddlespeech/server/engine/tts/online/onnx/tts_engine.py b/paddlespeech/server/engine/tts/online/onnx/tts_engine.py index 792442065074af9168f84b1ce695bb484b01e388..6453f1ae7065e25e3e318a9ebfa2616c6eeaa81b 100644 --- a/paddlespeech/server/engine/tts/online/onnx/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/onnx/tts_engine.py @@ -20,9 +20,9 @@ from typing import Optional import numpy as np import paddle -from .pretrained_models import pretrained_models from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.resource import CommonTaskResource from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.onnx_infer import get_sess @@ -43,7 +43,7 @@ class TTSServerExecutor(TTSExecutor): self.voc_pad = voc_pad self.voc_upsample = voc_upsample - self.pretrained_models = pretrained_models + self.task_resource = CommonTaskResource(task='tts', model_format='onnx') def _init_from_path( self, @@ -72,16 +72,21 @@ class TTSServerExecutor(TTSExecutor): return # am am_tag = am + '-' + lang + self.task_resource.set_task_model( + model_tag=am_tag, + model_type=0, # am + version=None, # default version + ) + self.am_res_path = self.task_resource.res_dir if am == "fastspeech2_csmsc_onnx": # get model info if am_ckpt is None or phones_dict is None: - am_res_path = self._get_pretrained_path(am_tag) - self.am_res_path = am_res_path self.am_ckpt = os.path.join( - am_res_path, self.pretrained_models[am_tag]['ckpt'][0]) + self.am_res_path, self.task_resource.res_dict['ckpt'][0]) # must have phones_dict in acoustic self.phones_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['phones_dict']) + self.am_res_path, + self.task_resource.res_dict['phones_dict']) else: self.am_ckpt = os.path.abspath(am_ckpt[0]) @@ -94,19 +99,19 @@ class TTSServerExecutor(TTSExecutor): elif am == "fastspeech2_cnndecoder_csmsc_onnx": if am_ckpt is None or am_stat is None or phones_dict is None: - am_res_path = self._get_pretrained_path(am_tag) - self.am_res_path = am_res_path self.am_encoder_infer = os.path.join( - am_res_path, self.pretrained_models[am_tag]['ckpt'][0]) + self.am_res_path, self.task_resource.res_dict['ckpt'][0]) self.am_decoder = os.path.join( - am_res_path, self.pretrained_models[am_tag]['ckpt'][1]) + self.am_res_path, self.task_resource.res_dict['ckpt'][1]) self.am_postnet = os.path.join( - am_res_path, self.pretrained_models[am_tag]['ckpt'][2]) + self.am_res_path, self.task_resource.res_dict['ckpt'][2]) # must have phones_dict in acoustic self.phones_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['phones_dict']) + self.am_res_path, + self.task_resource.res_dict['phones_dict']) self.am_stat = os.path.join( - am_res_path, self.pretrained_models[am_tag]['speech_stats']) + self.am_res_path, + self.task_resource.res_dict['speech_stats']) else: self.am_encoder_infer = os.path.abspath(am_ckpt[0]) @@ -131,11 +136,15 @@ class TTSServerExecutor(TTSExecutor): # voc model info voc_tag = voc + '-' + lang + self.task_resource.set_task_model( + model_tag=voc_tag, + model_type=1, # vocoder + version=None, # default version + ) if voc_ckpt is None: - voc_res_path = self._get_pretrained_path(voc_tag) - self.voc_res_path = voc_res_path + self.voc_res_path = self.task_resource.voc_res_dir self.voc_ckpt = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['ckpt']) + self.voc_res_path, self.task_resource.voc_res_dict['ckpt']) else: self.voc_ckpt = os.path.abspath(voc_ckpt) self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt)) diff --git a/paddlespeech/server/engine/tts/online/python/pretrained_models.py b/paddlespeech/server/engine/tts/online/python/pretrained_models.py deleted file mode 100644 index bf6aded51168c2c21172ec8101413b4cb0e05154..0000000000000000000000000000000000000000 --- a/paddlespeech/server/engine/tts/online/python/pretrained_models.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# support online model -pretrained_models = { - # fastspeech2 - "fastspeech2_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', - 'md5': - '637d28a5e53aa60275612ba4393d5f22', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_76000.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "fastspeech2_cnndecoder_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip', - 'md5': - '6eb28e22ace73e0ebe7845f86478f89f', - 'config': - 'cnndecoder.yaml', - 'ckpt': - 'snapshot_iter_153000.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - - # mb_melgan - "mb_melgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip', - 'md5': - 'ee5f0604e20091f0d495b6ec4618b90d', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1000000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - - # hifigan - "hifigan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', - 'md5': - 'dd40a3d88dfcf64513fba2f0f961ada6', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, -} diff --git a/paddlespeech/server/engine/tts/online/python/tts_engine.py b/paddlespeech/server/engine/tts/online/python/tts_engine.py index 8dc36f8ef8f6d0d2316e59e8090f43aa2702f8e2..2c08521de3adb68b6e8824dd485536e6ebed8d4b 100644 --- a/paddlespeech/server/engine/tts/online/python/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/python/tts_engine.py @@ -22,9 +22,9 @@ import paddle import yaml from yacs.config import CfgNode -from .pretrained_models import pretrained_models from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.resource import CommonTaskResource from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.util import denorm @@ -32,7 +32,6 @@ from paddlespeech.server.utils.util import get_chunks from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.modules.normalizer import ZScore -from paddlespeech.utils.dynamic_import import dynamic_import __all__ = ['TTSEngine'] @@ -44,7 +43,8 @@ class TTSServerExecutor(TTSExecutor): self.am_pad = am_pad self.voc_block = voc_block self.voc_pad = voc_pad - self.pretrained_models = pretrained_models + self.task_resource = CommonTaskResource( + task='tts', model_format='static', inference_mode='online') def get_model_info(self, field: str, @@ -65,7 +65,7 @@ class TTSServerExecutor(TTSExecutor): [Tensor]: standard deviation """ - model_class = dynamic_import(model_name, self.model_alias) + model_class = self.task_resource.get_model_class(model_name) if field == "am": odim = self.am_config.n_mels @@ -110,20 +110,24 @@ class TTSServerExecutor(TTSExecutor): return # am model info am_tag = am + '-' + lang + self.task_resource.set_task_model( + model_tag=am_tag, + model_type=0, # am + version=None, # default version + ) if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: - am_res_path = self._get_pretrained_path(am_tag) - self.am_res_path = am_res_path - self.am_config = os.path.join( - am_res_path, self.pretrained_models[am_tag]['config']) - self.am_ckpt = os.path.join(am_res_path, - self.pretrained_models[am_tag]['ckpt']) + self.am_res_path = self.task_resource.res_dir + self.am_config = os.path.join(self.am_res_path, + self.task_resource.res_dict['config']) + self.am_ckpt = os.path.join(self.am_res_path, + self.task_resource.res_dict['ckpt']) self.am_stat = os.path.join( - am_res_path, self.pretrained_models[am_tag]['speech_stats']) + self.am_res_path, self.task_resource.res_dict['speech_stats']) # must have phones_dict in acoustic self.phones_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['phones_dict']) + self.am_res_path, self.task_resource.res_dict['phones_dict']) print("self.phones_dict:", self.phones_dict) - logger.info(am_res_path) + logger.info(self.am_res_path) logger.info(self.am_config) logger.info(self.am_ckpt) else: @@ -139,16 +143,21 @@ class TTSServerExecutor(TTSExecutor): # voc model info voc_tag = voc + '-' + lang + self.task_resource.set_task_model( + model_tag=voc_tag, + model_type=1, # vocoder + version=None, # default version + ) if voc_ckpt is None or voc_config is None or voc_stat is None: - voc_res_path = self._get_pretrained_path(voc_tag) - self.voc_res_path = voc_res_path + self.voc_res_path = self.task_resource.voc_res_dir self.voc_config = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['config']) + self.voc_res_path, self.task_resource.voc_res_dict['config']) self.voc_ckpt = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['ckpt']) + self.voc_res_path, self.task_resource.voc_res_dict['ckpt']) self.voc_stat = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['speech_stats']) - logger.info(voc_res_path) + self.voc_res_path, + self.task_resource.voc_res_dict['speech_stats']) + logger.info(self.voc_res_path) logger.info(self.voc_config) logger.info(self.voc_ckpt) else: @@ -188,8 +197,8 @@ class TTSServerExecutor(TTSExecutor): am, am_mu, am_std = self.get_model_info("am", self.am_name, self.am_ckpt, self.am_stat) am_normalizer = ZScore(am_mu, am_std) - am_inference_class = dynamic_import(self.am_name + '_inference', - self.model_alias) + am_inference_class = self.task_resource.get_model_class( + self.am_name + '_inference') self.am_inference = am_inference_class(am_normalizer, am) self.am_inference.eval() print("acoustic model done!") @@ -199,8 +208,8 @@ class TTSServerExecutor(TTSExecutor): voc, voc_mu, voc_std = self.get_model_info("voc", self.voc_name, self.voc_ckpt, self.voc_stat) voc_normalizer = ZScore(voc_mu, voc_std) - voc_inference_class = dynamic_import(self.voc_name + '_inference', - self.model_alias) + voc_inference_class = self.task_resource.get_model_class(self.voc_name + + '_inference') self.voc_inference = voc_inference_class(voc_normalizer, voc) self.voc_inference.eval() print("voc done!") @@ -505,4 +514,4 @@ class TTSEngine(BaseEngine): logger.info(f"RTF: {self.executor.final_response_time / duration}") logger.info( f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s," - ) \ No newline at end of file + ) diff --git a/paddlespeech/server/engine/tts/paddleinference/pretrained_models.py b/paddlespeech/server/engine/tts/paddleinference/pretrained_models.py deleted file mode 100644 index 9618a7a697765f532a172c551b6be733a68a1bec..0000000000000000000000000000000000000000 --- a/paddlespeech/server/engine/tts/paddleinference/pretrained_models.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Static model applied on paddle inference -pretrained_models = { - # speedyspeech - "speedyspeech_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip', - 'md5': - 'f10cbdedf47dc7a9668d2264494e1823', - 'model': - 'speedyspeech_csmsc.pdmodel', - 'params': - 'speedyspeech_csmsc.pdiparams', - 'phones_dict': - 'phone_id_map.txt', - 'tones_dict': - 'tone_id_map.txt', - 'sample_rate': - 24000, - }, - # fastspeech2 - "fastspeech2_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip', - 'md5': - '9788cd9745e14c7a5d12d32670b2a5a7', - 'model': - 'fastspeech2_csmsc.pdmodel', - 'params': - 'fastspeech2_csmsc.pdiparams', - 'phones_dict': - 'phone_id_map.txt', - 'sample_rate': - 24000, - }, - # pwgan - "pwgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip', - 'md5': - 'e3504aed9c5a290be12d1347836d2742', - 'model': - 'pwgan_csmsc.pdmodel', - 'params': - 'pwgan_csmsc.pdiparams', - 'sample_rate': - 24000, - }, - # mb_melgan - "mb_melgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip', - 'md5': - 'ac6eee94ba483421d750433f4c3b8d36', - 'model': - 'mb_melgan_csmsc.pdmodel', - 'params': - 'mb_melgan_csmsc.pdiparams', - 'sample_rate': - 24000, - }, - # hifigan - "hifigan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip', - 'md5': - '7edd8c436b3a5546b3a7cb8cff9d5a0c', - 'model': - 'hifigan_csmsc.pdmodel', - 'params': - 'hifigan_csmsc.pdiparams', - 'sample_rate': - 24000, - }, -} diff --git a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py index f1ce8b76e2eacd378ccb8657486716ffb5ad4036..44e564983c942d0bebd4dd94a3e6e7accec5e69d 100644 --- a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py +++ b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py @@ -23,9 +23,9 @@ import paddle import soundfile as sf from scipy.io import wavfile -from .pretrained_models import pretrained_models from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.resource import CommonTaskResource from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import change_speed from paddlespeech.server.utils.errors import ErrorCode @@ -41,7 +41,8 @@ __all__ = ['TTSEngine'] class TTSServerExecutor(TTSExecutor): def __init__(self): super().__init__() - self.pretrained_models = pretrained_models + self.task_resource = CommonTaskResource( + task='tts', model_format='static') def _init_from_path( self, @@ -67,19 +68,23 @@ class TTSServerExecutor(TTSExecutor): return # am am_tag = am + '-' + lang + self.task_resource.set_task_model( + model_tag=am_tag, + model_type=0, # am + version=None, # default version + ) if am_model is None or am_params is None or phones_dict is None: - am_res_path = self._get_pretrained_path(am_tag) - self.am_res_path = am_res_path - self.am_model = os.path.join( - am_res_path, self.pretrained_models[am_tag]['model']) - self.am_params = os.path.join( - am_res_path, self.pretrained_models[am_tag]['params']) + self.am_res_path = self.task_resource.res_dir + self.am_model = os.path.join(self.am_res_path, + self.task_resource.res_dict['model']) + self.am_params = os.path.join(self.am_res_path, + self.task_resource.res_dict['params']) # must have phones_dict in acoustic self.phones_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['phones_dict']) - self.am_sample_rate = self.pretrained_models[am_tag]['sample_rate'] + self.am_res_path, self.task_resource.res_dict['phones_dict']) + self.am_sample_rate = self.task_resource.res_dict['sample_rate'] - logger.info(am_res_path) + logger.info(self.am_res_path) logger.info(self.am_model) logger.info(self.am_params) else: @@ -92,32 +97,36 @@ class TTSServerExecutor(TTSExecutor): # for speedyspeech self.tones_dict = None - if 'tones_dict' in self.pretrained_models[am_tag]: + if 'tones_dict' in self.task_resource.res_dict: self.tones_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['tones_dict']) + self.am_res_path, self.task_resource.res_dict['tones_dict']) if tones_dict: self.tones_dict = tones_dict # for multi speaker fastspeech2 self.speaker_dict = None - if 'speaker_dict' in self.pretrained_models[am_tag]: + if 'speaker_dict' in self.task_resource.res_dict: self.speaker_dict = os.path.join( - am_res_path, self.pretrained_models[am_tag]['speaker_dict']) + self.am_res_path, self.task_resource.res_dict['speaker_dict']) if speaker_dict: self.speaker_dict = speaker_dict # voc voc_tag = voc + '-' + lang + self.task_resource.set_task_model( + model_tag=voc_tag, + model_type=1, # vocoder + version=None, # default version + ) if voc_model is None or voc_params is None: - voc_res_path = self._get_pretrained_path(voc_tag) - self.voc_res_path = voc_res_path + self.voc_res_path = self.task_resource.voc_res_dir self.voc_model = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['model']) + self.voc_res_path, self.task_resource.voc_res_dict['model']) self.voc_params = os.path.join( - voc_res_path, self.pretrained_models[voc_tag]['params']) - self.voc_sample_rate = self.pretrained_models[voc_tag][ + self.voc_res_path, self.task_resource.voc_res_dict['params']) + self.voc_sample_rate = self.task_resource.voc_res_dict[ 'sample_rate'] - logger.info(voc_res_path) + logger.info(self.voc_res_path) logger.info(self.voc_model) logger.info(self.voc_params) else: