提交 67939d0d 编写于 作者: X xiongxinlei

add check asr server model type, test=doc

上级 15271445
...@@ -29,7 +29,8 @@ asr_online: ...@@ -29,7 +29,8 @@ asr_online:
cfg_path: cfg_path:
decode_method: decode_method:
force_yes: True force_yes: True
device: cpu # cpu or gpu:id device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
am_predictor_conf: am_predictor_conf:
device: # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True switch_ir_optim: True
...@@ -42,4 +43,4 @@ asr_online: ...@@ -42,4 +43,4 @@ asr_online:
window_ms: 25 # ms window_ms: 25 # ms
shift_ms: 10 # ms shift_ms: 10 # ms
sample_rate: 16000 sample_rate: 16000
sample_width: 2 sample_width: 2
\ No newline at end of file
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
import os import os
import sys
from typing import Optional from typing import Optional
import numpy as np import numpy as np
...@@ -588,7 +589,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -588,7 +589,7 @@ class ASRServerExecutor(ASRExecutor):
self.pretrained_models = pretrained_models self.pretrained_models = pretrained_models
def _init_from_path(self, def _init_from_path(self,
model_type: str='deepspeech2online_aishell', model_type: str=None,
am_model: Optional[os.PathLike]=None, am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None,
lang: str='zh', lang: str='zh',
...@@ -599,6 +600,12 @@ class ASRServerExecutor(ASRExecutor): ...@@ -599,6 +600,12 @@ class ASRServerExecutor(ASRExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if not model_type or not lang or not sample_rate:
logger.error(
"The model type or lang or sample rate is None, please input an valid server parameter yaml"
)
return False
self.model_type = model_type self.model_type = model_type
self.sample_rate = sample_rate self.sample_rate = sample_rate
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
...@@ -1028,20 +1035,27 @@ class ASREngine(BaseEngine): ...@@ -1028,20 +1035,27 @@ class ASREngine(BaseEngine):
self.device = paddle.get_device() self.device = paddle.get_device()
logger.info(f"paddlespeech_server set the device: {self.device}") logger.info(f"paddlespeech_server set the device: {self.device}")
paddle.set_device(self.device) paddle.set_device(self.device)
except BaseException: except BaseException as e:
logger.error( logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file" f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
) )
logger.error(
self.executor._init_from_path( "If all GPU or XPU is used, you can set the server to 'cpu'")
model_type=self.config.model_type, sys.exit(-1)
am_model=self.config.am_model,
am_params=self.config.am_params, if not self.executor._init_from_path(
lang=self.config.lang, model_type=self.config.model_type,
sample_rate=self.config.sample_rate, am_model=self.config.am_model,
cfg_path=self.config.cfg_path, am_params=self.config.am_params,
decode_method=self.config.decode_method, lang=self.config.lang,
am_predictor_conf=self.config.am_predictor_conf) sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf):
logger.error(
"Init the ASR server occurs error, please check the server configuration yaml"
)
return False
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully.")
return True return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册