未验证 提交 1f5f34a8 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #2016 from Jackwaterveg/develop_dev

[ASR] Support editing num_decode_left_chunks in cli and server
......@@ -28,6 +28,7 @@ asr_online:
sample_rate: 16000
cfg_path:
decode_method:
num_decoding_left_chunks: -1
force_yes: True
device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
......
......@@ -32,7 +32,7 @@ asr_online:
device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected
num_decoding_left_chunks: -1
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
......
......@@ -28,6 +28,7 @@ asr_online:
sample_rate: 16000
cfg_path:
decode_method:
num_decoding_left_chunks:
force_yes: True
device: 'cpu' # cpu or gpu:id
......
......@@ -83,6 +83,12 @@ class ASRExecutor(BaseExecutor):
'attention_rescoring'
],
help='only support transformer and conformer model')
self.parser.add_argument(
'--num_decoding_left_chunks',
'-num_left',
type=str,
default=-1,
help='only support transformer and conformer model')
self.parser.add_argument(
'--ckpt_path',
type=str,
......@@ -122,6 +128,7 @@ class ASRExecutor(BaseExecutor):
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
decode_method: str='attention_rescoring',
num_decoding_left_chunks: int=-1,
ckpt_path: Optional[os.PathLike]=None):
"""
Init model and other resources from a specific path.
......@@ -129,6 +136,7 @@ class ASRExecutor(BaseExecutor):
logger.info("start to init the model")
# default max_len: unit:second
self.max_len = 50
assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
return
......@@ -179,6 +187,7 @@ class ASRExecutor(BaseExecutor):
elif "conformer" in model_type or "transformer" in model_type:
self.config.decode.decoding_method = decode_method
self.config.num_decoding_left_chunks = num_decoding_left_chunks
else:
raise Exception("wrong type")
......@@ -451,6 +460,7 @@ class ASRExecutor(BaseExecutor):
config: os.PathLike=None,
ckpt_path: os.PathLike=None,
decode_method: str='attention_rescoring',
num_decoding_left_chunks: int=-1,
force_yes: bool=False,
rtf: bool=False,
device=paddle.get_device()):
......@@ -460,7 +470,7 @@ class ASRExecutor(BaseExecutor):
audio_file = os.path.abspath(audio_file)
paddle.set_device(device)
self._init_from_path(model, lang, sample_rate, config, decode_method,
ckpt_path)
num_decoding_left_chunks, ckpt_path)
if not self._check(audio_file, sample_rate, force_yes):
sys.exit(-1)
if rtf:
......
......@@ -28,6 +28,7 @@ asr_online:
sample_rate: 16000
cfg_path:
decode_method:
num_decoding_left_chunks:
force_yes: True
device: # cpu or gpu:id
......
......@@ -28,6 +28,7 @@ asr_online:
sample_rate: 16000
cfg_path:
decode_method:
num_decoding_left_chunks: -1
force_yes: True
device: # cpu or gpu:id
continuous_decoding: True # enable continue decoding when endpoint detected
......@@ -44,4 +45,4 @@ asr_online:
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2
\ No newline at end of file
sample_width: 2
......@@ -705,6 +705,7 @@ class ASRServerExecutor(ASRExecutor):
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
decode_method: str='attention_rescoring',
num_decoding_left_chunks: int=-1,
am_predictor_conf: dict=None):
"""
Init model and other resources from a specific path.
......@@ -790,7 +791,10 @@ class ASRServerExecutor(ASRExecutor):
# update the decoding method
if decode_method:
self.config.decode.decoding_method = decode_method
# update num_decoding_left_chunks
if num_decoding_left_chunks:
self.config.decode.num_decoding_left_chunks = num_decoding_left_chunks
assert self.config.decode.num_decoding_left_chunks == -1 or self.config.decode.num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0"
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if self.config.decode.decoding_method not in [
......@@ -864,6 +868,7 @@ class ASREngine(BaseEngine):
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
am_predictor_conf=self.config.am_predictor_conf):
logger.error(
"Init the ASR server occurs error, please check the server configuration yaml"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册