From d2640c14064058c5283830fd2046d1788e800046 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 18 Apr 2022 12:58:40 +0800 Subject: [PATCH] add mult sesssion process, test=doc --- .../server/engine/asr/online/asr_engine.py | 190 +++++++++++++++++- 1 file changed, 189 insertions(+), 1 deletion(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index e292f9cf..3546e598 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -78,6 +78,194 @@ pretrained_models = { }, } +# ASR server connection process class + +class PaddleASRConnectionHanddler: + def __init__(self, asr_engine): + super().__init__() + self.config = asr_engine.config + self.model_config = asr_engine.executor.config + self.asr_engine = asr_engine + + self.init() + self.reset() + + def init(self): + self.model_type = self.asr_engine.executor.model_type + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + pass + elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: + self.sample_rate = self.asr_engine.executor.sample_rate + + # acoustic model + self.model = self.asr_engine.executor.model + + # tokens to text + self.text_feature = self.asr_engine.executor.text_feature + + # ctc decoding + self.ctc_decode_config = self.asr_engine.executor.config.decode + self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) + + # extract fbank + self.preprocess_conf = self.model_config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + self.win_length = self.preprocess_conf.process[0]['win_length'] + self.n_shift = self.preprocess_conf.process[0]['n_shift'] + + def extract_feat(self, samples): + if "deepspeech2online" in self.model_type: + pass + elif "conformer2online" in self.model_type: + logger.info("Online ASR extract the feat") + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 + + logger.info(f"This package receive {samples.shape[0]} pcm data") + self.num_samples += samples.shape[0] + + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The connection remain the audio samples: {self.remained_wav.shape}" + ) + if len(self.remained_wav) < self.win_length: + return 0 + + # fbank + x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) + x_chunk = paddle.to_tensor( + x_chunk, dtype="float32").unsqueeze(axis=0) + if self.cached_feat is None: + self.cached_feat = x_chunk + else: + self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1) + + num_frames = x_chunk.shape[1] + self.num_frames += num_frames + self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + + logger.info( + f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" + ) + # logger.info(f"accumulate samples: {self.num_samples}") + + def reset(self): + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_outs_ = None + self.cached_feat = None + self.remained_wav = None + self.offset = 0 + self.num_samples = 0 + + self.num_frames = 0 + self.global_frame_offset = 0 + self.result = [] + + def decode(self, is_finished=False): + if "deepspeech2online" in self.model_type: + pass + elif "conformer" in self.model_type or "transformer" in self.model_type: + try: + logger.info( + f"we will use the transformer like model : {self.model_type}" + ) + self.advance_decoding(is_finished) + # self.update_result() + + # return self.result_transcripts[0] + except Exception as e: + logger.exception(e) + else: + raise Exception("invalid model name") + + def advance_decoding(self, is_finished=False): + logger.info("start to decode with advanced_decoding method") + cfg = self.ctc_decode_config + decoding_chunk_size = cfg.decoding_chunk_size + num_decoding_left_chunks = cfg.num_decoding_left_chunks + + assert decoding_chunk_size > 0 + subsampling = self.model.encoder.embed.subsampling_rate + context = self.model.encoder.embed.right_context + 1 + stride = subsampling * decoding_chunk_size + + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = self.cached_feat.shape[1] + logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames") + + # the cached feat must be larger decoding_window + if num_frames < decoding_window and not is_finished: + return None, None + + # logger.info("start to do model forward") + # required_cache_size = decoding_chunk_size * num_decoding_left_chunks + # outputs = [] + + # # num_frames - context + 1 ensure that current frame can get context window + # if is_finished: + # # if get the finished chunk, we need process the last context + # left_frames = context + # else: + # # we only process decoding_window frames for one chunk + # left_frames = decoding_window + + # logger.info(f"") + # end = None + # for cur in range(0, num_frames - left_frames + 1, stride): + # end = min(cur + decoding_window, num_frames) + # print(f"cur: {cur}, end: {end}") + # chunk_xs = self.cached_feat[:, cur:end, :] + # (y, self.subsampling_cache, self.elayers_output_cache, + # self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + # chunk_xs, self.offset, required_cache_size, + # self.subsampling_cache, self.elayers_output_cache, + # self.conformer_cnn_cache) + # outputs.append(y) + # update the offset + # self.offset += y.shape[1] + # self.cached_feat = self.cached_feat[end:] + # ys = paddle.cat(outputs, 1) + # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) + # masks = masks.unsqueeze(1) + + # # get the ctc probs + # ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) + # ctc_probs = ctc_probs.squeeze(0) + # # self.searcher.search(xs, ctc_probs, xs.place) + + # self.searcher.search(None, ctc_probs, self.cached_feat.place) + + # self.hyps = self.searcher.get_one_best_hyps() + + # ys for rescoring + # return ys, masks + + def update_result(self): + logger.info("update the final result") + hyps = self.hyps + self.result_transcripts = [ + self.text_feature.defeaturize(hyp) for hyp in hyps + ] + self.result_tokenids = [hyp for hyp in hyps] + + def rescoring(self): + pass + + + class ASRServerExecutor(ASRExecutor): def __init__(self): @@ -492,7 +680,7 @@ class ASRServerExecutor(ASRExecutor): if sample_rate != self.sample_rate: logger.info(f"audio sample rate {sample_rate} is not match," "the model sample_rate is {self.sample_rate}") - logger.info("ASR Engine use the {self.model_type} to process") + logger.info(f"ASR Engine use the {self.model_type} to process") logger.info("Create the preprocess instance") preprocess_conf = self.config.preprocess_config preprocess_args = {"train": False} -- GitLab