提交 2993ff58 编写于 作者: W WilliamZhang06

added engine factory and config, test=doc

...@@ -11,33 +11,3 @@ ...@@ -11,33 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
def init(args):
""" 系统初始化
"""
def main(args):
"""主程序入口"""
if init(args):
app.run(host='0.0.0.0', port=conf.port)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default="./conf/application.yaml")
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
args = parser.parse_args()
main(args)
...@@ -11,33 +11,3 @@ ...@@ -11,33 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
def init(args):
""" 系统初始化
"""
def main(args):
"""主程序入口"""
if init(args):
app.run(host='0.0.0.0', port=conf.port)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default="./conf/application.yaml")
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
args = parser.parse_args()
main(args)
...@@ -11,33 +11,3 @@ ...@@ -11,33 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
def init(args):
""" 系统初始化
"""
def main(args):
"""主程序入口"""
if init(args):
app.run(host='0.0.0.0', port=conf.port)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default="./conf/application.yaml")
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
args = parser.parse_args()
main(args)
...@@ -6,3 +6,9 @@ ...@@ -6,3 +6,9 @@
host: '0.0.0.0' host: '0.0.0.0'
port: 8090 port: 8090
##################################################################
# CONFIG FILE #
##################################################################
# add engine type (Options: asr, tts) and config file here.
engine_backend:
asr: 'conf/asr/asr.yaml'
\ No newline at end of file
model: 'conformer_wenetspeech'
lang: 'conformer_wenetspeech'
lang: 'zh'
sample_rate: 16000
decode_method: 'attention_rescoring'
...@@ -14,18 +14,21 @@ ...@@ -14,18 +14,21 @@
from engine.base_engine import BaseEngine from engine.base_engine import BaseEngine
from utils.log import logger from utils.log import logger
from utils.config import get_config
__all__ = ['ASREngine'] __all__ = ['ASREngine']
class ASREngine(BaseEngine): class ASREngine(BaseEngine):
def __init__(self, name=None): def __init__(self):
super(ASREngine, self).__init__() super(ASREngine, self).__init__()
self.executor = name
def init(self, config_file: str):
self.config_file = config_file
self.executor = None
self.input = None self.input = None
self.output = None self.output = None
config = get_config(self.config_file)
def init(self):
pass pass
def postprocess(self): def postprocess(self):
...@@ -34,12 +37,3 @@ class ASREngine(BaseEngine): ...@@ -34,12 +37,3 @@ class ASREngine(BaseEngine):
def run(self): def run(self):
logger.info("start run asr engine") logger.info("start run asr engine")
return "hello world" return "hello world"
if __name__ == "__main__":
# test Singleton
class1 = ASREngine("ASREngine")
class2 = ASREngine()
print(class1 is class2)
print(id(class1))
print(id(class2))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from engine.asr.python.asr_engine import ASREngine
from engine.tts.python.tts_engine import TTSEngine
class EngineFactory(object):
@staticmethod
def get_engine(engine_name):
if engine_name == 'asr':
return ASREngine()
elif engine_name == 'tts':
return TTSEngine()
else:
return None
...@@ -12,43 +12,42 @@ ...@@ -12,43 +12,42 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import uvicorn import uvicorn
import yaml import yaml
from engine.tts.python.tts_engine import TTSEngine
from fastapi import FastAPI from fastapi import FastAPI
from restful.api import router as api_router
from paddlespeech.cli.log import logger from restful.api import setup_router
from utils.log import logger
from utils.config import get_config
from engine.engine_factory import EngineFactory
app = FastAPI( app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1") title="PaddleSpeech Serving API", description="Api", version="0.0.1")
def init(args): def init(config):
""" 系统初始化 """ system initialization
""" """
# init api
api_list = list(config.engine_backend)
api_router = setup_router(api_list)
app.include_router(api_router) app.include_router(api_router)
# engine single # init engine
engine_list = []
TTS_ENGINE = TTSEngine() for engine in config.engine_backend:
engine_list.append(EngineFactory.get_engine(engine_name=engine))
# todo others engine_list[-1].init(config_file=config.engine_backend[engine])
return True return True
def main(args): def main(args):
"""主程序入口""" """main function"""
#TODO configuration config = get_config(args.config_file)
from yacs.config import CfgNode
with open(args.config_file, 'rt') as f:
config = CfgNode(yaml.safe_load(f))
if init(args): if init(config):
uvicorn.run(app, host=config.host, port=config.port, debug=True) uvicorn.run(app, host=config.host, port=config.port, debug=True)
...@@ -58,7 +57,7 @@ if __name__ == "__main__": ...@@ -58,7 +57,7 @@ if __name__ == "__main__":
"--config_file", "--config_file",
action="store", action="store",
help="yaml file of the app", help="yaml file of the app",
default="./conf/tts/tts.yaml") default="./conf/application.yaml")
parser.add_argument( parser.add_argument(
"--log_file", "--log_file",
......
...@@ -14,8 +14,19 @@ ...@@ -14,8 +14,19 @@
from fastapi import APIRouter from fastapi import APIRouter
from .tts_api import router as tts_router from .tts_api import router as tts_router
#from .asr_api import router as asr_router from .asr_api import router as asr_router
_router = APIRouter()
def setup_router(api_list: list):
for api_name in api_list:
if api_name == 'asr':
_router.include_router(asr_router)
elif api_name == 'tts':
_router.include_router(tts_router)
else:
pass
return _router
router = APIRouter()
#router.include_router(asr_router)
router.include_router(tts_router)
...@@ -14,13 +14,12 @@ ...@@ -14,13 +14,12 @@
from fastapi import APIRouter from fastapi import APIRouter
import base64 import base64
from engine.asr.python.asr_engine import ASREngine from engine.asr.python.asr_engine import ASREngine
from .response import ASRResponse from .response import ASRResponse
from .request import ASRRequest from .request import ASRRequest
router = APIRouter()
router = APIRouter()
@router.get('/paddlespeech/asr/help') @router.get('/paddlespeech/asr/help')
def help(): def help():
...@@ -44,8 +43,8 @@ def asr(request_body: ASRRequest): ...@@ -44,8 +43,8 @@ def asr(request_body: ASRRequest):
""" """
# single # single
asr_engine = ASREngine() asr_engine = ASREngine()
print("asr_engine id :" ,id(asr_engine))
asr_engine.init()
asr_results = asr_engine.run() asr_results = asr_engine.run()
asr_engine.postprocess() asr_engine.postprocess()
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
from yacs.config import CfgNode
def get_config(config_file):
"""[summary]
Args:
config_file (str): config_file
Returns:
CfgNode:
"""
with open(config_file, 'rt') as f:
config = CfgNode(yaml.safe_load(f))
return config
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the
import base64
def readwav2base64(wav_file):
"""
read wave file and covert to base64 string
"""
with open(wav_file, 'rb') as f:
base64_bytes = base64.b64encode(f.read())
base64_string = base64_bytes.decode('utf-8')
return base64_string
def readbase64towav(base64_string):
pass
def self_check():
""" self check resource
"""
return True
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册