asr_socket.py 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json

from fastapi import APIRouter
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState

X
xiongxinlei 已提交
21
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
22 23 24 25 26 27
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio

router = APIRouter()

W
WilliamZhang06 已提交
28

29 30 31 32
@router.websocket('/ws/asr')
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()

W
WilliamZhang06 已提交
33 34
    engine_pool = get_engine_pool()
    asr_engine = engine_pool['asr']
X
xiongxinlei 已提交
35
    connection_handler = None
36
    # init buffer
37
    # each websocekt connection has its own chunk buffer
W
WilliamZhang06 已提交
38 39
    chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
    chunk_buffer = ChunkBuffer(
40 41 42 43 44 45 46
        window_n=chunk_buffer_conf.window_n,
        shift_n=chunk_buffer_conf.shift_n,
        window_ms=chunk_buffer_conf.window_ms,
        shift_ms=chunk_buffer_conf.shift_ms,
        sample_rate=chunk_buffer_conf.sample_rate,
        sample_width=chunk_buffer_conf.sample_width)

47
    # init vad
48 49 50 51 52 53
    vad_conf = asr_engine.config.get('vad_conf', None)
    if vad_conf:
        vad = VADAudio(
            aggressiveness=vad_conf['aggressiveness'],
            rate=vad_conf['sample_rate'],
            frame_duration_ms=vad_conf['frame_duration_ms'])
54 55 56 57 58 59 60 61 62 63

    try:
        while True:
            # careful here, changed the source code from starlette.websockets
            assert websocket.application_state == WebSocketState.CONNECTED
            message = await websocket.receive()
            websocket._raise_on_disconnect(message)
            if "text" in message:
                message = json.loads(message["text"])
                if 'signal' not in message:
W
WilliamZhang06 已提交
64
                    resp = {"status": "ok", "message": "no valid json data"}
65 66 67
                    await websocket.send_json(resp)

                if message['signal'] == 'start':
W
WilliamZhang06 已提交
68
                    resp = {"status": "ok", "signal": "server_ready"}
69
                    # do something at begining here
X
xiongxinlei 已提交
70 71
                    # create the instance to process the audio
                    connection_handler = PaddleASRConnectionHanddler(asr_engine)
72 73 74
                    await websocket.send_json(resp)
                elif message['signal'] == 'end':
                    # reset single  engine for an new connection
X
xiongxinlei 已提交
75 76 77
                    connection_handler.decode(is_finished=True)
                    connection_handler.rescoring()
                    asr_results = connection_handler.get_result()
X
xiongxinlei 已提交
78
                    connection_handler.reset()
X
xiongxinlei 已提交
79 80 81 82 83 84

                    resp = {
                        "status": "ok",
                        "signal": "finished",
                        'asr_results': asr_results
                    }
85 86 87
                    await websocket.send_json(resp)
                    break
                else:
W
WilliamZhang06 已提交
88
                    resp = {"status": "ok", "message": "no valid json data"}
89 90 91
                    await websocket.send_json(resp)
            elif "bytes" in message:
                message = message["bytes"]
X
xiongxinlei 已提交
92

X
xiongxinlei 已提交
93
                connection_handler.extract_feat(message)
X
xiongxinlei 已提交
94 95
                connection_handler.decode(is_finished=False)
                asr_results = connection_handler.get_result()
96

97 98 99 100
                resp = {'asr_results': asr_results}
                await websocket.send_json(resp)
    except WebSocketDisconnect:
        pass