提交 67939d0d 编写于 作者: X xiongxinlei

add check asr server model type, test=doc

上级 15271445
......@@ -29,7 +29,8 @@ asr_online:
cfg_path:
decode_method:
force_yes: True
device: cpu # cpu or gpu:id
device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
......@@ -42,4 +43,4 @@ asr_online:
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2
\ No newline at end of file
sample_width: 2
......@@ -13,6 +13,7 @@
# limitations under the License.
import copy
import os
import sys
from typing import Optional
import numpy as np
......@@ -588,7 +589,7 @@ class ASRServerExecutor(ASRExecutor):
self.pretrained_models = pretrained_models
def _init_from_path(self,
model_type: str='deepspeech2online_aishell',
model_type: str=None,
am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None,
lang: str='zh',
......@@ -599,6 +600,12 @@ class ASRServerExecutor(ASRExecutor):
"""
Init model and other resources from a specific path.
"""
if not model_type or not lang or not sample_rate:
logger.error(
"The model type or lang or sample rate is None, please input an valid server parameter yaml"
)
return False
self.model_type = model_type
self.sample_rate = sample_rate
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
......@@ -1028,20 +1035,27 @@ class ASREngine(BaseEngine):
self.device = paddle.get_device()
logger.info(f"paddlespeech_server set the device: {self.device}")
paddle.set_device(self.device)
except BaseException:
except BaseException as e:
logger.error(
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
)
self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf)
logger.error(
"If all GPU or XPU is used, you can set the server to 'cpu'")
sys.exit(-1)
if not self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf):
logger.error(
"Init the ASR server occurs error, please check the server configuration yaml"
)
return False
logger.info("Initialize ASR server engine successfully.")
return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册