From 943272385a6b11fafa00e18ab13a30513e8ea9ec Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 23 May 2022 10:19:53 +0000 Subject: [PATCH] refactor asr online --- .../server/engine/asr/online/asr_engine.py | 581 ++++++------------ 1 file changed, 200 insertions(+), 381 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index fd57a3d5..3b88cabb 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -38,7 +38,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor -__all__ = ['ASREngine'] +__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine'] # ASR server connection process class @@ -67,7 +67,7 @@ class PaddleASRConnectionHanddler: # tokens to text self.text_feature = self.asr_engine.executor.text_feature - if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + if "deepspeech2" in self.model_type: from paddlespeech.s2t.io.collator import SpeechCollator self.am_predictor = self.asr_engine.executor.am_predictor @@ -89,8 +89,8 @@ class PaddleASRConnectionHanddler: cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) - # frame window samples length and frame shift samples length + # frame window and frame shift, in samples unit self.win_length = int(self.model_config.window_ms / 1000 * self.sample_rate) self.n_shift = int(self.model_config.stride_ms / 1000 * @@ -109,16 +109,15 @@ class PaddleASRConnectionHanddler: self.preprocess_args = {"train": False} self.preprocessing = Transformation(self.preprocess_conf) - # frame window samples length and frame shift samples length + # frame window and frame shift, in samples unit self.win_length = self.preprocess_conf.process[0]['win_length'] self.n_shift = self.preprocess_conf.process[0]['n_shift'] + else: + raise ValueError(f"Not supported: {self.model_type}") def extract_feat(self, samples): - # we compute the elapsed time of first char occuring # and we record the start time at the first pcm sample arraving - # if self.first_char_occur_elapsed is not None: - # self.first_char_occur_elapsed = time.time() if "deepspeech2online" in self.model_type: # self.reamined_wav stores all the samples, @@ -154,28 +153,28 @@ class PaddleASRConnectionHanddler: spectrum = self.collate_fn_test._normalizer.apply(spectrum) # spectrum augment - audio = self.collate_fn_test.augmentation.transform_feature( + feat = self.collate_fn_test.augmentation.transform_feature( spectrum) - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - # audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) + # audio_len is frame num + frame_num = feat.shape[0] + feat = paddle.to_tensor(feat, dtype='float32') + feat = paddle.unsqueeze(feat, axis=0) if self.cached_feat is None: - self.cached_feat = audio + self.cached_feat = feat else: - assert (len(audio.shape) == 3) + assert (len(feat.shape) == 3) assert (len(self.cached_feat.shape) == 3) self.cached_feat = paddle.concat( - [self.cached_feat, audio], axis=1) + [self.cached_feat, feat], axis=1) # set the feat device if self.device is None: self.device = self.cached_feat.place - self.num_frames += audio_len - self.remained_wav = self.remained_wav[self.n_shift * audio_len:] + self.num_frames += frame_num + self.remained_wav = self.remained_wav[self.n_shift * frame_num:] logger.info( f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" @@ -183,25 +182,28 @@ class PaddleASRConnectionHanddler: logger.info( f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" ) + elif "conformer_online" in self.model_type: logger.info("Online ASR extract the feat") samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 - logger.info(f"This package receive {samples.shape[0]} pcm data") self.num_samples += samples.shape[0] + logger.info(f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}") # self.reamined_wav stores all the samples, # include the original remained_wav and this package samples if self.remained_wav is None: self.remained_wav = samples else: - assert self.remained_wav.ndim == 1 + assert self.remained_wav.ndim == 1 # (T,) self.remained_wav = np.concatenate([self.remained_wav, samples]) logger.info( - f"The connection remain the audio samples: {self.remained_wav.shape}" + f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}" ) + if len(self.remained_wav) < self.win_length: + # samples not enough for feature window return 0 # fbank @@ -209,11 +211,13 @@ class PaddleASRConnectionHanddler: **self.preprocess_args) x_chunk = paddle.to_tensor( x_chunk, dtype="float32").unsqueeze(axis=0) + + # feature cache if self.cached_feat is None: self.cached_feat = x_chunk else: - assert (len(x_chunk.shape) == 3) - assert (len(self.cached_feat.shape) == 3) + assert (len(x_chunk.shape) == 3) # (B,T,D) + assert (len(self.cached_feat.shape) == 3) # (B,T,D) self.cached_feat = paddle.concat( [self.cached_feat, x_chunk], axis=1) @@ -221,20 +225,30 @@ class PaddleASRConnectionHanddler: if self.device is None: self.device = self.cached_feat.place + # cur frame step num_frames = x_chunk.shape[1] + + # global frame step self.num_frames += num_frames + + # update remained wav self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + logger.info( - f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" + f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" ) logger.info( - f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" + f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}" ) - # logger.info(f"accumulate samples: {self.num_samples}") + logger.info(f"global samples: {self.num_samples}") + logger.info(f"global frames: {self.num_frames}") + else: + raise ValueError(f"not supported: {self.model_type}") + def reset(self): - if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + if "deepspeech2" in self.model_type: # for deepspeech2 self.chunk_state_h_box = copy.deepcopy( self.asr_engine.executor.chunk_state_h_box) @@ -242,35 +256,63 @@ class PaddleASRConnectionHanddler: self.asr_engine.executor.chunk_state_c_box) self.decoder.reset_decoder(batch_size=1) - # for conformer online + self.device = None + + ## common + + # global sample and frame step + self.num_samples = 0 + self.num_frames = 0 + + # cache for audio and feat + self.remained_wav = None + self.cached_feat = None + + + # partial/ending decoding results + self.result_transcripts = [''] + + ## conformer + + # cache for conformer online self.subsampling_cache = None self.elayers_output_cache = None self.conformer_cnn_cache = None self.encoder_out = None - self.cached_feat = None - self.remained_wav = None - self.offset = 0 - self.num_samples = 0 - self.device = None + # conformer decoding state + self.chunk_num = 0 # globa decoding chunk num + self.offset = 0 # global offset in decoding frame unit self.hyps = [] - self.num_frames = 0 - self.chunk_num = 0 - self.global_frame_offset = 0 - self.result_transcripts = [''] + + # token timestamp result self.word_time_stamp = [] + + # one best timestamp viterbi prob is large. self.time_stamp = [] - self.first_char_occur_elapsed = None + def decode(self, is_finished=False): + """advance decoding + + Args: + is_finished (bool, optional): Is last frame or not. Defaults to False. + + Raises: + Exception: when not support model. + + Returns: + None: nothing + """ if "deepspeech2online" in self.model_type: - # x_chunk 是特征数据 - decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model - context = 7 # context=7 in deepspeech2 model - subsampling = 4 # subsampling=4 in deepspeech2 model - stride = subsampling * decoding_chunk_size + decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit + context = 7 # context=7, in audio frame unit + subsampling = 4 # subsampling=4, in audio frame unit + cached_feature_num = context - subsampling - # decoding window for model + # decoding window for model, in audio frame unit decoding_window = (decoding_chunk_size - 1) * subsampling + context + # decoding stride for model, in audio frame unit + stride = subsampling * decoding_chunk_size if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") @@ -280,6 +322,7 @@ class PaddleASRConnectionHanddler: logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" ) + # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: logger.info( @@ -293,6 +336,7 @@ class PaddleASRConnectionHanddler: "flast {num_frames} is less than context {context} frames, and we cannot do model forward" ) return None, None + logger.info("start to do model forward") # num_frames - context + 1 ensure that current frame can get context window if is_finished: @@ -302,6 +346,7 @@ class PaddleASRConnectionHanddler: # we only process decoding_window frames for one chunk left_frames = decoding_window + end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) # extract the audio @@ -311,7 +356,9 @@ class PaddleASRConnectionHanddler: self.result_transcripts = [trans_best] + # update feat cache self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] + # return trans_best[0] elif "conformer" in self.model_type or "transformer" in self.model_type: try: @@ -326,9 +373,19 @@ class PaddleASRConnectionHanddler: else: raise Exception("invalid model name") + @paddle.no_grad() def decode_one_chunk(self, x_chunk, x_chunk_lens): - logger.info("start to decoce one chunk with deepspeech2 model") + """forward one chunk frames + + Args: + x_chunk (np.ndarray): (B,T,D), audio frames. + x_chunk_lens ([type]): (B,), audio frame lens + + Returns: + logprob: poster probability. + """ + logger.info("start to decoce one chunk for deepspeech2") input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) @@ -365,24 +422,31 @@ class PaddleASRConnectionHanddler: self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() - logger.info(f"decode one best result: {trans_best[0]}") + logger.info(f"decode one best result for deepspeech2: {trans_best[0]}") return trans_best[0] + @paddle.no_grad() def advance_decoding(self, is_finished=False): - logger.info("start to decode with advanced_decoding method") + logger.info("Conformer/Transformer: start to decode with advanced_decoding method") cfg = self.ctc_decode_config + + # cur chunk size, in decoding frame unit decoding_chunk_size = cfg.decoding_chunk_size + # using num of history chunks num_decoding_left_chunks = cfg.num_decoding_left_chunks - assert decoding_chunk_size > 0 + subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 - stride = subsampling * decoding_chunk_size - cached_feature_num = context - subsampling # processed chunk feature cached for next chunk - # decoding window for model + # processed chunk feature cached for next chunk + cached_feature_num = context - subsampling + # decoding stride, in audio frame unit + stride = subsampling * decoding_chunk_size + # decoding window, in audio frame unit decoding_window = (decoding_chunk_size - 1) * subsampling + context + if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") return @@ -407,6 +471,7 @@ class PaddleASRConnectionHanddler: return None, None logger.info("start to do model forward") + # hist of chunks, in deocding frame unit required_cache_size = decoding_chunk_size * num_decoding_left_chunks outputs = [] @@ -423,8 +488,11 @@ class PaddleASRConnectionHanddler: for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) + # global chunk_num self.chunk_num += 1 + # cur chunk chunk_xs = self.cached_feat[:, cur:end, :] + # forward chunk (y, self.subsampling_cache, self.elayers_output_cache, self.conformer_cnn_cache) = self.model.encoder.forward_chunk( chunk_xs, self.offset, required_cache_size, @@ -432,7 +500,7 @@ class PaddleASRConnectionHanddler: self.conformer_cnn_cache) outputs.append(y) - # update the offset + # update the global offset, in decoding frame unit self.offset += y.shape[1] ys = paddle.cat(outputs, 1) @@ -445,12 +513,15 @@ class PaddleASRConnectionHanddler: ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) + # advance decoding self.searcher.search(ctc_probs, self.cached_feat.place) - + # get one best hyps self.hyps = self.searcher.get_one_best_hyps() + assert self.cached_feat.shape[0] == 1 assert end >= cached_feature_num + # advance cache of feat self.cached_feat = self.cached_feat[0, end - cached_feature_num:, :].unsqueeze(0) assert len( @@ -462,50 +533,79 @@ class PaddleASRConnectionHanddler: ) def update_result(self): + """Conformer/Transformer hyps to result. + """ logger.info("update the final result") hyps = self.hyps + + # output results and tokenids self.result_transcripts = [ self.text_feature.defeaturize(hyp) for hyp in hyps ] self.result_tokenids = [hyp for hyp in hyps] def get_result(self): + """return partial/ending asr result. + + Returns: + str: one best result of partial/ending. + """ if len(self.result_transcripts) > 0: return self.result_transcripts[0] else: return '' def get_word_time_stamp(self): + """return token timestamp result. + + Returns: + list: List of ('w':token, 'bg':time, 'ed':time) + """ return self.word_time_stamp + @paddle.no_grad() def rescoring(self): - if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + """Second-Pass Decoding, + only for conformer and transformer model. + """ + if "deepspeech2" in self.model_type: + logger.info("deepspeech2 not support rescoring decoding.") return - - logger.info("rescoring the final result") + if "attention_rescoring" != self.ctc_decode_config.decoding_method: + logger.info(f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring") return + logger.info("rescoring the final result") + + # last decoding for last audio self.searcher.finalize_search() + # update beam search results self.update_result() beam_size = self.ctc_decode_config.beam_size hyps = self.searcher.get_hyps() if hyps is None or len(hyps) == 0: + logger.info("No Hyps!") return + # rescore by decoder post probability + # assert len(hyps) == beam_size + # list of Tensor hyp_list = [] for hyp in hyps: hyp_content = hyp[0] # Prevent the hyp is empty if len(hyp_content) == 0: hyp_content = (self.model.ctc.blank_id, ) + hyp_content = paddle.to_tensor( hyp_content, place=self.device, dtype=paddle.long) hyp_list.append(hyp_content) - hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) + + hyps_pad = pad_sequence(hyp_list, batch_first=True, padding_value=self.model.ignore_id) hyps_lens = paddle.to_tensor( [len(hyp[0]) for hyp in hyps], place=self.device, dtype=paddle.long) # (beam_size,) @@ -531,10 +631,12 @@ class PaddleASRConnectionHanddler: score = 0.0 for j, w in enumerate(hyp[0]): score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.model.eos] # add ctc score (which in ln domain) score += hyp[1] * self.ctc_decode_config.ctc_weight + if score > best_score: best_score = score best_index = i @@ -542,47 +644,57 @@ class PaddleASRConnectionHanddler: # update the one best result # hyps stored the beam results and each fields is: - logger.info(f"best index: {best_index}") + logger.info(f"best hyp index: {best_index}") # logger.info(f'best result: {hyps[best_index]}') # the field of the hyps is: + ## asr results # hyps[0][0]: the sentence word-id in the vocab with a tuple # hyps[0][1]: the sentence decoding probability with all paths + ## timestamp # hyps[0][2]: viterbi_blank ending probability - # hyps[0][3]: viterbi_non_blank probability + # hyps[0][3]: viterbi_non_blank dending probability # hyps[0][4]: current_token_prob, - # hyps[0][5]: times_viterbi_blank, - # hyps[0][6]: times_titerbi_non_blank + # hyps[0][5]: times_viterbi_blank ending timestamp, + # hyps[0][6]: times_titerbi_non_blank encding timestamp. self.hyps = [hyps[best_index][0]] + logger.info(f"best hyp ids: {self.hyps}") # update the hyps time stamp self.time_stamp = hyps[best_index][5] if hyps[best_index][2] > hyps[ best_index][3] else hyps[best_index][6] logger.info(f"time stamp: {self.time_stamp}") + # update one best result self.update_result() # update each word start and end time stamp - frame_shift_in_ms = self.model.encoder.embed.subsampling_rate * self.n_shift / self.sample_rate - logger.info(f"frame shift ms: {frame_shift_in_ms}") + # decoding frame to audio frame + frame_shift = self.model.encoder.embed.subsampling_rate + frame_shift_in_sec = frame_shift * (self.n_shift / self.sample_rate) + logger.info(f"frame shift sec: {frame_shift_in_sec}") + word_time_stamp = [] for idx, _ in enumerate(self.time_stamp): start = (self.time_stamp[idx - 1] + self.time_stamp[idx] ) / 2.0 if idx > 0 else 0 - start = start * frame_shift_in_ms + start = start * frame_shift_in_sec end = (self.time_stamp[idx] + self.time_stamp[idx + 1] ) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset - end = end * frame_shift_in_ms + + end = end * frame_shift_in_sec word_time_stamp.append({ "w": self.result_transcripts[0][idx], "bg": start, "ed": end }) - # logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}") + # logger.info(f"{word_time_stamp[-1]}") + self.word_time_stamp = word_time_stamp logger.info(f"word time stamp: {self.word_time_stamp}") + class ASRServerExecutor(ASRExecutor): def __init__(self): super().__init__() @@ -610,6 +722,7 @@ class ASRServerExecutor(ASRExecutor): self.sample_rate = sample_rate sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str + if cfg_path is None or am_model is None or am_params is None: logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh @@ -628,7 +741,7 @@ class ASRServerExecutor(ASRExecutor): self.am_model = os.path.abspath(am_model) self.am_params = os.path.abspath(am_params) self.res_path = os.path.dirname( - os.path.dirname(os.path.abspath(self.cfg_path))) + os.path.dirname(os.path.abspath(self.cfg_path))) logger.info(self.cfg_path) logger.info(self.am_model) @@ -639,7 +752,7 @@ class ASRServerExecutor(ASRExecutor): self.config.merge_from_file(self.cfg_path) with UpdateConfig(self.config): - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + if "deepspeech2" in model_type: from paddlespeech.s2t.io.collator import SpeechCollator self.vocab = self.config.vocab_filepath self.config.decode.lang_model_path = os.path.join( @@ -655,6 +768,7 @@ class ASRServerExecutor(ASRExecutor): self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) + elif "conformer" in model_type or "transformer" in model_type: logger.info("start to create the stream conformer asr engine") if self.config.spm_model_prefix: @@ -682,7 +796,8 @@ class ASRServerExecutor(ASRExecutor): ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" else: raise Exception("wrong type") - if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + + if "deepspeech2" in model_type: # AM predictor logger.info("ASR engine start to init the am predictor") self.am_predictor_conf = am_predictor_conf @@ -719,6 +834,7 @@ class ASRServerExecutor(ASRExecutor): self.chunk_state_c_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) + elif "conformer" in model_type or "transformer" in model_type: model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} @@ -737,277 +853,14 @@ class ASRServerExecutor(ASRExecutor): # update the ctc decoding self.searcher = CTCPrefixBeamSearch(self.config.decode) self.transformer_decode_reset() - - return True - - def reset_decoder_and_chunk(self): - """reset decoder and chunk state for an new audio - """ - if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: - self.decoder.reset_decoder(batch_size=1) - # init state box, for new audio request - self.chunk_state_h_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - self.chunk_state_c_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - elif "conformer" in self.model_type or "transformer" in self.model_type: - self.transformer_decode_reset() - - def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): - """decode one chunk - - Args: - x_chunk (numpy.array): shape[B, T, D] - x_chunk_lens (numpy.array): shape[B] - model_type (str): online model type - - Returns: - str: one best result - """ - logger.info("start to decoce chunk by chunk") - if "deepspeech2online" in model_type: - input_names = self.am_predictor.get_input_names() - audio_handle = self.am_predictor.get_input_handle(input_names[0]) - audio_len_handle = self.am_predictor.get_input_handle( - input_names[1]) - h_box_handle = self.am_predictor.get_input_handle(input_names[2]) - c_box_handle = self.am_predictor.get_input_handle(input_names[3]) - - audio_handle.reshape(x_chunk.shape) - audio_handle.copy_from_cpu(x_chunk) - - audio_len_handle.reshape(x_chunk_lens.shape) - audio_len_handle.copy_from_cpu(x_chunk_lens) - - h_box_handle.reshape(self.chunk_state_h_box.shape) - h_box_handle.copy_from_cpu(self.chunk_state_h_box) - - c_box_handle.reshape(self.chunk_state_c_box.shape) - c_box_handle.copy_from_cpu(self.chunk_state_c_box) - - output_names = self.am_predictor.get_output_names() - output_handle = self.am_predictor.get_output_handle(output_names[0]) - output_lens_handle = self.am_predictor.get_output_handle( - output_names[1]) - output_state_h_handle = self.am_predictor.get_output_handle( - output_names[2]) - output_state_c_handle = self.am_predictor.get_output_handle( - output_names[3]) - - self.am_predictor.run() - - output_chunk_probs = output_handle.copy_to_cpu() - output_chunk_lens = output_lens_handle.copy_to_cpu() - self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() - self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() - - self.decoder.next(output_chunk_probs, output_chunk_lens) - trans_best, trans_beam = self.decoder.decode() - logger.info(f"decode one best result: {trans_best[0]}") - return trans_best[0] - - elif "conformer" in model_type or "transformer" in model_type: - try: - logger.info( - f"we will use the transformer like model : {self.model_type}" - ) - self.advanced_decoding(x_chunk, x_chunk_lens) - self.update_result() - - return self.result_transcripts[0] - except Exception as e: - logger.exception(e) else: - raise Exception("invalid model name") - - def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): - logger.info("start to decode with advanced_decoding method") - encoder_out, encoder_mask = self.encoder_forward(xs) - ctc_probs = self.model.ctc.log_softmax( - encoder_out) # (1, maxlen, vocab_size) - ctc_probs = ctc_probs.squeeze(0) - self.searcher.search(ctc_probs, xs.place) - # update the one best result - self.hyps = self.searcher.get_one_best_hyps() - - # now we supprot ctc_prefix_beam_search and attention_rescoring - if "attention_rescoring" in self.config.decode.decoding_method: - self.rescoring(encoder_out, xs.place) - - def encoder_forward(self, xs): - logger.info("get the model out from the feat") - cfg = self.config.decode - decoding_chunk_size = cfg.decoding_chunk_size - num_decoding_left_chunks = cfg.num_decoding_left_chunks - - assert decoding_chunk_size > 0 - subsampling = self.model.encoder.embed.subsampling_rate - context = self.model.encoder.embed.right_context + 1 - stride = subsampling * decoding_chunk_size - - # decoding window for model - decoding_window = (decoding_chunk_size - 1) * subsampling + context - num_frames = xs.shape[1] - required_cache_size = decoding_chunk_size * num_decoding_left_chunks - - logger.info("start to do model forward") - outputs = [] - - # num_frames - context + 1 ensure that current frame can get context window - for cur in range(0, num_frames - context + 1, stride): - end = min(cur + decoding_window, num_frames) - chunk_xs = xs[:, cur:end, :] - (y, self.subsampling_cache, self.elayers_output_cache, - self.conformer_cnn_cache) = self.model.encoder.forward_chunk( - chunk_xs, self.offset, required_cache_size, - self.subsampling_cache, self.elayers_output_cache, - self.conformer_cnn_cache) - outputs.append(y) - self.offset += y.shape[1] - - ys = paddle.cat(outputs, 1) - masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) - masks = masks.unsqueeze(1) - return ys, masks - - def rescoring(self, encoder_out, device): - logger.info("start to rescoring the hyps") - beam_size = self.config.decode.beam_size - hyps = self.searcher.get_hyps() - assert len(hyps) == beam_size - - hyp_list = [] - for hyp in hyps: - hyp_content = hyp[0] - # Prevent the hyp is empty - if len(hyp_content) == 0: - hyp_content = (self.model.ctc.blank_id, ) - hyp_content = paddle.to_tensor( - hyp_content, place=device, dtype=paddle.long) - hyp_list.append(hyp_content) - hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) - hyps_lens = paddle.to_tensor( - [len(hyp[0]) for hyp in hyps], place=device, - dtype=paddle.long) # (beam_size,) - hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, - self.model.ignore_id) - hyps_lens = hyps_lens + 1 # Add at begining - - encoder_out = encoder_out.repeat(beam_size, 1, 1) - encoder_mask = paddle.ones( - (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) - decoder_out, _ = self.model.decoder( - encoder_out, encoder_mask, hyps_pad, - hyps_lens) # (beam_size, max_hyps_len, vocab_size) - # ctc score in ln domain - decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) - decoder_out = decoder_out.numpy() - - # Only use decoder score for rescoring - best_score = -float('inf') - best_index = 0 - # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size - for i, hyp in enumerate(hyps): - score = 0.0 - for j, w in enumerate(hyp[0]): - score += decoder_out[i][j][w] - # last decoder output token is `eos`, for laste decoder input token. - score += decoder_out[i][len(hyp[0])][self.model.eos] - # add ctc score (which in ln domain) - score += hyp[1] * self.config.decode.ctc_weight - if score > best_score: - best_score = score - best_index = i - - # update the one best result - self.hyps = [hyps[best_index][0]] - return hyps[best_index][0] - - def transformer_decode_reset(self): - self.subsampling_cache = None - self.elayers_output_cache = None - self.conformer_cnn_cache = None - self.offset = 0 - # decoding reset - self.searcher.reset() - - def update_result(self): - logger.info("update the final result") - hyps = self.hyps - self.result_transcripts = [ - self.text_feature.defeaturize(hyp) for hyp in hyps - ] - self.result_tokenids = [hyp for hyp in hyps] - - def extract_feat(self, samples, sample_rate): - """extract feat - - Args: - samples (numpy.array): numpy.float32 - sample_rate (int): sample rate - - Returns: - x_chunk (numpy.array): shape[B, T, D] - x_chunk_lens (numpy.array): shape[B] - """ - - if "deepspeech2online" in self.model_type: - # pcm16 -> pcm 32 - samples = pcm2float(samples) - # read audio - speech_segment = SpeechSegment.from_pcm( - samples, sample_rate, transcript=" ") - # audio augment - self.collate_fn_test.augmentation.transform_audio(speech_segment) - - # extract speech feature - spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( - speech_segment, self.collate_fn_test.keep_transcription_text) - # CMVN spectrum - if self.collate_fn_test._normalizer: - spectrum = self.collate_fn_test._normalizer.apply(spectrum) - - # spectrum augment - audio = self.collate_fn_test.augmentation.transform_feature( - spectrum) - - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - # audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) - - x_chunk = audio.numpy() - x_chunk_lens = np.array([audio_len]) - - return x_chunk, x_chunk_lens - elif "conformer_online" in self.model_type: - - if sample_rate != self.sample_rate: - logger.info(f"audio sample rate {sample_rate} is not match," - "the model sample_rate is {self.sample_rate}") - logger.info(f"ASR Engine use the {self.model_type} to process") - logger.info("Create the preprocess instance") - preprocess_conf = self.config.preprocess_config - preprocess_args = {"train": False} - preprocessing = Transformation(preprocess_conf) - - logger.info("Read the audio file") - logger.info(f"audio shape: {samples.shape}") - # fbank - x_chunk = preprocessing(samples, **preprocess_args) - x_chunk_lens = paddle.to_tensor(x_chunk.shape[0]) - x_chunk = paddle.to_tensor( - x_chunk, dtype="float32").unsqueeze(axis=0) - logger.info( - f"process the audio feature success, feat shape: {x_chunk.shape}" - ) - return x_chunk, x_chunk_lens + raise ValueError(f"Not support: {model_type}") + + return True class ASREngine(BaseEngine): - """ASR server engine + """ASR server resource Args: metaclass: Defaults to Singleton. @@ -1015,7 +868,7 @@ class ASREngine(BaseEngine): def __init__(self): super(ASREngine, self).__init__() - logger.info("create the online asr engine instance") + logger.info("create the online asr engine resource instance") def init(self, config: dict) -> bool: """init engine resource @@ -1026,17 +879,12 @@ class ASREngine(BaseEngine): Returns: bool: init failed or success """ - self.input = None - self.output = "" - self.executor = ASRServerExecutor() self.config = config + self.executor = ASRServerExecutor() + try: - if self.config.get("device", None): - self.device = self.config.device - else: - self.device = paddle.get_device() - logger.info(f"paddlespeech_server set the device: {self.device}") - paddle.set_device(self.device) + default_dev = paddle.get_device() + paddle.set_device(self.config.get("device", default_dev)) except BaseException as e: logger.error( f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file" @@ -1045,6 +893,8 @@ class ASREngine(BaseEngine): "If all GPU or XPU is used, you can set the server to 'cpu'") sys.exit(-1) + logger.info(f"paddlespeech_server set the device: {self.device}") + if not self.executor._init_from_path( model_type=self.config.model_type, am_model=self.config.am_model, @@ -1062,42 +912,11 @@ class ASREngine(BaseEngine): logger.info("Initialize ASR server engine successfully.") return True - def preprocess(self, - samples, - sample_rate, - model_type="deepspeech2online_aishell-zh-16k"): - """preprocess - - Args: - samples (numpy.array): numpy.float32 - sample_rate (int): sample rate - - Returns: - x_chunk (numpy.array): shape[B, T, D] - x_chunk_lens (numpy.array): shape[B] - """ - # if "deepspeech" in model_type: - x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate) - return x_chunk, x_chunk_lens + def preprocess(self, *args, **kwargs): + raise NotImplementedError("Online not using this.") - def run(self, x_chunk, x_chunk_lens, decoder_chunk_size=1): - """run online engine - - Args: - x_chunk (numpy.array): shape[B, T, D] - x_chunk_lens (numpy.array): shape[B] - decoder_chunk_size(int) - """ - self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, - self.config.model_type) + def run(self, *args, **kwargs): + raise NotImplementedError("Online not using this.") def postprocess(self): - """postprocess - """ - return self.output - - def reset(self): - """reset engine decoder and inference state - """ - self.executor.reset_decoder_and_chunk() - self.output = "" + raise NotImplementedError("Online not using this.") -- GitLab