提交 7be6b0e8 编写于 作者: H Hui Zhang

unify name style & frame with abs timestamp

上级 15b25199
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import uvicorn
from fastapi import FastAPI
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.ws.api import setup_router as setup_ws_router
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
def init(config):
"""system initialization
Args:
config (CfgNode): config object
Returns:
bool:
"""
# init api
api_list = list(engine.split("_")[0] for engine in config.engine_list)
if config.protocol == "websocket":
api_router = setup_ws_router(api_list)
elif config.protocol == "http":
api_router = setup_http_router(api_list)
else:
raise Exception("unsupported protocol")
app.include_router(api_router)
if not init_engine_pool(config):
return False
return True
def main(args):
"""main function"""
config = get_config(args.config_file)
if init(config):
uvicorn.run(app, host=config.host, port=config.port, debug=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default="./conf/application.yaml")
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
args = parser.parse_args()
main(args)
......@@ -29,19 +29,19 @@ def setup_router(api_list: List):
"""setup router for fastapi
Args:
api_list (List): [asr, tts, cls]
api_list (List): [asr, tts, cls, text, vecotr]
Returns:
APIRouter
"""
for api_name in api_list:
if api_name == 'asr':
if api_name.lower() == 'asr':
_router.include_router(asr_router)
elif api_name == 'tts':
elif api_name.lower() == 'tts':
_router.include_router(tts_router)
elif api_name == 'cls':
elif api_name.lower() == 'cls':
_router.include_router(cls_router)
elif api_name == 'text':
elif api_name.lower() == 'text':
_router.include_router(text_router)
elif api_name.lower() == 'vector':
_router.include_router(vec_router)
......
......@@ -43,6 +43,7 @@ class TextHttpHandler:
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/text'
logger.info(f"endpoint: {self.url}")
def run(self, text):
"""Call the text server to process the specific text
......@@ -107,8 +108,10 @@ class ASRWsAudioHandler:
"""
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples)
assert sample_rate == 16000
chunk_size = int(85 * sample_rate / 1000) # 85ms, sample_rate = 16kHz
chunk_size = 85 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size != 0:
padding_len_x = chunk_size - x_len % chunk_size
else:
......@@ -217,6 +220,7 @@ class ASRHttpHandler:
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/asr'
logger.info(f"endpoint: {self.url}")
def run(self, input, audio_format, sample_rate, lang):
"""Call the http asr to process the audio
......@@ -275,6 +279,7 @@ class TTSWsHandler:
self.start_play = True
self.t = threading.Thread(target=self.play_audio)
self.max_fail = 50
logger.info(f"endpoint: {self.url}")
def play_audio(self):
while True:
......@@ -383,6 +388,7 @@ class TTSHttpHandler:
self.start_play = True
self.t = threading.Thread(target=self.play_audio)
self.max_fail = 50
logger.info(f"endpoint: {self.url}")
def play_audio(self):
while True:
......@@ -483,6 +489,7 @@ class VectorHttpHandler:
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/vector'
logger.info(f"endpoint: {self.url}")
def run(self, input, audio_format, sample_rate, task="spk"):
"""Call the http asr to process the audio
......@@ -529,6 +536,7 @@ class VectorScoreHttpHandler:
else:
self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/vector/score'
logger.info(f"endpoint: {self.url}")
def run(self, enroll_audio, test_audio, audio_format, sample_rate):
"""Call the http asr to process the audio
......
......@@ -107,7 +107,7 @@ def change_speed(sample_raw, speed_rate, sample_rate):
def float2pcm(sig, dtype='int16'):
"""Convert floating point signal with a range from -1 to 1 to PCM.
"""Convert floating point signal with a range from -1 to 1 to PCM16.
Args:
sig (array): Input array, must have floating point type.
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
class Frame(object):
"""Represents a "frame" of audio data."""
......@@ -46,8 +45,7 @@ class ChunkBuffer(object):
self.shift_ms = shift_ms
self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4
self.remained_audio = b''
self.window_sec = float((self.window_n - 1) * self.shift_ms +
self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
......@@ -57,22 +55,31 @@ class ChunkBuffer(object):
self.shift_bytes = int(self.shift_sec * self.sample_rate *
self.sample_width)
self.remained_audio = b''
# abs timestamp from `start` or latest `reset`
self.timestamp = 0.0
def reset(self):
"""
reset buffer state.
"""
self.timestamp = 0.0
self.remained_audio = b''
def frame_generator(self, audio):
"""Generates audio frames from PCM audio data.
Takes the desired frame duration in milliseconds, the PCM data, and
the sample rate.
Yields Frames of the requested duration.
"""
audio = self.remained_audio + audio
self.remained_audio = b''
offset = 0
timestamp = 0.0
while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp,
yield Frame(audio[offset:offset + self.window_bytes], self.timestamp,
self.window_sec)
timestamp += self.shift_sec
self.timestamp += self.shift_sec
offset += self.shift_bytes
self.remained_audio += audio[offset:]
......@@ -15,8 +15,8 @@ from typing import List
from fastapi import APIRouter
from paddlespeech.server.ws.asr_socket import router as asr_router
from paddlespeech.server.ws.tts_socket import router as tts_router
from paddlespeech.server.ws.asr_api import router as asr_router
from paddlespeech.server.ws.tts_api import router as tts_router
_router = APIRouter()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from fastapi import APIRouter
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
from paddlespeech.server.engine.engine_pool import get_engine_pool
router = APIRouter()
@router.websocket('/paddlespeech/asr/streaming')
async def websocket_endpoint(websocket: WebSocket):
"""PaddleSpeech Online ASR Server api
Args:
websocket (WebSocket): the websocket instance
"""
#1. the interface wait to accept the websocket protocal header
# and only we receive the header, it establish the connection with specific thread
await websocket.accept()
#2. if we accept the websocket headers, we will get the online asr engine instance
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
#3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio
# and each connection has its own connection instance to process the request
# and only if client send the start signal, we create the PaddleASRConnectionHanddler instance
connection_handler = None
try:
#4. we do a loop to process the audio package by package according the protocal
# and only if the client send finished signal, we will break the loop
while True:
# careful here, changed the source code from starlette.websockets
# 4.1 we wait for the client signal for the specific action
assert websocket.application_state == WebSocketState.CONNECTED
message = await websocket.receive()
websocket._raise_on_disconnect(message)
#4.2 text for the action command and bytes for pcm data
if "text" in message:
# we first parse the specific command
message = json.loads(message["text"])
if 'signal' not in message:
resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp)
# start command, we create the PaddleASRConnectionHanddler instance to process the audio data
# end command, we process the all the last audio pcm and return the final result
# and we break the loop
if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here
# create the instance to process the audio
connection_handler = PaddleASRConnectionHanddler(asr_engine)
await websocket.send_json(resp)
elif message['signal'] == 'end':
# reset single engine for an new connection
# and we will destroy the connection
connection_handler.decode(is_finished=True)
connection_handler.rescoring()
asr_results = connection_handler.get_result()
word_time_stamp = connection_handler.get_word_time_stamp()
connection_handler.reset()
resp = {
"status": "ok",
"signal": "finished",
'result': asr_results,
'times': word_time_stamp
}
await websocket.send_json(resp)
break
else:
resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp)
elif "bytes" in message:
# bytes for the pcm data
message = message["bytes"]
# we extract the remained audio pcm
# and decode for the result in this package data
connection_handler.extract_feat(message)
connection_handler.decode(is_finished=False)
asr_results = connection_handler.get_result()
# return the current period result
# if the engine create the vad instance, this connection will have many period results
resp = {'result': asr_results}
await websocket.send_json(resp)
except WebSocketDisconnect:
pass
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from fastapi import APIRouter
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
router = APIRouter()
@router.websocket('/paddlespeech/tts/streaming')
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
# careful here, changed the source code from starlette.websockets
assert websocket.application_state == WebSocketState.CONNECTED
message = await websocket.receive()
websocket._raise_on_disconnect(message)
# get engine
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
# 获取 message 并转文本
message = json.loads(message["text"])
text_bese64 = message["text"]
sentence = tts_engine.preprocess(text_bese64=text_bese64)
# run
wav_generator = tts_engine.run(sentence)
while True:
try:
tts_results = next(wav_generator)
resp = {"status": 1, "audio": tts_results}
await websocket.send_json(resp)
except StopIteration as e:
resp = {"status": 2, "audio": ''}
await websocket.send_json(resp)
logger.info("Complete the transmission of audio streams")
break
except WebSocketDisconnect:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册