From 6f7917b7f2b489b8341aeda2c8ff318975b84f78 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 25 May 2022 09:25:17 +0000 Subject: [PATCH] fix streaming asr --- .../conf/ws_conformer_application.yaml | 2 +- ...plication.yaml => ws_ds2_application.yaml} | 0 .../server/engine/asr/online/asr_engine.py | 53 ++++--------------- 3 files changed, 12 insertions(+), 43 deletions(-) rename demos/streaming_asr_server/conf/{ws_application.yaml => ws_ds2_application.yaml} (100%) diff --git a/demos/streaming_asr_server/conf/ws_conformer_application.yaml b/demos/streaming_asr_server/conf/ws_conformer_application.yaml index 2affde07..6a10741b 100644 --- a/demos/streaming_asr_server/conf/ws_conformer_application.yaml +++ b/demos/streaming_asr_server/conf/ws_conformer_application.yaml @@ -4,7 +4,7 @@ # SERVER SETTING # ################################################################################# host: 0.0.0.0 -port: 8090 +port: 8091 # The task format in the engin_list is: _ # task choices = ['asr_online'] diff --git a/demos/streaming_asr_server/conf/ws_application.yaml b/demos/streaming_asr_server/conf/ws_ds2_application.yaml similarity index 100% rename from demos/streaming_asr_server/conf/ws_application.yaml rename to demos/streaming_asr_server/conf/ws_ds2_application.yaml diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 70bfcfb6..d7bd458f 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -53,7 +53,7 @@ class PaddleASRConnectionHanddler: logger.info( "create an paddle asr connection handler to process the websocket connection" ) - self.config = asr_engine.config + self.config = asr_engine.config # server config self.model_config = asr_engine.executor.config self.asr_engine = asr_engine @@ -249,10 +249,13 @@ class PaddleASRConnectionHanddler: def reset(self): if "deepspeech2" in self.model_type: # for deepspeech2 - self.chunk_state_h_box = copy.deepcopy( - self.asr_engine.executor.chunk_state_h_box) - self.chunk_state_c_box = copy.deepcopy( - self.asr_engine.executor.chunk_state_c_box) + # init state + self.chunk_state_h_box = np.zeros( + (self.model_config .num_rnn_layers, 1, self.model_config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.model_config.num_rnn_layers, 1, self.model_config.rnn_layer_size), + dtype=float32) self.decoder.reset_decoder(batch_size=1) self.device = None @@ -803,36 +806,6 @@ class ASRServerExecutor(ASRExecutor): model_file=self.am_model, params_file=self.am_params, predictor_conf=self.am_predictor_conf) - - # decoder - logger.info("ASR engine start to create the ctc decoder instance") - self.decoder = CTCDecoder( - odim=self.config.output_dim, # is in vocab - enc_n_units=self.config.rnn_layer_size * 2, - blank_id=self.config.blank_id, - dropout_rate=0.0, - reduction=True, # sum - batch_average=True, # sum / batch_size - grad_norm_type=self.config.get('ctc_grad_norm_type', None)) - - # init decoder - logger.info("ASR engine start to init the ctc decoder") - cfg = self.config.decode - decode_batch_size = 1 # for online - self.decoder.init_decoder( - decode_batch_size, self.text_feature.vocab_list, - cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, - cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, - cfg.num_proc_bsearch) - - # init state box - self.chunk_state_h_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - self.chunk_state_c_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - elif "conformer" in model_type or "transformer" in model_type: model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} @@ -847,15 +820,11 @@ class ASRServerExecutor(ASRExecutor): model_dict = paddle.load(self.am_model) self.model.set_state_dict(model_dict) logger.info("create the transformer like model success") - - # update the ctc decoding - self.searcher = CTCPrefixBeamSearch(self.config.decode) - self.transformer_decode_reset() else: raise ValueError(f"Not support: {model_type}") return True - + class ASREngine(BaseEngine): """ASR server resource @@ -881,8 +850,8 @@ class ASREngine(BaseEngine): self.executor = ASRServerExecutor() try: - default_dev = paddle.get_device() - paddle.set_device(self.config.get("device", default_dev)) + self.device = self.config.get("device", paddle.get_device()) + paddle.set_device(self.device) except BaseException as e: logger.error( f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file" -- GitLab