提交 c8574c7e 编写于 作者: H Hui Zhang

ds2 inference as sepearte engine for streaming asr

上级 b9e3e493
...@@ -11,7 +11,7 @@ port: 8090 ...@@ -11,7 +11,7 @@ port: 8090
# protocol = ['websocket'] (only one can be selected). # protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type. # websocket only support online engine type.
protocol: 'websocket' protocol: 'websocket'
engine_list: ['asr_online'] engine_list: ['asr_online-inference']
################################################################################# #################################################################################
...@@ -20,7 +20,7 @@ engine_list: ['asr_online'] ...@@ -20,7 +20,7 @@ engine_list: ['asr_online']
################################### ASR ######################################### ################################### ASR #########################################
################### speech task: asr; engine_type: online ####################### ################### speech task: asr; engine_type: online #######################
asr_online: asr_online-inference:
model_type: 'deepspeech2online_aishell' model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
......
...@@ -187,7 +187,7 @@ class ASRExecutor(BaseExecutor): ...@@ -187,7 +187,7 @@ class ASRExecutor(BaseExecutor):
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
if num_decoding_left_chunks: if num_decoding_left_chunks:
assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0" assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0, "num_decoding_left_chunks should be -1 or >=0"
self.config.num_decoding_left_chunks = num_decoding_left_chunks self.config.num_decoding_left_chunks = num_decoding_left_chunks
else: else:
......
...@@ -224,6 +224,26 @@ asr_static_pretrained_models = { ...@@ -224,6 +224,26 @@ asr_static_pretrained_models = {
'29e02312deb2e59b3c8686c7966d4fe3' '29e02312deb2e59b3c8686c7966d4fe3'
} }
}, },
"deepspeech2online_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz',
'md5':
'df5ddeac8b679a470176649ac4b78726',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
} }
# --------------------------------- # ---------------------------------
......
...@@ -11,7 +11,7 @@ port: 8090 ...@@ -11,7 +11,7 @@ port: 8090
# protocol = ['websocket', 'http'] (only one can be selected). # protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type. # websocket only support online engine type.
protocol: 'websocket' protocol: 'websocket'
engine_list: ['asr_online'] engine_list: ['asr_online-inference']
################################################################################# #################################################################################
...@@ -20,7 +20,7 @@ engine_list: ['asr_online'] ...@@ -20,7 +20,7 @@ engine_list: ['asr_online']
################################### ASR ######################################### ################################### ASR #########################################
################### speech task: asr; engine_type: online ####################### ################### speech task: asr; engine_type: online #######################
asr_online: asr_online-inference:
model_type: 'deepspeech2online_aishell' model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
......
...@@ -125,7 +125,6 @@ class PaddleASRConnectionHanddler: ...@@ -125,7 +125,6 @@ class PaddleASRConnectionHanddler:
self.remained_wav = None self.remained_wav = None
self.cached_feat = None self.cached_feat = None
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
return return
...@@ -698,7 +697,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -698,7 +697,7 @@ class ASRServerExecutor(ASRExecutor):
self.task_resource = CommonTaskResource( self.task_resource = CommonTaskResource(
task='asr', model_format='dynamic', inference_mode='online') task='asr', model_format='dynamic', inference_mode='online')
def update_config(self)->None: def update_config(self) -> None:
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
with UpdateConfig(self.config): with UpdateConfig(self.config):
# download lm # download lm
...@@ -720,7 +719,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -720,7 +719,7 @@ class ASRServerExecutor(ASRExecutor):
self.config.decode.decoding_method = self.decode_method self.config.decode.decoding_method = self.decode_method
# update num_decoding_left_chunks # update num_decoding_left_chunks
if self.num_decoding_left_chunks: if self.num_decoding_left_chunks:
assert self.num_decoding_left_chunks == -1 or self.num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0" assert self.num_decoding_left_chunks == -1 or self.num_decoding_left_chunks >= 0, "num_decoding_left_chunks should be -1 or >=0"
self.config.decode.num_decoding_left_chunks = self.num_decoding_left_chunks self.config.decode.num_decoding_left_chunks = self.num_decoding_left_chunks
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method # we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring # Generally we set the decoding_method to attention_rescoring
...@@ -738,17 +737,17 @@ class ASRServerExecutor(ASRExecutor): ...@@ -738,17 +737,17 @@ class ASRServerExecutor(ASRExecutor):
raise Exception(f"not support: {self.model_type}") raise Exception(f"not support: {self.model_type}")
def init_model(self) -> None: def init_model(self) -> None:
if "deepspeech2" in self.model_type : if "deepspeech2" in self.model_type:
# AM predictor # AM predictor
logger.info("ASR engine start to init the am predictor") logger.info("ASR engine start to init the am predictor")
self.am_predictor = init_predictor( self.am_predictor = init_predictor(
model_file=self.am_model, model_file=self.am_model,
params_file=self.am_params, params_file=self.am_params,
predictor_conf=self.am_predictor_conf) predictor_conf=self.am_predictor_conf)
elif "conformer" in self.model_type or "transformer" in self.model_type : elif "conformer" in self.model_type or "transformer" in self.model_type:
# load model # load model
# model_type: {model_name}_{dataset} # model_type: {model_name}_{dataset}
model_name = self.model_type[:self.model_type.rindex('_')] model_name = self.model_type[:self.model_type.rindex('_')]
logger.info(f"model name: {model_name}") logger.info(f"model name: {model_name}")
model_class = self.task_resource.get_model_class(model_name) model_class = self.task_resource.get_model_class(model_name)
model = model_class.from_config(self.config) model = model_class.from_config(self.config)
...@@ -758,7 +757,6 @@ class ASRServerExecutor(ASRExecutor): ...@@ -758,7 +757,6 @@ class ASRServerExecutor(ASRExecutor):
else: else:
raise Exception(f"not support: {self.model_type}") raise Exception(f"not support: {self.model_type}")
def _init_from_path(self, def _init_from_path(self,
model_type: str=None, model_type: str=None,
am_model: Optional[os.PathLike]=None, am_model: Optional[os.PathLike]=None,
...@@ -786,7 +784,6 @@ class ASRServerExecutor(ASRExecutor): ...@@ -786,7 +784,6 @@ class ASRServerExecutor(ASRExecutor):
self.am_predictor_conf = am_predictor_conf self.am_predictor_conf = am_predictor_conf
logger.info(f"model_type: {self.model_type}") logger.info(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(model_tag=tag) self.task_resource.set_task_model(model_tag=tag)
...@@ -831,7 +828,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -831,7 +828,7 @@ class ASRServerExecutor(ASRExecutor):
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.update_config() self.update_config()
# AM predictor # AM predictor
self.init_model() self.init_model()
...@@ -850,7 +847,6 @@ class ASREngine(BaseEngine): ...@@ -850,7 +847,6 @@ class ASREngine(BaseEngine):
super(ASREngine, self).__init__() super(ASREngine, self).__init__()
logger.info("create the online asr engine resource instance") logger.info("create the online asr engine resource instance")
def init_model(self) -> bool: def init_model(self) -> bool:
if not self.executor._init_from_path( if not self.executor._init_from_path(
model_type=self.config.model_type, model_type=self.config.model_type,
...@@ -865,7 +861,6 @@ class ASREngine(BaseEngine): ...@@ -865,7 +861,6 @@ class ASREngine(BaseEngine):
return False return False
return True return True
def init(self, config: dict) -> bool: def init(self, config: dict) -> bool:
"""init engine resource """init engine resource
......
...@@ -28,6 +28,9 @@ class EngineFactory(object): ...@@ -28,6 +28,9 @@ class EngineFactory(object):
elif engine_name == 'asr' and engine_type == 'online': elif engine_name == 'asr' and engine_type == 'online':
from paddlespeech.server.engine.asr.online.python.asr_engine import ASREngine from paddlespeech.server.engine.asr.online.python.asr_engine import ASREngine
return ASREngine() return ASREngine()
elif engine_name == 'asr' and engine_type == 'online-inference':
from paddlespeech.server.engine.asr.online.paddleinference.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'asr' and engine_type == 'online-onnx': elif engine_name == 'asr' and engine_type == 'online-onnx':
from paddlespeech.server.engine.asr.online.onnx.asr_engine import ASREngine from paddlespeech.server.engine.asr.online.onnx.asr_engine import ASREngine
return ASREngine() return ASREngine()
......
...@@ -92,7 +92,7 @@ async def websocket_endpoint(websocket: WebSocket): ...@@ -92,7 +92,7 @@ async def websocket_endpoint(websocket: WebSocket):
else: else:
resp = {"status": "ok", "message": "no valid json data"} resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp) await websocket.send_json(resp)
elif "bytes" in message: elif "bytes" in message:
# bytes for the pcm data # bytes for the pcm data
message = message["bytes"] message = message["bytes"]
......
...@@ -747,7 +747,7 @@ def num2chn(number_string, ...@@ -747,7 +747,7 @@ def num2chn(number_string,
previous_symbol, (CNU, type(None))): previous_symbol, (CNU, type(None))):
if next_symbol.power != 1 and ( if next_symbol.power != 1 and (
(previous_symbol is None) or (previous_symbol is None) or
(previous_symbol.power != 1)): (previous_symbol.power != 1)): # noqa: E129
result_symbols[i] = liang result_symbols[i] = liang
# if big is True, '两' will not be used and `alt_two` has no impact on output # if big is True, '两' will not be used and `alt_two` has no impact on output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册