diff --git a/demos/streaming_asr_server/conf/ws_conformer_application.yaml b/demos/streaming_asr_server/conf/ws_conformer_application.yaml index 9dbc82b6f4e6963c3726e3d789c195f56110e732..01bb1e9c90a2f268d070ac84b4aec9c5476bbccb 100644 --- a/demos/streaming_asr_server/conf/ws_conformer_application.yaml +++ b/demos/streaming_asr_server/conf/ws_conformer_application.yaml @@ -28,6 +28,7 @@ asr_online: sample_rate: 16000 cfg_path: decode_method: + num_decoding_left_chunks: -1 force_yes: True device: 'cpu' # cpu or gpu:id decode_method: "attention_rescoring" diff --git a/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml b/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml index 683d86f03cdc6a9ad0cf740ef9b4d491dd5d2ac9..d30bcd0252eeffa871ee92c11b77699d6f839093 100644 --- a/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml +++ b/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml @@ -32,7 +32,7 @@ asr_online: device: 'cpu' # cpu or gpu:id decode_method: "attention_rescoring" continuous_decoding: True # enable continue decoding when endpoint detected - + num_decoding_left_chunks: -1 am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/demos/streaming_asr_server/conf/ws_ds2_application.yaml b/demos/streaming_asr_server/conf/ws_ds2_application.yaml index f2ea6330f690801182f457ba1170207a12e14b18..d19bd26dc1b4d1a45f9fb797f9f4e749099948ad 100644 --- a/demos/streaming_asr_server/conf/ws_ds2_application.yaml +++ b/demos/streaming_asr_server/conf/ws_ds2_application.yaml @@ -28,6 +28,7 @@ asr_online: sample_rate: 16000 cfg_path: decode_method: + num_decoding_left_chunks: force_yes: True device: 'cpu' # cpu or gpu:id diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index f26901a1e42e3d9598eb71417961756ced9ef515..ad83bc20e79d08c5211b2a0f774099ac195d3846 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -83,6 +83,12 @@ class ASRExecutor(BaseExecutor): 'attention_rescoring' ], help='only support transformer and conformer model') + self.parser.add_argument( + '--num_decoding_left_chunks', + '-num_left', + type=str, + default=-1, + help='only support transformer and conformer model') self.parser.add_argument( '--ckpt_path', type=str, @@ -122,6 +128,7 @@ class ASRExecutor(BaseExecutor): sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, decode_method: str='attention_rescoring', + num_decoding_left_chunks: int=-1, ckpt_path: Optional[os.PathLike]=None): """ Init model and other resources from a specific path. @@ -129,6 +136,7 @@ class ASRExecutor(BaseExecutor): logger.info("start to init the model") # default max_len: unit:second self.max_len = 50 + assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0 if hasattr(self, 'model'): logger.info('Model had been initialized.') return @@ -179,6 +187,7 @@ class ASRExecutor(BaseExecutor): elif "conformer" in model_type or "transformer" in model_type: self.config.decode.decoding_method = decode_method + self.config.num_decoding_left_chunks = num_decoding_left_chunks else: raise Exception("wrong type") @@ -451,6 +460,7 @@ class ASRExecutor(BaseExecutor): config: os.PathLike=None, ckpt_path: os.PathLike=None, decode_method: str='attention_rescoring', + num_decoding_left_chunks: int=-1, force_yes: bool=False, rtf: bool=False, device=paddle.get_device()): @@ -460,7 +470,7 @@ class ASRExecutor(BaseExecutor): audio_file = os.path.abspath(audio_file) paddle.set_device(device) self._init_from_path(model, lang, sample_rate, config, decode_method, - ckpt_path) + num_decoding_left_chunks, ckpt_path) if not self._check(audio_file, sample_rate, force_yes): sys.exit(-1) if rtf: diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index d6f5a227c5359b9c4b8dc91f897619962cba8706..43d83f2d46bac910f018bb2b6270b25a84944480 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -28,6 +28,7 @@ asr_online: sample_rate: 16000 cfg_path: decode_method: + num_decoding_left_chunks: force_yes: True device: # cpu or gpu:id diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml index dd5e67ca35a05243f7c5f722dda62b8b802f4578..d72eb2379ad291a552e551101c894c2429c7738a 100644 --- a/paddlespeech/server/conf/ws_conformer_application.yaml +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -28,6 +28,7 @@ asr_online: sample_rate: 16000 cfg_path: decode_method: + num_decoding_left_chunks: -1 force_yes: True device: # cpu or gpu:id continuous_decoding: True # enable continue decoding when endpoint detected @@ -44,4 +45,4 @@ asr_online: window_ms: 25 # ms shift_ms: 10 # ms sample_rate: 16000 - sample_width: 2 \ No newline at end of file + sample_width: 2 diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 8fc210e5a32b3c4e67192d47842d87f5a1bb6b1b..3eefa9d72782578b4e2b2eebf15f202f4419f0db 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -703,6 +703,7 @@ class ASRServerExecutor(ASRExecutor): sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, decode_method: str='attention_rescoring', + num_decoding_left_chunks: int=-1, am_predictor_conf: dict=None): """ Init model and other resources from a specific path. @@ -788,7 +789,10 @@ class ASRServerExecutor(ASRExecutor): # update the decoding method if decode_method: self.config.decode.decoding_method = decode_method - + # update num_decoding_left_chunks + if num_decoding_left_chunks: + self.config.decode.num_decoding_left_chunks = num_decoding_left_chunks + assert self.config.decode.num_decoding_left_chunks == -1 or self.config.decode.num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0" # we only support ctc_prefix_beam_search and attention_rescoring dedoding method # Generally we set the decoding_method to attention_rescoring if self.config.decode.decoding_method not in [ @@ -862,6 +866,7 @@ class ASREngine(BaseEngine): sample_rate=self.config.sample_rate, cfg_path=self.config.cfg_path, decode_method=self.config.decode_method, + num_decoding_left_chunks=self.config.num_decoding_left_chunks, am_predictor_conf=self.config.am_predictor_conf): logger.error( "Init the ASR server occurs error, please check the server configuration yaml"