提交 d4f863dc 编写于 作者: L lym0302

improve, test=doc

上级 018dda6e
......@@ -18,6 +18,7 @@ import io
import json
import os
import random
import sys
import time
from typing import List
......@@ -32,6 +33,7 @@ from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler
from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import compute_delay
from paddlespeech.server.utils.util import network_reachable
from paddlespeech.server.utils.util import wav2base64
__all__ = [
......@@ -128,6 +130,7 @@ class TTSClientExecutor(BaseExecutor):
return True
except Exception as e:
logger.error("Failed to synthesized audio.")
logger.error(e)
return False
@stats_wrapper
......@@ -154,6 +157,12 @@ class TTSClientExecutor(BaseExecutor):
"save_path": output
}
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(f"{network} unreachable, please check the ip address.")
sys.exit(-1)
res = requests.post(url, json.dumps(request))
response_dict = res.json()
if output is not None:
......@@ -236,6 +245,7 @@ class TTSOnlineClientExecutor(BaseExecutor):
return True
except Exception as e:
logger.error("Failed to synthesized audio.")
logger.error(e)
return False
@stats_wrapper
......@@ -254,6 +264,12 @@ class TTSOnlineClientExecutor(BaseExecutor):
Python API to call an executor.
"""
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(f"{network} unreachable, please check the ip address.")
sys.exit(-1)
if protocol == "http":
logger.info("tts http client start")
from paddlespeech.server.utils.audio_handler import TTSHttpHandler
......@@ -275,7 +291,7 @@ class TTSOnlineClientExecutor(BaseExecutor):
else:
logger.error("Please set correct protocol, http or websocket")
return False
sys.exit(-1)
logger.info(f"sentence: {input}")
logger.info(f"duration: {duration} s")
......@@ -399,6 +415,13 @@ class ASRClientExecutor(BaseExecutor):
# and paddlespeech_client asr only support http protocol
protocol = "http"
if protocol.lower() == "http":
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(
f"{network} unreachable, please check the ip address.")
sys.exit(-1)
from paddlespeech.server.utils.audio_handler import ASRHttpHandler
logger.info("asr http client start")
handler = ASRHttpHandler(server_ip=server_ip, port=port)
......@@ -503,6 +526,13 @@ class ASROnlineClientExecutor(BaseExecutor):
Returns:
str: the audio text
"""
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(f"{network} unreachable, please check the ip address.")
sys.exit(-1)
logger.info("asr websocket client start")
handler = ASRWsAudioHandler(
server_ip,
......@@ -555,6 +585,7 @@ class CLSClientExecutor(BaseExecutor):
return True
except Exception as e:
logger.error("Failed to speech classification.")
logger.error(e)
return False
@stats_wrapper
......@@ -567,6 +598,12 @@ class CLSClientExecutor(BaseExecutor):
Python API to call an executor.
"""
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(f"{network} unreachable, please check the ip address.")
sys.exit(-1)
url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/cls'
audio = wav2base64(input)
data = {"audio": audio, "topk": topk}
......@@ -632,6 +669,12 @@ class TextClientExecutor(BaseExecutor):
str: the punctuation text
"""
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(f"{network} unreachable, please check the ip address.")
sys.exit(-1)
url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/text'
request = {
"text": input,
......@@ -728,6 +771,13 @@ class VectorClientExecutor(BaseExecutor):
Returns:
str: the audio embedding or score between enroll and test audio
"""
# Check if the network is reachable
network = 'http://' + server_ip + ":" + str(port)
if network_reachable(network) is not True:
logger.error(f"{network} unreachable, please check the ip address.")
sys.exit(-1)
if task == "spk":
from paddlespeech.server.utils.audio_handler import VectorHttpHandler
logger.info("vector http client start")
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import sys
from typing import List
import uvicorn
......@@ -79,10 +80,12 @@ class ServerExecutor(BaseExecutor):
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
config = get_config(args.config_file)
if self.init(config):
uvicorn.run(app, host=config.host, port=config.port, debug=True)
try:
self(args.config_file, args.log_file)
except Exception as e:
logger.error("Failed to start server.")
logger.error(e)
sys.exit(-1)
@stats_wrapper
def __call__(self,
......
......@@ -304,18 +304,24 @@ class TTSWsHandler:
receive_time_list = []
chunk_duration_list = []
# 1. Send websocket handshake protocal
# 1. Send websocket handshake request
async with websockets.connect(self.url) as ws:
# 2. Server has already received handshake protocal
# send text to engine
# 2. Server has already received handshake response, send start request
start_request = json.dumps({"task": "tts", "signal": "start"})
await ws.send(start_request)
msg = await ws.recv()
logger.info(f"client receive msg={msg}")
msg = json.loads(msg)
session = msg["session"]
# 3. send speech synthesis request
text_base64 = str(base64.b64encode((text).encode('utf-8')), "UTF8")
d = {"text": text_base64}
d = json.dumps(d)
request = json.dumps({"text": text_base64})
st = time.time()
await ws.send(d)
await ws.send(request)
logging.info("send a message to the server")
# 3. Process the received response
# Process the received response
message = await ws.recv()
first_response = time.time() - st
message = json.loads(message)
......@@ -348,6 +354,15 @@ class TTSWsHandler:
save_audio_success = save_audio(all_bytes, output)
else:
save_audio_success = False
# 5. send end request
end_request = json.dumps({
"task": "tts",
"signal": "end",
"session": session
})
await ws.send(end_request)
else:
logger.error("infer error")
......@@ -458,6 +473,7 @@ class TTSHttpHandler:
final_response = time.time() - st
duration = len(all_bytes) / 2.0 / 24000
html.close() # when stream=True
if output is not None:
save_audio_success = save_audio(all_bytes, output)
......
......@@ -13,6 +13,8 @@
import base64
import math
import requests
def wav2base64(wav_file: str):
"""
......@@ -146,3 +148,21 @@ def count_engine(logfile: str="./nohup.out"):
print(
f"max final response: {max(final_response_list)} s, min final response: {min(final_response_list)} s"
)
def network_reachable(url: str, timeout: int=5) -> bool:
"""Check if the network is reachable
Args:
url (str): http://server_ip:port or ws://server_ip:port
timeout (int, optional): timeout. Defaults to 5.
Returns:
bool: Whether the network is reachable.
"""
try:
request = requests.get(url, timeout=timeout)
return True
except (requests.ConnectionError, requests.Timeout) as exception:
print(exception)
return False
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import uuid
from fastapi import APIRouter
from fastapi import WebSocket
......@@ -26,36 +27,71 @@ router = APIRouter()
@router.websocket('/paddlespeech/tts/streaming')
async def websocket_endpoint(websocket: WebSocket):
"""PaddleSpeech Online TTS Server api
Args:
websocket (WebSocket): the websocket instance
"""
#1. the interface wait to accept the websocket protocal header
# and only we receive the header, it establish the connection with specific thread
await websocket.accept()
#2. if we accept the websocket headers, we will get the online tts engine instance
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
try:
# careful here, changed the source code from starlette.websockets
assert websocket.application_state == WebSocketState.CONNECTED
message = await websocket.receive()
websocket._raise_on_disconnect(message)
while True:
# careful here, changed the source code from starlette.websockets
assert websocket.application_state == WebSocketState.CONNECTED
message = await websocket.receive()
websocket._raise_on_disconnect(message)
message = json.loads(message["text"])
# get engine
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
if 'signal' in message:
# start request
if message['signal'] == 'start':
session = uuid.uuid1().hex
resp = {
"status": 0,
"signal": "server ready",
"session": session
}
await websocket.send_json(resp)
# 获取 message 并转文本
message = json.loads(message["text"])
text_bese64 = message["text"]
sentence = tts_engine.preprocess(text_bese64=text_bese64)
# end request
elif message['signal'] == 'end':
resp = {
"status": 0,
"signal": "connection will be closed",
"session": session
}
await websocket.send_json(resp)
# run
wav_generator = tts_engine.run(sentence)
# speech synthesis request
elif 'text' in message:
text_bese64 = message["text"]
sentence = tts_engine.preprocess(text_bese64=text_bese64)
while True:
try:
tts_results = next(wav_generator)
resp = {"status": 1, "audio": tts_results}
await websocket.send_json(resp)
except StopIteration as e:
resp = {"status": 2, "audio": ''}
await websocket.send_json(resp)
logger.info("Complete the transmission of audio streams")
break
# run
wav_generator = tts_engine.run(sentence)
while True:
try:
tts_results = next(wav_generator)
resp = {"status": 1, "audio": tts_results}
await websocket.send_json(resp)
except StopIteration as e:
resp = {"status": 2, "audio": ''}
await websocket.send_json(resp)
logger.info(
"Complete the transmission of audio streams")
break
else:
logger.error(
"Invalid request, please check if the request is correct.")
except WebSocketDisconnect:
pass
\ No newline at end of file
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册