未验证 提交 cf9a590f 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1704 from Honei/server

[asr][websocket] add asr conformer websocket server
...@@ -40,7 +40,6 @@ from paddlespeech.s2t.utils.utility import UpdateConfig ...@@ -40,7 +40,6 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor'] __all__ = ['ASRExecutor']
@cli_register( @cli_register(
name='paddlespeech.asr', description='Speech to text infer command.') name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor): class ASRExecutor(BaseExecutor):
...@@ -125,6 +124,7 @@ class ASRExecutor(BaseExecutor): ...@@ -125,6 +124,7 @@ class ASRExecutor(BaseExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
logger.info("start to init the model")
if hasattr(self, 'model'): if hasattr(self, 'model'):
logger.info('Model had been initialized.') logger.info('Model had been initialized.')
return return
...@@ -140,14 +140,15 @@ class ASRExecutor(BaseExecutor): ...@@ -140,14 +140,15 @@ class ASRExecutor(BaseExecutor):
res_path, res_path,
self.pretrained_models[tag]['ckpt_path'] + ".pdparams") self.pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
self.res_path = os.path.dirname( self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
...@@ -176,7 +177,6 @@ class ASRExecutor(BaseExecutor): ...@@ -176,7 +177,6 @@ class ASRExecutor(BaseExecutor):
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
else: else:
raise Exception("wrong type") raise Exception("wrong type")
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
...@@ -254,12 +254,14 @@ class ASRExecutor(BaseExecutor): ...@@ -254,12 +254,14 @@ class ASRExecutor(BaseExecutor):
else: else:
raise Exception("wrong type") raise Exception("wrong type")
logger.info("audio feat process success")
@paddle.no_grad() @paddle.no_grad()
def infer(self, model_type: str): def infer(self, model_type: str):
""" """
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
logger.info("start to infer the model to get the output")
cfg = self.config.decode cfg = self.config.decode
audio = self._inputs["audio"] audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"] audio_len = self._inputs["audio_len"]
...@@ -276,17 +278,22 @@ class ASRExecutor(BaseExecutor): ...@@ -276,17 +278,22 @@ class ASRExecutor(BaseExecutor):
self._outputs["result"] = result_transcripts[0] self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
result_transcripts = self.model.decode( logger.info(f"we will use the transformer like model : {model_type}")
audio, try:
audio_len, result_transcripts = self.model.decode(
text_feature=self.text_feature, audio,
decoding_method=cfg.decoding_method, audio_len,
beam_size=cfg.beam_size, text_feature=self.text_feature,
ctc_weight=cfg.ctc_weight, decoding_method=cfg.decoding_method,
decoding_chunk_size=cfg.decoding_chunk_size, beam_size=cfg.beam_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, ctc_weight=cfg.ctc_weight,
simulate_streaming=cfg.simulate_streaming) decoding_chunk_size=cfg.decoding_chunk_size,
self._outputs["result"] = result_transcripts[0][0] num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
self._outputs["result"] = result_transcripts[0][0]
except Exception as e:
logger.exception(e)
else: else:
raise Exception("invalid model name") raise Exception("invalid model name")
......
...@@ -88,6 +88,8 @@ model_alias = { ...@@ -88,6 +88,8 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer": "conformer":
"paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.u2:U2Model",
"conformer_online":
"paddlespeech.s2t.models.u2:U2Model",
"transformer": "transformer":
"paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.u2:U2Model",
"wenetspeech": "wenetspeech":
......
...@@ -279,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer): ...@@ -279,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
# TODO(Hui Zhang): if end_flag.sum() == running_size: # TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size: if end_flag.cast(paddle.int64).sum() == running_size:
break break
# 2.1 Forward decoder step # 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i) running_size, 1, 1).to(device) # (B*N, i, i)
# logp: (B*N, vocab) # logp: (B*N, vocab)
logp, cache = self.decoder.forward_one_step( logp, cache = self.decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache) encoder_out, encoder_mask, hyps, hyps_mask, cache)
# 2.2 First beam prune: select topk best prob at current time # 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag) top_k_logp = mask_finished_scores(top_k_logp, end_flag)
...@@ -708,11 +707,11 @@ class U2BaseModel(ASRInterface, nn.Layer): ...@@ -708,11 +707,11 @@ class U2BaseModel(ASRInterface, nn.Layer):
batch_size = feats.shape[0] batch_size = feats.shape[0]
if decoding_method in ['ctc_prefix_beam_search', if decoding_method in ['ctc_prefix_beam_search',
'attention_rescoring'] and batch_size > 1: 'attention_rescoring'] and batch_size > 1:
logger.fatal( logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1' f'decoding mode {decoding_method} must be running with batch_size == 1'
) )
logger.error(f"current batch_size is {batch_size}")
sys.exit(1) sys.exit(1)
if decoding_method == 'attention': if decoding_method == 'attention':
hyps = self.recognize( hyps = self.recognize(
feats, feats,
......
...@@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase): ...@@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase):
# init once # init once
if self._ext_scorer is not None: if self._ext_scorer is not None:
return return
if language_model_path != '': if language_model_path != '':
logger.info("begin to initialize the external scorer " logger.info("begin to initialize the external scorer "
"for decoding") "for decoding")
......
...@@ -35,3 +35,16 @@ ...@@ -35,3 +35,16 @@
```bash ```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
``` ```
## Online ASR Server
### Lanuch online asr server
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### Access online asr server
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
```
\ No newline at end of file
...@@ -35,3 +35,17 @@ ...@@ -35,3 +35,17 @@
```bash ```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
``` ```
## 流式ASR
### 启动流式语音识别服务
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### 访问流式语音识别服务
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav
```
\ No newline at end of file
...@@ -277,11 +277,12 @@ class ASRClientExecutor(BaseExecutor): ...@@ -277,11 +277,12 @@ class ASRClientExecutor(BaseExecutor):
lang=lang, lang=lang,
audio_format=audio_format) audio_format=audio_format)
time_end = time.time() time_end = time.time()
logger.info(res.json()) logger.info(res)
logger.info("Response time %f s." % (time_end - time_start)) logger.info("Response time %f s." % (time_end - time_start))
return True return True
except Exception as e: except Exception as e:
logger.error("Failed to speech recognition.") logger.error("Failed to speech recognition.")
logger.error(e)
return False return False
@stats_wrapper @stats_wrapper
...@@ -299,9 +300,10 @@ class ASRClientExecutor(BaseExecutor): ...@@ -299,9 +300,10 @@ class ASRClientExecutor(BaseExecutor):
logging.info("asr websocket client start") logging.info("asr websocket client start")
handler = ASRAudioHandler(server_ip, port) handler = ASRAudioHandler(server_ip, port)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(input)) res = loop.run_until_complete(handler.run(input))
logging.info("asr websocket client finished") logging.info("asr websocket client finished")
return res['asr_results']
@cli_client_register( @cli_client_register(
name='paddlespeech_client.cls', description='visit cls service') name='paddlespeech_client.cls', description='visit cls service')
......
...@@ -41,11 +41,7 @@ asr_online: ...@@ -41,11 +41,7 @@ asr_online:
shift_ms: 40 shift_ms: 40
sample_rate: 16000 sample_rate: 16000
sample_width: 2 sample_width: 2
window_n: 7 # frame
vad_conf: shift_n: 4 # frame
aggressiveness: 2 window_ms: 20 # ms
sample_rate: 16000 shift_ms: 10 # ms
frame_duration_ms: 20
sample_width: 2
padding_ms: 200
padding_ratio: 0.9
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'conformer_online_multicn'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
from paddlespeech.cli.log import logger
from paddlespeech.s2t.utils.utility import log_add
__all__ = ['CTCPrefixBeamSearch']
class CTCPrefixBeamSearch:
def __init__(self, config):
"""Implement the ctc prefix beam search
Args:
config (yacs.config.CfgNode): _description_
"""
self.config = config
self.reset()
@paddle.no_grad()
def search(self, ctc_probs, device, blank_id=0):
"""ctc prefix beam search method decode a chunk feature
Args:
xs (paddle.Tensor): feature data
ctc_probs (paddle.Tensor): the ctc probability of all the tokens
device (paddle.fluid.core_avx.Place): the feature host device, such as CUDAPlace(0).
blank_id (int, optional): the blank id in the vocab. Defaults to 0.
Returns:
list: the search result
"""
# decode
logger.info("start to ctc prefix search")
batch_size = 1
beam_size = self.config.beam_size
maxlen = ctc_probs.shape[0]
assert len(ctc_probs.shape) == 2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if self.cur_hyps is None:
self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in self.cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
self.cur_hyps = next_hyps[:beam_size]
self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps]
logger.info("ctc prefix search success")
return self.hyps
def get_one_best_hyps(self):
"""Return the one best result
Returns:
list: the one best result
"""
return [self.hyps[0][0]]
def get_hyps(self):
"""Return the search hyps
Returns:
list: return the search hyps
"""
return self.hyps
def reset(self):
"""Rest the search cache value
"""
self.cur_hyps = None
self.hyps = None
def finalize_search(self):
"""do nothing in ctc_prefix_beam_search
"""
pass
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -34,10 +34,9 @@ class ASRAudioHandler: ...@@ -34,10 +34,9 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str): def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples) x_len = len(samples)
# chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size != 0: chunk_size = 85 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size!= 0:
padding_len_x = chunk_size - x_len % chunk_size padding_len_x = chunk_size - x_len % chunk_size
else: else:
padding_len_x = 0 padding_len_x = 0
...@@ -48,7 +47,6 @@ class ASRAudioHandler: ...@@ -48,7 +47,6 @@ class ASRAudioHandler:
assert (x_len + padding_len_x) % chunk_size == 0 assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk) num_chunk = int(num_chunk)
for i in range(0, num_chunk): for i in range(0, num_chunk):
start = i * chunk_size start = i * chunk_size
end = start + chunk_size end = start + chunk_size
...@@ -57,7 +55,11 @@ class ASRAudioHandler: ...@@ -57,7 +55,11 @@ class ASRAudioHandler:
async def run(self, wavfile_path: str): async def run(self, wavfile_path: str):
logging.info("send a message to the server") logging.info("send a message to the server")
# self.read_wave()
# send websocket handshake protocal
async with websockets.connect(self.url) as ws: async with websockets.connect(self.url) as ws:
# server has already received handshake protocal
# client start to send the command
audio_info = json.dumps( audio_info = json.dumps(
{ {
"name": "test.wav", "name": "test.wav",
...@@ -78,7 +80,6 @@ class ASRAudioHandler: ...@@ -78,7 +80,6 @@ class ASRAudioHandler:
msg = json.loads(msg) msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
result = msg
# finished # finished
audio_info = json.dumps( audio_info = json.dumps(
{ {
...@@ -91,10 +92,12 @@ class ASRAudioHandler: ...@@ -91,10 +92,12 @@ class ASRAudioHandler:
separators=(',', ': ')) separators=(',', ': '))
await ws.send(audio_info) await ws.send(audio_info)
msg = await ws.recv() msg = await ws.recv()
# decode the bytes to str
msg = json.loads(msg) msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("final receive msg={}".format(msg))
result = msg
return result return result
def main(args): def main(args):
......
...@@ -63,12 +63,12 @@ class ChunkBuffer(object): ...@@ -63,12 +63,12 @@ class ChunkBuffer(object):
the sample rate. the sample rate.
Yields Frames of the requested duration. Yields Frames of the requested duration.
""" """
audio = self.remained_audio + audio audio = self.remained_audio + audio
self.remained_audio = b'' self.remained_audio = b''
offset = 0 offset = 0
timestamp = 0.0 timestamp = 0.0
while offset + self.window_bytes <= len(audio): while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp, yield Frame(audio[offset:offset + self.window_bytes], timestamp,
self.window_sec) self.window_sec)
......
...@@ -13,12 +13,12 @@ ...@@ -13,12 +13,12 @@
# limitations under the License. # limitations under the License.
import json import json
import numpy as np
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import WebSocket from fastapi import WebSocket
from fastapi import WebSocketDisconnect from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio from paddlespeech.server.utils.vad import VADAudio
...@@ -28,26 +28,29 @@ router = APIRouter() ...@@ -28,26 +28,29 @@ router = APIRouter()
@router.websocket('/ws/asr') @router.websocket('/ws/asr')
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
await websocket.accept() await websocket.accept()
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
connection_handler = None
# init buffer # init buffer
# each websocekt connection has its own chunk buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer( chunk_buffer = ChunkBuffer(
window_n=7, window_n=chunk_buffer_conf.window_n,
shift_n=4, shift_n=chunk_buffer_conf.shift_n,
window_ms=20, window_ms=chunk_buffer_conf.window_ms,
shift_ms=10, shift_ms=chunk_buffer_conf.shift_ms,
sample_rate=chunk_buffer_conf['sample_rate'], sample_rate=chunk_buffer_conf.sample_rate,
sample_width=chunk_buffer_conf['sample_width']) sample_width=chunk_buffer_conf.sample_width)
# init vad # init vad
vad_conf = asr_engine.config.vad_conf vad_conf = asr_engine.config.get('vad_conf', None)
vad = VADAudio( if vad_conf:
aggressiveness=vad_conf['aggressiveness'], vad = VADAudio(
rate=vad_conf['sample_rate'], aggressiveness=vad_conf['aggressiveness'],
frame_duration_ms=vad_conf['frame_duration_ms']) rate=vad_conf['sample_rate'],
frame_duration_ms=vad_conf['frame_duration_ms'])
try: try:
while True: while True:
...@@ -64,13 +67,21 @@ async def websocket_endpoint(websocket: WebSocket): ...@@ -64,13 +67,21 @@ async def websocket_endpoint(websocket: WebSocket):
if message['signal'] == 'start': if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"} resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here # do something at begining here
# create the instance to process the audio
connection_handler = PaddleASRConnectionHanddler(asr_engine)
await websocket.send_json(resp) await websocket.send_json(resp)
elif message['signal'] == 'end': elif message['signal'] == 'end':
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
# reset single engine for an new connection # reset single engine for an new connection
asr_engine.reset() connection_handler.decode(is_finished=True)
resp = {"status": "ok", "signal": "finished"} connection_handler.rescoring()
asr_results = connection_handler.get_result()
connection_handler.reset()
resp = {
"status": "ok",
"signal": "finished",
'asr_results': asr_results
}
await websocket.send_json(resp) await websocket.send_json(resp)
break break
else: else:
...@@ -79,21 +90,11 @@ async def websocket_endpoint(websocket: WebSocket): ...@@ -79,21 +90,11 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message: elif "bytes" in message:
message = message["bytes"] message = message["bytes"]
engine_pool = get_engine_pool() connection_handler.extract_feat(message)
asr_engine = engine_pool['asr'] connection_handler.decode(is_finished=False)
asr_results = "" asr_results = connection_handler.get_result()
frames = chunk_buffer.frame_generator(message)
for frame in frames:
samples = np.frombuffer(frame.bytes, dtype=np.int16)
sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate)
asr_engine.run(x_chunk, x_chunk_lens)
asr_results = asr_engine.postprocess()
asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results} resp = {'asr_results': asr_results}
await websocket.send_json(resp) await websocket.send_json(resp)
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册