From 380afbbc5d828f81204a5b9ab9088d4491ba0b70 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 16:18:42 +0800 Subject: [PATCH] add ds2 model multi session, test=doc --- paddlespeech/server/conf/ws_application.yaml | 50 +--- .../server/conf/ws_conformer_application.yaml | 45 ++++ .../server/engine/asr/online/asr_engine.py | 224 ++++++++++++++++-- 3 files changed, 263 insertions(+), 56 deletions(-) create mode 100644 paddlespeech/server/conf/ws_conformer_application.yaml diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index aa3c208b..dae4a3ff 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -18,44 +18,10 @@ engine_list: ['asr_online'] # ENGINE CONFIG # ################################################################################# -# ################################### ASR ######################################### -# ################### speech task: asr; engine_type: online ####################### -# asr_online: -# model_type: 'deepspeech2online_aishell' -# am_model: # the pdmodel file of am static model [optional] -# am_params: # the pdiparams file of am static model [optional] -# lang: 'zh' -# sample_rate: 16000 -# cfg_path: -# decode_method: -# force_yes: True - -# am_predictor_conf: -# device: # set 'gpu:id' or 'cpu' -# switch_ir_optim: True -# glog_info: False # True -> print glog -# summary: True # False -> do not show predictor config - -# chunk_buffer_conf: -# frame_duration_ms: 80 -# shift_ms: 40 -# sample_rate: 16000 -# sample_width: 2 - -# vad_conf: -# aggressiveness: 2 -# sample_rate: 16000 -# frame_duration_ms: 20 -# sample_width: 2 -# padding_ms: 200 -# padding_ratio: 0.9 - - - ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### asr_online: - model_type: 'conformer2online_aishell' + model_type: 'deepspeech2online_aishell' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' @@ -71,9 +37,19 @@ asr_online: summary: True # False -> do not show predictor config chunk_buffer_conf: + frame_duration_ms: 80 + shift_ms: 40 + sample_rate: 16000 + sample_width: 2 window_n: 7 # frame shift_n: 4 # frame - window_ms: 25 # ms + window_ms: 20 # ms shift_ms: 10 # ms + + vad_conf: + aggressiveness: 2 sample_rate: 16000 - sample_width: 2 \ No newline at end of file + frame_duration_ms: 20 + sample_width: 2 + padding_ms: 200 + padding_ratio: 0.9 \ No newline at end of file diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml new file mode 100644 index 00000000..1a775f85 --- /dev/null +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -0,0 +1,45 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 0.0.0.0 +port: 8090 + +# The task format in the engin_list is: _ +# task choices = ['asr_online', 'tts_online'] +# protocol = ['websocket', 'http'] (only one can be selected). +# websocket only support online engine type. +protocol: 'websocket' +engine_list: ['asr_online'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### ASR ######################################### +################### speech task: asr; engine_type: online ####################### +asr_online: + model_type: 'conformer2online_aishell' + am_model: # the pdmodel file of am static model [optional] + am_params: # the pdiparams file of am static model [optional] + lang: 'zh' + sample_rate: 16000 + cfg_path: + decode_method: + force_yes: True + + am_predictor_conf: + device: # set 'gpu:id' or 'cpu' + switch_ir_optim: True + glog_info: False # True -> print glog + summary: True # False -> do not show predictor config + + chunk_buffer_conf: + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 25 # ms + shift_ms: 10 # ms + sample_rate: 16000 + sample_width: 2 \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index a8e25f4b..77eb5a21 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -13,7 +13,7 @@ # limitations under the License. import os from typing import Optional - +import copy import numpy as np import paddle from numpy import float32 @@ -93,7 +93,7 @@ class PaddleASRConnectionHanddler: ) self.config = asr_engine.config self.model_config = asr_engine.executor.config - self.model = asr_engine.executor.model + # self.model = asr_engine.executor.model self.asr_engine = asr_engine self.init() @@ -102,7 +102,32 @@ class PaddleASRConnectionHanddler: def init(self): self.model_type = self.asr_engine.executor.model_type if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: - pass + from paddlespeech.s2t.io.collator import SpeechCollator + self.sample_rate = self.asr_engine.executor.sample_rate + self.am_predictor = self.asr_engine.executor.am_predictor + self.text_feature = self.asr_engine.executor.text_feature + self.collate_fn_test = SpeechCollator.from_config(self.model_config) + self.decoder = CTCDecoder( + odim=self.model_config.output_dim, # is in vocab + enc_n_units=self.model_config.rnn_layer_size * 2, + blank_id=self.model_config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.model_config.get('ctc_grad_norm_type', None)) + + cfg = self.model_config.decode + decode_batch_size = 1 # for online + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + # frame window samples length and frame shift samples length + + self.win_length = int(self.model_config.window_ms * self.sample_rate) + self.n_shift = int(self.model_config.stride_ms * self.sample_rate) + elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: self.sample_rate = self.asr_engine.executor.sample_rate @@ -127,7 +152,65 @@ class PaddleASRConnectionHanddler: def extract_feat(self, samples): if "deepspeech2online" in self.model_type: - pass + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 + + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The connection remain the audio samples: {self.remained_wav.shape}" + ) + + # pcm16 -> pcm 32 + samples = pcm2float(self.remained_wav) + # read audio + speech_segment = SpeechSegment.from_pcm( + samples, self.sample_rate, transcript=" ") + # audio augment + self.collate_fn_test.augmentation.transform_audio(speech_segment) + + # extract speech feature + spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( + speech_segment, self.collate_fn_test.keep_transcription_text) + # CMVN spectrum + if self.collate_fn_test._normalizer: + spectrum = self.collate_fn_test._normalizer.apply(spectrum) + + # spectrum augment + audio = self.collate_fn_test.augmentation.transform_feature( + spectrum) + + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + # audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + + if self.cached_feat is None: + self.cached_feat = audio + else: + assert (len(audio.shape) == 3) + assert (len(self.cached_feat.shape) == 3) + self.cached_feat = paddle.concat( + [self.cached_feat, audio], axis=1) + + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + + self.num_frames += audio_len + self.remained_wav = self.remained_wav[self.n_shift * audio_len:] + + logger.info( + f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" + ) elif "conformer2online" in self.model_type: logger.info("Online ASR extract the feat") samples = np.frombuffer(samples, dtype=np.int16) @@ -179,24 +262,81 @@ class PaddleASRConnectionHanddler: # logger.info(f"accumulate samples: {self.num_samples}") def reset(self): - self.subsampling_cache = None - self.elayers_output_cache = None - self.conformer_cnn_cache = None - self.encoder_out = None - self.cached_feat = None - self.remained_wav = None - self.offset = 0 - self.num_samples = 0 - self.device = None - self.hyps = [] - self.num_frames = 0 - self.chunk_num = 0 - self.global_frame_offset = 0 - self.result_transcripts = [''] + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + # for deepspeech2 + self.chunk_state_h_box = copy.deepcopy(self.asr_engine.executor.chunk_state_h_box) + self.chunk_state_c_box = copy.deepcopy(self.asr_engine.executor.chunk_state_c_box) + self.decoder.reset_decoder(batch_size=1) + elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: + # for conformer online + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_out = None + self.cached_feat = None + self.remained_wav = None + self.offset = 0 + self.num_samples = 0 + self.device = None + self.hyps = [] + self.num_frames = 0 + self.chunk_num = 0 + self.global_frame_offset = 0 + self.result_transcripts = [''] def decode(self, is_finished=False): if "deepspeech2online" in self.model_type: - pass + # x_chunk 是特征数据 + decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model + context = 7 # context=7 in deepspeech2 model + subsampling = 4 # subsampling=4 in deepspeech2 model + stride = subsampling * decoding_chunk_size + cached_feature_num = context - subsampling + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + + if self.cached_feat is None: + logger.info("no audio feat, please input more pcm data") + return + + num_frames = self.cached_feat.shape[1] + logger.info( + f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" + ) + # the cached feat must be larger decoding_window + if num_frames < decoding_window and not is_finished: + logger.info( + f"frame feat num is less than {decoding_window}, please input more pcm data" + ) + return None, None + + # if is_finished=True, we need at least context frames + if num_frames < context: + logger.info( + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + ) + return None, None + logger.info("start to do model forward") + # num_frames - context + 1 ensure that current frame can get context window + if is_finished: + # if get the finished chunk, we need process the last context + left_frames = context + else: + # we only process decoding_window frames for one chunk + left_frames = decoding_window + + for cur in range(0, num_frames - left_frames + 1, stride): + end = min(cur + decoding_window, num_frames) + # extract the audio + x_chunk = self.cached_feat[:, cur:end, :].numpy() + x_chunk_lens = np.array([x_chunk.shape[1]]) + trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) + + self.result_transcripts = [trans_best] + + self.cached_feat = self.cached_feat[:, end - + cached_feature_num:, :] + # return trans_best[0] elif "conformer" in self.model_type or "transformer" in self.model_type: try: logger.info( @@ -210,6 +350,48 @@ class PaddleASRConnectionHanddler: else: raise Exception("invalid model name") + def decode_one_chunk(self, x_chunk, x_chunk_lens): + logger.info("start to decoce one chunk with deepspeech2 model") + input_names = self.am_predictor.get_input_names() + audio_handle = self.am_predictor.get_input_handle(input_names[0]) + audio_len_handle = self.am_predictor.get_input_handle( + input_names[1]) + h_box_handle = self.am_predictor.get_input_handle(input_names[2]) + c_box_handle = self.am_predictor.get_input_handle(input_names[3]) + + audio_handle.reshape(x_chunk.shape) + audio_handle.copy_from_cpu(x_chunk) + + audio_len_handle.reshape(x_chunk_lens.shape) + audio_len_handle.copy_from_cpu(x_chunk_lens) + + h_box_handle.reshape(self.chunk_state_h_box.shape) + h_box_handle.copy_from_cpu(self.chunk_state_h_box) + + c_box_handle.reshape(self.chunk_state_c_box.shape) + c_box_handle.copy_from_cpu(self.chunk_state_c_box) + + output_names = self.am_predictor.get_output_names() + output_handle = self.am_predictor.get_output_handle(output_names[0]) + output_lens_handle = self.am_predictor.get_output_handle( + output_names[1]) + output_state_h_handle = self.am_predictor.get_output_handle( + output_names[2]) + output_state_c_handle = self.am_predictor.get_output_handle( + output_names[3]) + + self.am_predictor.run() + + output_chunk_probs = output_handle.copy_to_cpu() + output_chunk_lens = output_lens_handle.copy_to_cpu() + self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() + self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() + + self.decoder.next(output_chunk_probs, output_chunk_lens) + trans_best, trans_beam = self.decoder.decode() + logger.info(f"decode one one best result: {trans_best[0]}") + return trans_best[0] + def advance_decoding(self, is_finished=False): logger.info("start to decode with advanced_decoding method") cfg = self.ctc_decode_config @@ -240,6 +422,7 @@ class PaddleASRConnectionHanddler: ) return None, None + # if is_finished=True, we need at least context frames if num_frames < context: logger.info( "flast {num_frames} is less than context {context} frames, and we cannot do model forward" @@ -315,6 +498,9 @@ class PaddleASRConnectionHanddler: return '' def rescoring(self): + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + return + logger.info("rescoring the final result") if "attention_rescoring" != self.ctc_decode_config.decoding_method: return -- GitLab