“5c18decf3c680f44121c77682f884be04791f30b”上不存在“develop/api_doc/fluid/nets.html”
未验证 提交 bd66c7a8 编写于 作者: Honei_X's avatar Honei_X 提交者: GitHub

Merge pull request #1905 from Honei/develop

[asr][server]add check asr server model type
...@@ -29,7 +29,8 @@ asr_online: ...@@ -29,7 +29,8 @@ asr_online:
cfg_path: cfg_path:
decode_method: decode_method:
force_yes: True force_yes: True
device: cpu # cpu or gpu:id device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
am_predictor_conf: am_predictor_conf:
device: # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True switch_ir_optim: True
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
import os import os
import sys
from typing import Optional from typing import Optional
import numpy as np import numpy as np
...@@ -588,7 +589,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -588,7 +589,7 @@ class ASRServerExecutor(ASRExecutor):
self.pretrained_models = pretrained_models self.pretrained_models = pretrained_models
def _init_from_path(self, def _init_from_path(self,
model_type: str='deepspeech2online_aishell', model_type: str=None,
am_model: Optional[os.PathLike]=None, am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None,
lang: str='zh', lang: str='zh',
...@@ -599,6 +600,12 @@ class ASRServerExecutor(ASRExecutor): ...@@ -599,6 +600,12 @@ class ASRServerExecutor(ASRExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if not model_type or not lang or not sample_rate:
logger.error(
"The model type or lang or sample rate is None, please input an valid server parameter yaml"
)
return False
self.model_type = model_type self.model_type = model_type
self.sample_rate = sample_rate self.sample_rate = sample_rate
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
...@@ -731,6 +738,8 @@ class ASRServerExecutor(ASRExecutor): ...@@ -731,6 +738,8 @@ class ASRServerExecutor(ASRExecutor):
self.searcher = CTCPrefixBeamSearch(self.config.decode) self.searcher = CTCPrefixBeamSearch(self.config.decode)
self.transformer_decode_reset() self.transformer_decode_reset()
return True
def reset_decoder_and_chunk(self): def reset_decoder_and_chunk(self):
"""reset decoder and chunk state for an new audio """reset decoder and chunk state for an new audio
""" """
...@@ -1028,12 +1037,15 @@ class ASREngine(BaseEngine): ...@@ -1028,12 +1037,15 @@ class ASREngine(BaseEngine):
self.device = paddle.get_device() self.device = paddle.get_device()
logger.info(f"paddlespeech_server set the device: {self.device}") logger.info(f"paddlespeech_server set the device: {self.device}")
paddle.set_device(self.device) paddle.set_device(self.device)
except BaseException: except BaseException as e:
logger.error( logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file" f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
) )
logger.error(
"If all GPU or XPU is used, you can set the server to 'cpu'")
sys.exit(-1)
self.executor._init_from_path( if not self.executor._init_from_path(
model_type=self.config.model_type, model_type=self.config.model_type,
am_model=self.config.am_model, am_model=self.config.am_model,
am_params=self.config.am_params, am_params=self.config.am_params,
...@@ -1041,7 +1053,11 @@ class ASREngine(BaseEngine): ...@@ -1041,7 +1053,11 @@ class ASREngine(BaseEngine):
sample_rate=self.config.sample_rate, sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path, cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method, decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf) am_predictor_conf=self.config.am_predictor_conf):
logger.error(
"Init the ASR server occurs error, please check the server configuration yaml"
)
return False
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully.")
return True return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册