asr_socket.py 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json

W
WilliamZhang06 已提交
16
import numpy as np
17 18 19 20 21 22 23 24 25 26 27
from fastapi import APIRouter
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState

from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio

router = APIRouter()

W
WilliamZhang06 已提交
28

29 30
@router.websocket('/ws/asr')
async def websocket_endpoint(websocket: WebSocket):
31
    print("websocket protocal receive the dataset")
32 33
    await websocket.accept()

W
WilliamZhang06 已提交
34 35
    engine_pool = get_engine_pool()
    asr_engine = engine_pool['asr']
36
    # init buffer
W
WilliamZhang06 已提交
37 38
    chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
    chunk_buffer = ChunkBuffer(
39 40 41 42
        window_n=7,
        shift_n=4,
        window_ms=20,
        shift_ms=10,
W
WilliamZhang06 已提交
43 44
        sample_rate=chunk_buffer_conf['sample_rate'],
        sample_width=chunk_buffer_conf['sample_width'])
45
    # init vad
46 47 48 49 50 51 52 53
    # print(asr_engine.config)
    # print(type(asr_engine.config))
    vad_conf = asr_engine.config.get('vad_conf', None)
    if vad_conf:
        vad = VADAudio(
            aggressiveness=vad_conf['aggressiveness'],
            rate=vad_conf['sample_rate'],
            frame_duration_ms=vad_conf['frame_duration_ms'])
54 55 56 57 58 59 60 61 62 63

    try:
        while True:
            # careful here, changed the source code from starlette.websockets
            assert websocket.application_state == WebSocketState.CONNECTED
            message = await websocket.receive()
            websocket._raise_on_disconnect(message)
            if "text" in message:
                message = json.loads(message["text"])
                if 'signal' not in message:
W
WilliamZhang06 已提交
64
                    resp = {"status": "ok", "message": "no valid json data"}
65 66 67
                    await websocket.send_json(resp)

                if message['signal'] == 'start':
W
WilliamZhang06 已提交
68
                    resp = {"status": "ok", "signal": "server_ready"}
69 70 71 72 73 74
                    # do something at begining here
                    await websocket.send_json(resp)
                elif message['signal'] == 'end':
                    engine_pool = get_engine_pool()
                    asr_engine = engine_pool['asr']
                    # reset single  engine for an new connection
75
                    # asr_engine.reset()
W
WilliamZhang06 已提交
76
                    resp = {"status": "ok", "signal": "finished"}
77 78 79
                    await websocket.send_json(resp)
                    break
                else:
W
WilliamZhang06 已提交
80
                    resp = {"status": "ok", "message": "no valid json data"}
81 82 83 84 85 86 87
                    await websocket.send_json(resp)
            elif "bytes" in message:
                message = message["bytes"]

                engine_pool = get_engine_pool()
                asr_engine = engine_pool['asr']
                asr_results = ""
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
                # frames = chunk_buffer.frame_generator(message)
                # for frame in frames:
                #     # get the pcm data from the bytes
                #     samples = np.frombuffer(frame.bytes, dtype=np.int16)
                #     sample_rate = asr_engine.config.sample_rate
                #     x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
                #                                                   sample_rate)
                #     asr_engine.run(x_chunk, x_chunk_lens)
                #     asr_results = asr_engine.postprocess()
                samples = np.frombuffer(message, dtype=np.int16)
                sample_rate = asr_engine.config.sample_rate
                x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
                                                              sample_rate)
                asr_engine.run(x_chunk, x_chunk_lens)
                # asr_results = asr_engine.postprocess()
103 104 105 106 107 108
                asr_results = asr_engine.postprocess()
                resp = {'asr_results': asr_results}

                await websocket.send_json(resp)
    except WebSocketDisconnect:
        pass