paddlespeech_server.py 6.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
L
lym0302 已提交
15
import sys
L
lym0302 已提交
16
import warnings
17 18 19 20
from typing import List

import uvicorn
from fastapi import FastAPI
21
from prettytable import PrettyTable
L
lym0302 已提交
22
from starlette.middleware.cors import CORSMiddleware
23

L
lym0302 已提交
24
from ..executor import BaseExecutor
25
from ..util import cli_server_register
L
lym0302 已提交
26
from ..util import stats_wrapper
27
from paddlespeech.cli.log import logger
K
KP 已提交
28
from paddlespeech.resource import CommonTaskResource
L
lym0302 已提交
29
from paddlespeech.server.engine.engine_pool import init_engine_pool
L
lym0302 已提交
30
from paddlespeech.server.engine.engine_warmup import warm_up
L
lym0302 已提交
31
from paddlespeech.server.restful.api import setup_router as setup_http_router
32
from paddlespeech.server.utils.config import get_config
L
lym0302 已提交
33
from paddlespeech.server.ws.api import setup_router as setup_ws_router
L
lym0302 已提交
34
warnings.filterwarnings("ignore")
35

36
__all__ = ['ServerExecutor', 'ServerStatsExecutor']
37 38 39 40

app = FastAPI(
    title="PaddleSpeech Serving API", description="Api", version="0.0.1")

I
iftaken 已提交
41 42 43 44 45 46
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"])
47

小湉湉's avatar
小湉湉 已提交
48

49
@cli_server_register(
L
lym0302 已提交
50 51
    name='paddlespeech_server.start', description='Start the service')
class ServerExecutor(BaseExecutor):
52
    def __init__(self):
L
lym0302 已提交
53 54 55
        super(ServerExecutor, self).__init__()
        self.parser = argparse.ArgumentParser(
            prog='paddlespeech_server.start', add_help=True)
56 57 58 59
        self.parser.add_argument(
            "--config_file",
            action="store",
            help="yaml file of the app",
L
lym0302 已提交
60 61
            default=None,
            required=True)
62 63 64 65 66 67 68 69 70

        self.parser.add_argument(
            "--log_file",
            action="store",
            help="log file",
            default="./log/paddlespeech.log")

    def init(self, config) -> bool:
        """system initialization
L
lym0302 已提交
71

72 73
        Args:
            config (CfgNode): config object
L
lym0302 已提交
74

75 76 77 78
        Returns:
            bool: 
        """
        # init api
L
lym0302 已提交
79
        api_list = list(engine.split("_")[0] for engine in config.engine_list)
L
lym0302 已提交
80 81 82 83 84 85
        if config.protocol == "websocket":
            api_router = setup_ws_router(api_list)
        elif config.protocol == "http":
            api_router = setup_http_router(api_list)
        else:
            raise Exception("unsupported protocol")
86
        app.include_router(api_router)
X
xiongxinlei 已提交
87
        logger.info("start to init the engine")
L
lym0302 已提交
88 89
        if not init_engine_pool(config):
            return False
90

L
lym0302 已提交
91 92 93 94 95
        # warm up
        for engine_and_type in config.engine_list:
            if not warm_up(engine_and_type):
                return False

96 97 98 99
        return True

    def execute(self, argv: List[str]) -> bool:
        args = self.parser.parse_args(argv)
L
lym0302 已提交
100 101 102 103 104 105
        try:
            self(args.config_file, args.log_file)
        except Exception as e:
            logger.error("Failed to start server.")
            logger.error(e)
            sys.exit(-1)
L
lym0302 已提交
106 107 108 109 110 111 112 113 114 115

    @stats_wrapper
    def __call__(self,
                 config_file: str="./conf/application.yaml",
                 log_file: str="./log/paddlespeech.log"):
        """
        Python API to call an executor.
        """
        config = get_config(config_file)
        if self.init(config):
小湉湉's avatar
小湉湉 已提交
116
            uvicorn.run(app, host=config.host, port=config.port)
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131


@cli_server_register(
    name='paddlespeech_server.stats',
    description='Get the models supported by each speech task in the service.')
class ServerStatsExecutor():
    def __init__(self):
        super(ServerStatsExecutor, self).__init__()

        self.parser = argparse.ArgumentParser(
            prog='paddlespeech_server.stats', add_help=True)
        self.parser.add_argument(
            '--task',
            type=str,
            default=None,
132
            choices=['asr', 'tts', 'cls', 'text', 'vector'],
133 134
            help='Choose speech task.',
            required=True)
135
        self.task_choices = ['asr', 'tts', 'cls', 'text', 'vector']
136 137
        self.model_name_format = {
            'asr': 'Model-Language-Sample Rate',
L
lym0302 已提交
138
            'tts': 'Model-Language',
139 140 141
            'cls': 'Model-Sample Rate',
            'text': 'Model-Task-Language',
            'vector': 'Model-Sample Rate'
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
        }

    def show_support_models(self, pretrained_models: dict):
        fields = self.model_name_format[self.task].split("-")
        table = PrettyTable(fields)
        for key in pretrained_models:
            table.add_row(key.split("-"))
        print(table)

    def execute(self, argv: List[str]) -> bool:
        """
            Command line entry.
        """
        parser_args = self.parser.parse_args(argv)
        self.task = parser_args.task
        if self.task not in self.task_choices:
            logger.error(
                "Please input correct speech task, choices = ['asr', 'tts']")
            return False

K
KP 已提交
162 163 164 165
        try:
            # Dynamic models
            dynamic_pretrained_models = CommonTaskResource(
                task=self.task, model_format='dynamic').pretrained_models
166

K
KP 已提交
167
            if len(dynamic_pretrained_models) > 0:
L
lym0302 已提交
168
                logger.info(
K
KP 已提交
169 170 171 172 173 174 175 176
                    "Here is the table of {} pretrained models supported in the service.".
                    format(self.task.upper()))
                self.show_support_models(dynamic_pretrained_models)

            # Static models
            static_pretrained_models = CommonTaskResource(
                task=self.task, model_format='static').pretrained_models
            if len(static_pretrained_models) > 0:
177
                logger.info(
K
KP 已提交
178 179
                    "Here is the table of {} static pretrained models supported in the service.".
                    format(self.task.upper()))
小湉湉's avatar
小湉湉 已提交
180
                self.show_support_models(static_pretrained_models)
181

K
KP 已提交
182
            return True
183

K
KP 已提交
184
        except BaseException:
185
            logger.error(
K
KP 已提交
186 187
                "Failed to get the table of {} pretrained models supported in the service.".
                format(self.task.upper()))
188
            return False