diff --git a/paddlespeech/cli/stats/infer.py b/paddlespeech/cli/stats/infer.py index d60a6691e51dce4baa2170a24f4537e81ba2798d..4ef50449c37e08c1a3c5f9b8894a5b4141e1c33f 100644 --- a/paddlespeech/cli/stats/infer.py +++ b/paddlespeech/cli/stats/infer.py @@ -68,7 +68,7 @@ class StatsExecutor(): ) return False - if self.task == 'asr': + elif self.task == 'asr': try: from ..asr.infer import pretrained_models logger.info( diff --git a/paddlespeech/server/bin/__init__.py b/paddlespeech/server/bin/__init__.py index bd75747f79948ea42229b8c164174dbe4240d4b1..025aab098f2b6d56ced56d499ce619feb190ab2d 100644 --- a/paddlespeech/server/bin/__init__.py +++ b/paddlespeech/server/bin/__init__.py @@ -14,3 +14,4 @@ from .paddlespeech_client import ASRClientExecutor from .paddlespeech_client import TTSClientExecutor from .paddlespeech_server import ServerExecutor +from .paddlespeech_server import ServerStatsExecutor diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py index aff77d54436eac55fda46c8e2ed218cc115a0085..21fc5c65e965a87c483046d66e45036d1b091b5d 100644 --- a/paddlespeech/server/bin/paddlespeech_server.py +++ b/paddlespeech/server/bin/paddlespeech_server.py @@ -16,15 +16,17 @@ from typing import List import uvicorn from fastapi import FastAPI +from prettytable import PrettyTable from ..executor import BaseExecutor from ..util import cli_server_register from ..util import stats_wrapper +from paddlespeech.cli.log import logger from paddlespeech.server.engine.engine_pool import init_engine_pool from paddlespeech.server.restful.api import setup_router from paddlespeech.server.utils.config import get_config -__all__ = ['ServerExecutor'] +__all__ = ['ServerExecutor', 'ServerStatsExecutor'] app = FastAPI( title="PaddleSpeech Serving API", description="Api", version="0.0.1") @@ -86,3 +88,139 @@ class ServerExecutor(BaseExecutor): config = get_config(config_file) if self.init(config): uvicorn.run(app, host=config.host, port=config.port, debug=True) + + +@cli_server_register( + name='paddlespeech_server.stats', + description='Get the models supported by each speech task in the service.') +class ServerStatsExecutor(): + def __init__(self): + super(ServerStatsExecutor, self).__init__() + + self.parser = argparse.ArgumentParser( + prog='paddlespeech_server.stats', add_help=True) + self.parser.add_argument( + '--task', + type=str, + default=None, + choices=['asr', 'tts'], + help='Choose speech task.', + required=True) + self.task_choices = ['asr', 'tts'] + self.model_name_format = { + 'asr': 'Model-Language-Sample Rate', + 'tts': 'Model-Language' + } + + def show_support_models(self, pretrained_models: dict): + fields = self.model_name_format[self.task].split("-") + table = PrettyTable(fields) + for key in pretrained_models: + table.add_row(key.split("-")) + print(table) + + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ + parser_args = self.parser.parse_args(argv) + self.task = parser_args.task + if self.task not in self.task_choices: + logger.error( + "Please input correct speech task, choices = ['asr', 'tts']") + return False + + elif self.task == 'asr': + try: + from paddlespeech.cli.asr.infer import pretrained_models + logger.info( + "Here is the table of ASR pretrained models supported in the service." + ) + self.show_support_models(pretrained_models) + + # show ASR static pretrained model + from paddlespeech.server.engine.asr.paddleinference.asr_engine import pretrained_models + logger.info( + "Here is the table of ASR static pretrained models supported in the service." + ) + self.show_support_models(pretrained_models) + + return True + except BaseException: + logger.error( + "Failed to get the table of ASR pretrained models supported in the service." + ) + return False + + elif self.task == 'tts': + try: + from paddlespeech.cli.tts.infer import pretrained_models + logger.info( + "Here is the table of TTS pretrained models supported in the service." + ) + self.show_support_models(pretrained_models) + + # show TTS static pretrained model + from paddlespeech.server.engine.tts.paddleinference.tts_engine import pretrained_models + logger.info( + "Here is the table of TTS static pretrained models supported in the service." + ) + self.show_support_models(pretrained_models) + + return True + except BaseException: + logger.error( + "Failed to get the table of TTS pretrained models supported in the service." + ) + return False + + @stats_wrapper + def __call__( + self, + task: str=None, ): + """ + Python API to call an executor. + """ + self.task = task + if self.task not in self.task_choices: + print("Please input correct speech task, choices = ['asr', 'tts']") + + elif self.task == 'asr': + try: + from paddlespeech.cli.asr.infer import pretrained_models + print( + "Here is the table of ASR pretrained models supported in the service." + ) + self.show_support_models(pretrained_models) + + # show ASR static pretrained model + from paddlespeech.server.engine.asr.paddleinference.asr_engine import pretrained_models + print( + "Here is the table of ASR static pretrained models supported in the service." + ) + self.show_support_models(pretrained_models) + + except BaseException: + print( + "Failed to get the table of ASR pretrained models supported in the service." + ) + + elif self.task == 'tts': + try: + from paddlespeech.cli.tts.infer import pretrained_models + print( + "Here is the table of TTS pretrained models supported in the service." + ) + self.show_support_models(pretrained_models) + + # show TTS static pretrained model + from paddlespeech.server.engine.tts.paddleinference.tts_engine import pretrained_models + print( + "Here is the table of TTS static pretrained models supported in the service." + ) + self.show_support_models(pretrained_models) + + except BaseException: + print( + "Failed to get the table of TTS pretrained models supported in the service." + )