提交 c8574c7e 编写于 作者: H Hui Zhang

ds2 inference as sepearte engine for streaming asr

上级 b9e3e493
......@@ -11,7 +11,7 @@ port: 8090
# protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
engine_list: ['asr_online-inference']
#################################################################################
......@@ -20,7 +20,7 @@ engine_list: ['asr_online']
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
asr_online-inference:
model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
......
......@@ -187,7 +187,7 @@ class ASRExecutor(BaseExecutor):
elif "conformer" in model_type or "transformer" in model_type:
self.config.decode.decoding_method = decode_method
if num_decoding_left_chunks:
assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0"
assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, "num_decoding_left_chunks should be -1 or >=0"
self.config.num_decoding_left_chunks = num_decoding_left_chunks
else:
......
......@@ -224,6 +224,26 @@ asr_static_pretrained_models = {
'29e02312deb2e59b3c8686c7966d4fe3'
}
},
"deepspeech2online_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz',
'md5':
'df5ddeac8b679a470176649ac4b78726',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
}
# ---------------------------------
......
......@@ -11,7 +11,7 @@ port: 8090
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
engine_list: ['asr_online-inference']
#################################################################################
......@@ -20,7 +20,7 @@ engine_list: ['asr_online']
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
asr_online-inference:
model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
......
......@@ -125,7 +125,6 @@ class PaddleASRConnectionHanddler:
self.remained_wav = None
self.cached_feat = None
if "deepspeech2" in self.model_type:
return
......@@ -698,7 +697,7 @@ class ASRServerExecutor(ASRExecutor):
self.task_resource = CommonTaskResource(
task='asr', model_format='dynamic', inference_mode='online')
def update_config(self)->None:
def update_config(self) -> None:
if "deepspeech2" in self.model_type:
with UpdateConfig(self.config):
# download lm
......@@ -720,7 +719,7 @@ class ASRServerExecutor(ASRExecutor):
self.config.decode.decoding_method = self.decode_method
# update num_decoding_left_chunks
if self.num_decoding_left_chunks:
assert self.num_decoding_left_chunks == -1 or self.num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0"
assert self.num_decoding_left_chunks == -1 or self.num_decoding_left_chunks >= 0, "num_decoding_left_chunks should be -1 or >=0"
self.config.decode.num_decoding_left_chunks = self.num_decoding_left_chunks
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
......@@ -738,17 +737,17 @@ class ASRServerExecutor(ASRExecutor):
raise Exception(f"not support: {self.model_type}")
def init_model(self) -> None:
if "deepspeech2" in self.model_type :
if "deepspeech2" in self.model_type:
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
elif "conformer" in self.model_type or "transformer" in self.model_type :
elif "conformer" in self.model_type or "transformer" in self.model_type:
# load model
# model_type: {model_name}_{dataset}
model_name = self.model_type[:self.model_type.rindex('_')]
model_name = self.model_type[:self.model_type.rindex('_')]
logger.info(f"model name: {model_name}")
model_class = self.task_resource.get_model_class(model_name)
model = model_class.from_config(self.config)
......@@ -758,7 +757,6 @@ class ASRServerExecutor(ASRExecutor):
else:
raise Exception(f"not support: {self.model_type}")
def _init_from_path(self,
model_type: str=None,
am_model: Optional[os.PathLike]=None,
......@@ -786,7 +784,6 @@ class ASRServerExecutor(ASRExecutor):
self.am_predictor_conf = am_predictor_conf
logger.info(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(model_tag=tag)
......@@ -831,7 +828,7 @@ class ASRServerExecutor(ASRExecutor):
spm_model_prefix=self.config.spm_model_prefix)
self.update_config()
# AM predictor
self.init_model()
......@@ -850,7 +847,6 @@ class ASREngine(BaseEngine):
super(ASREngine, self).__init__()
logger.info("create the online asr engine resource instance")
def init_model(self) -> bool:
if not self.executor._init_from_path(
model_type=self.config.model_type,
......@@ -865,7 +861,6 @@ class ASREngine(BaseEngine):
return False
return True
def init(self, config: dict) -> bool:
"""init engine resource
......
......@@ -28,6 +28,9 @@ class EngineFactory(object):
elif engine_name == 'asr' and engine_type == 'online':
from paddlespeech.server.engine.asr.online.python.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'asr' and engine_type == 'online-inference':
from paddlespeech.server.engine.asr.online.paddleinference.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'asr' and engine_type == 'online-onnx':
from paddlespeech.server.engine.asr.online.onnx.asr_engine import ASREngine
return ASREngine()
......
......@@ -92,7 +92,7 @@ async def websocket_endpoint(websocket: WebSocket):
else:
resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp)
elif "bytes" in message:
# bytes for the pcm data
message = message["bytes"]
......
......@@ -747,7 +747,7 @@ def num2chn(number_string,
previous_symbol, (CNU, type(None))):
if next_symbol.power != 1 and (
(previous_symbol is None) or
(previous_symbol.power != 1)):
(previous_symbol.power != 1)): # noqa: E129
result_symbols[i] = liang
# if big is True, '两' will not be used and `alt_two` has no impact on output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册