diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e3cc36e0094aa4f1c9c6bba529b0dcdfa9cbb2b8..6e7ae1fbf99df3a42c6691d9ca8499e040657ab3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,12 +51,12 @@ repos: language: system files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ - - id: copyright_checker - name: copyright_checker - entry: python .pre-commit-hooks/copyright-check.hook - language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ - exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ + #- id: copyright_checker + # name: copyright_checker + # entry: python .pre-commit-hooks/copyright-check.hook + # language: system + # files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ + # exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ - repo: https://github.com/asottile/reorder_python_imports rev: v2.4.0 hooks: diff --git a/demos/streaming_asr_server/server.sh b/demos/streaming_asr_server/server.sh index 4266f8c642c83ece8dc4a2dd29812acfad4d6f8a..0d255807c8841396528696198ce1d136b93c69ce 100755 --- a/demos/streaming_asr_server/server.sh +++ b/demos/streaming_asr_server/server.sh @@ -5,4 +5,5 @@ export CUDA_VISIBLE_DEVICE=0,1,2,3 paddlespeech_server start --config_file conf/punc_application.yaml &> punc.log & # nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 & -paddlespeech_server start --config_file conf/ws_conformer_application.yaml &> streaming_asr.log & \ No newline at end of file +paddlespeech_server start --config_file conf/ws_conformer_application.yaml &> streaming_asr.log & + diff --git a/demos/streaming_asr_server/test.sh b/demos/streaming_asr_server/test.sh index 4f43c6534f078683329a287bb87a1c79cff15b8f..f3075454d6ccda411bf024d354b693b3625aa1fe 100755 --- a/demos/streaming_asr_server/test.sh +++ b/demos/streaming_asr_server/test.sh @@ -9,4 +9,5 @@ paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wa # read the wav and call streaming and punc service # If `127.0.0.1` is not accessible, you need to use the actual service IP address. # python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav -paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav \ No newline at end of file +paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav + diff --git a/paddlespeech/__init__.py b/paddlespeech/__init__.py index b781c4a8e5cc99590e179faf1c4c3989349d4216..29db45cc5dcc1dbb0354f1ccd96d815a0cf184ad 100644 --- a/paddlespeech/__init__.py +++ b/paddlespeech/__init__.py @@ -14,3 +14,7 @@ import _locale _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) + + + + diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index a52219730b441e8c40c7ab481976402224d4dfb5..2dce35cb59f3e3905aff24438423a0f7ea9a8365 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -13,6 +13,7 @@ # limitations under the License. import os import sys +from typing import ByteString from typing import Optional import numpy as np @@ -30,9 +31,10 @@ from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoingOpt +from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoint from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch from paddlespeech.server.engine.base_engine import BaseEngine -from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor __all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine'] @@ -54,24 +56,33 @@ class PaddleASRConnectionHanddler: self.model_config = asr_engine.executor.config self.asr_engine = asr_engine - self.init() - self.reset() - - def init(self): # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer self.model_type = self.asr_engine.executor.model_type self.sample_rate = self.asr_engine.executor.sample_rate # tokens to text self.text_feature = self.asr_engine.executor.text_feature + # extract feat, new only fbank in conformer model + self.preprocess_conf = self.model_config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + + # frame window and frame shift, in samples unit + self.win_length = self.preprocess_conf.process[0]['win_length'] + self.n_shift = self.preprocess_conf.process[0]['n_shift'] + + assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, ( + self.sample_rate, self.preprocess_conf.process[0]['fs']) + self.frame_shift_in_ms = int( + self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000) + + self.init_decoder() + self.reset() + + def init_decoder(self): if "deepspeech2" in self.model_type: self.am_predictor = self.asr_engine.executor.am_predictor - # extract feat, new only fbank in conformer model - self.preprocess_conf = self.model_config.preprocess_config - self.preprocess_args = {"train": False} - self.preprocessing = Transformation(self.preprocess_conf) - self.decoder = CTCDecoder( odim=self.model_config.output_dim, # is in vocab enc_n_units=self.model_config.rnn_layer_size * 2, @@ -90,10 +101,6 @@ class PaddleASRConnectionHanddler: cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) - # frame window and frame shift, in samples unit - self.win_length = self.preprocess_conf.process[0]['win_length'] - self.n_shift = self.preprocess_conf.process[0]['n_shift'] - elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model self.model = self.asr_engine.executor.model @@ -102,130 +109,40 @@ class PaddleASRConnectionHanddler: self.ctc_decode_config = self.asr_engine.executor.config.decode self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) - # extract feat, new only fbank in conformer model - self.preprocess_conf = self.model_config.preprocess_config - self.preprocess_args = {"train": False} - self.preprocessing = Transformation(self.preprocess_conf) - - # frame window and frame shift, in samples unit - self.win_length = self.preprocess_conf.process[0]['win_length'] - self.n_shift = self.preprocess_conf.process[0]['n_shift'] + # ctc endpoint + self.endpoint_opt = OnlineCTCEndpoingOpt( + frame_shift_in_ms=self.frame_shift_in_ms, blank=0) + self.endpointer = OnlineCTCEndpoint(self.endpoint_opt) else: raise ValueError(f"Not supported: {self.model_type}") - def extract_feat(self, samples): - # we compute the elapsed time of first char occuring - # and we record the start time at the first pcm sample arraving - - if "deepspeech2online" in self.model_type: - # self.reamined_wav stores all the samples, - # include the original remained_wav and this package samples - samples = np.frombuffer(samples, dtype=np.int16) - assert samples.ndim == 1 - - if self.remained_wav is None: - self.remained_wav = samples - else: - assert self.remained_wav.ndim == 1 - self.remained_wav = np.concatenate([self.remained_wav, samples]) - logger.info( - f"The connection remain the audio samples: {self.remained_wav.shape}" - ) - - # fbank - feat = self.preprocessing(self.remained_wav, - **self.preprocess_args) - feat = paddle.to_tensor( - feat, dtype="float32").unsqueeze(axis=0) - - if self.cached_feat is None: - self.cached_feat = feat - else: - assert (len(feat.shape) == 3) - assert (len(self.cached_feat.shape) == 3) - self.cached_feat = paddle.concat( - [self.cached_feat, feat], axis=1) - - # set the feat device - if self.device is None: - self.device = self.cached_feat.place - - # cur frame step - num_frames = feat.shape[1] - - self.num_frames += num_frames - self.remained_wav = self.remained_wav[self.n_shift * num_frames:] - - logger.info( - f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" - ) - logger.info( - f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" - ) - - elif "conformer_online" in self.model_type: - logger.info("Online ASR extract the feat") - samples = np.frombuffer(samples, dtype=np.int16) - assert samples.ndim == 1 - - self.num_samples += samples.shape[0] - logger.info( - f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}" - ) - - # self.reamined_wav stores all the samples, - # include the original remained_wav and this package samples - if self.remained_wav is None: - self.remained_wav = samples - else: - assert self.remained_wav.ndim == 1 # (T,) - self.remained_wav = np.concatenate([self.remained_wav, samples]) - logger.info( - f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}" - ) - - if len(self.remained_wav) < self.win_length: - # samples not enough for feature window - return 0 - - # fbank - x_chunk = self.preprocessing(self.remained_wav, - **self.preprocess_args) - x_chunk = paddle.to_tensor( - x_chunk, dtype="float32").unsqueeze(axis=0) - - # feature cache - if self.cached_feat is None: - self.cached_feat = x_chunk - else: - assert (len(x_chunk.shape) == 3) # (B,T,D) - assert (len(self.cached_feat.shape) == 3) # (B,T,D) - self.cached_feat = paddle.concat( - [self.cached_feat, x_chunk], axis=1) - - # set the feat device - if self.device is None: - self.device = self.cached_feat.place + def model_reset(self): + if "deepspeech2" in self.model_type: + return - # cur frame step - num_frames = x_chunk.shape[1] + # feature cache + self.cached_feat = None - # global frame step - self.num_frames += num_frames + ## conformer + # cache for conformer online + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_out = None + # conformer decoding state + self.offset = 0 # global offset in decoding frame unit - # update remained wav - self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + ## just for record info + self.chunk_num = 0 # global decoding chunk num, not used - logger.info( - f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" - ) - logger.info( - f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}" - ) - logger.info(f"global samples: {self.num_samples}") - logger.info(f"global frames: {self.num_frames}") - else: - raise ValueError(f"not supported: {self.model_type}") + def reset_continuous_decoding(self): + """ + when in continous decoding, reset for next utterance. + """ + self.global_frame_offset = self.num_frames + self.model_reset() + self.searcher.reset() + self.endpointer.reset() def reset(self): if "deepspeech2" in self.model_type: @@ -241,53 +158,110 @@ class PaddleASRConnectionHanddler: dtype=float32) self.decoder.reset_decoder(batch_size=1) + if "conformer" in self.model_type or "transformer" in self.model_type: + self.searcher.reset() + self.endpointer.reset() + self.device = None ## common - # global sample and frame step self.num_samples = 0 + self.global_frame_offset = 0 + # frame step of cur utterance self.num_frames = 0 # cache for audio and feat self.remained_wav = None self.cached_feat = None - # partial/ending decoding results - self.result_transcripts = [''] - ## conformer + self.model_reset() - # cache for conformer online - self.subsampling_cache = None - self.elayers_output_cache = None - self.conformer_cnn_cache = None - self.encoder_out = None - # conformer decoding state - self.chunk_num = 0 # globa decoding chunk num - self.offset = 0 # global offset in decoding frame unit - self.hyps = [] - + ## outputs + # partial/ending decoding results + self.result_transcripts = [''] # token timestamp result self.word_time_stamp = [] + ## just for record + self.hyps = [] + # one best timestamp viterbi prob is large. self.time_stamp = [] + def extract_feat(self, samples: ByteString): + logger.info("Online ASR extract the feat") + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 + + self.num_samples += samples.shape[0] + logger.info( + f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}" + ) + + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 # (T,) + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}" + ) + + if len(self.remained_wav) < self.win_length: + # samples not enough for feature window + return 0 + + # fbank + x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) + x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0) + + # feature cache + if self.cached_feat is None: + self.cached_feat = x_chunk + else: + assert (len(x_chunk.shape) == 3) # (B,T,D) + assert (len(self.cached_feat.shape) == 3) # (B,T,D) + self.cached_feat = paddle.concat( + [self.cached_feat, x_chunk], axis=1) + + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + + # cur frame step + num_frames = x_chunk.shape[1] + + # global frame step + self.num_frames += num_frames + + # update remained wav + self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + + logger.info( + f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}" + ) + logger.info(f"global samples: {self.num_samples}") + logger.info(f"global frames: {self.num_frames}") + def decode(self, is_finished=False): """advance decoding Args: is_finished (bool, optional): Is last frame or not. Defaults to False. - Raises: - Exception: when not support model. - Returns: - None: nothing + None: """ - if "deepspeech2online" in self.model_type: + if "deepspeech2" in self.model_type: decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit + context = 7 # context=7, in audio frame unit subsampling = 4 # subsampling=4, in audio frame unit @@ -332,9 +306,11 @@ class PaddleASRConnectionHanddler: end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) + # extract the audio x_chunk = self.cached_feat[:, cur:end, :].numpy() x_chunk_lens = np.array([x_chunk.shape[1]]) + trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) self.result_transcripts = [trans_best] @@ -409,31 +385,38 @@ class PaddleASRConnectionHanddler: @paddle.no_grad() def advance_decoding(self, is_finished=False): + if "deepspeech" in self.model_type: + return + logger.info( "Conformer/Transformer: start to decode with advanced_decoding method" ) cfg = self.ctc_decode_config - # cur chunk size, in decoding frame unit + # cur chunk size, in decoding frame unit, e.g. 16 decoding_chunk_size = cfg.decoding_chunk_size - # using num of history chunks + # using num of history chunks, e.g -1 num_decoding_left_chunks = cfg.num_decoding_left_chunks assert decoding_chunk_size > 0 + # e.g. 4 subsampling = self.model.encoder.embed.subsampling_rate + # e.g. 7 context = self.model.encoder.embed.right_context + 1 - # processed chunk feature cached for next chunk + # processed chunk feature cached for next chunk, e.g. 3 cached_feature_num = context - subsampling - # decoding stride, in audio frame unit - stride = subsampling * decoding_chunk_size + # decoding window, in audio frame unit decoding_window = (decoding_chunk_size - 1) * subsampling + context + # decoding stride, in audio frame unit + stride = subsampling * decoding_chunk_size if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") return + # (B=1,T,D) num_frames = self.cached_feat.shape[1] logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" @@ -454,9 +437,6 @@ class PaddleASRConnectionHanddler: return None, None logger.info("start to do model forward") - # hist of chunks, in deocding frame unit - required_cache_size = decoding_chunk_size * num_decoding_left_chunks - outputs = [] # num_frames - context + 1 ensure that current frame can get context window if is_finished: @@ -466,7 +446,11 @@ class PaddleASRConnectionHanddler: # we only process decoding_window frames for one chunk left_frames = decoding_window + # hist of chunks, in deocding frame unit + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + # record the end for removing the processed feat + outputs = [] end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) @@ -491,30 +475,28 @@ class PaddleASRConnectionHanddler: self.encoder_out = ys else: self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) + logger.info( + f"This connection handler encoder out shape: {self.encoder_out.shape}" + ) # get the ctc probs ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) + ## decoding # advance decoding self.searcher.search(ctc_probs, self.cached_feat.place) # get one best hyps self.hyps = self.searcher.get_one_best_hyps() - assert self.cached_feat.shape[0] == 1 - assert end >= cached_feature_num - # advance cache of feat - self.cached_feat = self.cached_feat[0, end - - cached_feature_num:, :].unsqueeze(0) + assert self.cached_feat.shape[0] == 1 #(B=1,T,D) + assert end >= cached_feature_num + self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] assert len( self.cached_feat.shape ) == 3, f"current cache feat shape is: {self.cached_feat.shape}" - logger.info( - f"This connection handler encoder out shape: {self.encoder_out.shape}" - ) - def update_result(self): """Conformer/Transformer hyps to result. """ @@ -654,24 +636,28 @@ class PaddleASRConnectionHanddler: # update each word start and end time stamp # decoding frame to audio frame - frame_shift = self.model.encoder.embed.subsampling_rate - frame_shift_in_sec = frame_shift * (self.n_shift / self.sample_rate) - logger.info(f"frame shift sec: {frame_shift_in_sec}") + decode_frame_shift = self.model.encoder.embed.subsampling_rate + decode_frame_shift_in_sec = decode_frame_shift * (self.n_shift / + self.sample_rate) + logger.info(f"decode frame shift in sec: {decode_frame_shift_in_sec}") + + global_offset_in_sec = self.global_frame_offset * self.frame_shift_in_ms / 1000.0 + logger.info(f"global offset: {global_offset_in_sec} sec.") word_time_stamp = [] for idx, _ in enumerate(self.time_stamp): start = (self.time_stamp[idx - 1] + self.time_stamp[idx] ) / 2.0 if idx > 0 else 0 - start = start * frame_shift_in_sec + start = start * decode_frame_shift_in_sec end = (self.time_stamp[idx] + self.time_stamp[idx + 1] ) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset - end = end * frame_shift_in_sec + end = end * decode_frame_shift_in_sec word_time_stamp.append({ "w": self.result_transcripts[0][idx], - "bg": start, - "ed": end + "bg": global_offset_in_sec + start, + "ed": global_offset_in_sec + end }) # logger.info(f"{word_time_stamp[-1]}") @@ -705,13 +691,14 @@ class ASRServerExecutor(ASRExecutor): self.model_type = model_type self.sample_rate = sample_rate + logger.info(f"model_type: {self.model_type}") + sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str self.task_resource.set_task_model(model_tag=tag) + if cfg_path is None or am_model is None or am_params is None: - logger.info(f"Load the pretrained model, tag = {tag}") self.res_path = self.task_resource.res_dir - self.cfg_path = os.path.join( self.res_path, self.task_resource.res_dict['cfg_path']) @@ -719,7 +706,6 @@ class ASRServerExecutor(ASRExecutor): self.task_resource.res_dict['model']) self.am_params = os.path.join(self.res_path, self.task_resource.res_dict['params']) - logger.info(self.res_path) else: self.cfg_path = os.path.abspath(cfg_path) self.am_model = os.path.abspath(am_model) @@ -727,9 +713,12 @@ class ASRServerExecutor(ASRExecutor): self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) - logger.info(self.cfg_path) - logger.info(self.am_model) - logger.info(self.am_params) + logger.info("Load the pretrained model:") + logger.info(f" tag = {tag}") + logger.info(f" res_path: {self.res_path}") + logger.info(f" cfg path: {self.cfg_path}") + logger.info(f" am_model path: {self.am_model}") + logger.info(f" am_params path: {self.am_params}") #Init body. self.config = CfgNode(new_allowed=True) @@ -738,25 +727,39 @@ class ASRServerExecutor(ASRExecutor): if self.config.spm_model_prefix: self.config.spm_model_prefix = os.path.join( self.res_path, self.config.spm_model_prefix) + logger.info(f"spm model path: {self.config.spm_model_prefix}") + + self.vocab = self.config.vocab_filepath + self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) - self.vocab = self.config.vocab_filepath - with UpdateConfig(self.config): - if "deepspeech2" in model_type: + + if "deepspeech2" in model_type: + with UpdateConfig(self.config): + # download lm self.config.decode.lang_model_path = os.path.join( MODEL_HOME, 'language_model', self.config.decode.lang_model_path) - lm_url = self.task_resource.res_dict['lm_url'] - lm_md5 = self.task_resource.res_dict['lm_md5'] - logger.info(f"Start to load language model {lm_url}") - self.download_lm( - lm_url, - os.path.dirname(self.config.decode.lang_model_path), lm_md5) + lm_url = self.task_resource.res_dict['lm_url'] + lm_md5 = self.task_resource.res_dict['lm_md5'] + logger.info(f"Start to load language model {lm_url}") + self.download_lm( + lm_url, + os.path.dirname(self.config.decode.lang_model_path), lm_md5) + + # AM predictor + logger.info("ASR engine start to init the am predictor") + self.am_predictor_conf = am_predictor_conf + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) - elif "conformer" in model_type or "transformer" in model_type: + elif "conformer" in model_type or "transformer" in model_type: + with UpdateConfig(self.config): logger.info("start to create the stream conformer asr engine") # update the decoding method if decode_method: @@ -770,37 +773,24 @@ class ASRServerExecutor(ASRExecutor): logger.info( "we set the decoding_method to attention_rescoring") self.config.decode.decoding_method = "attention_rescoring" + assert self.config.decode.decoding_method in [ "ctc_prefix_beam_search", "attention_rescoring" ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" - else: - raise Exception("wrong type") - if "deepspeech2" in model_type: - # AM predictor - logger.info("ASR engine start to init the am predictor") - self.am_predictor_conf = am_predictor_conf - self.am_predictor = init_predictor( - model_file=self.am_model, - params_file=self.am_params, - predictor_conf=self.am_predictor_conf) - elif "conformer" in model_type or "transformer" in model_type: + # load model model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} logger.info(f"model name: {model_name}") model_class = self.task_resource.get_model_class(model_name) - model_conf = self.config - model = model_class.from_config(model_conf) + model = model_class.from_config(self.config) self.model = model + self.model.set_state_dict(paddle.load(self.am_model)) self.model.eval() - - # load model - model_dict = paddle.load(self.am_model) - self.model.set_state_dict(model_dict) - logger.info("create the transformer like model success") else: - raise ValueError(f"Not support: {model_type}") + raise Exception(f"not support: {model_type}") + logger.info(f"create the {model_type} model success") return True diff --git a/paddlespeech/server/engine/asr/online/ctc_endpoint.py b/paddlespeech/server/engine/asr/online/ctc_endpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..70146d6dc398233e57d0db4e021cb4f06794a530 --- /dev/null +++ b/paddlespeech/server/engine/asr/online/ctc_endpoint.py @@ -0,0 +1,108 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List + +from paddlespeech.cli.log import logger + + +@dataclass +class OnlineCTCEndpointRule: + must_contain_nonsilence: bool = True + min_trailing_silence: int = 1000 + min_utterance_length: int = 0 + + +@dataclass +class OnlineCTCEndpoingOpt: + frame_shift_in_ms: int = 10 + + blank: int = 0 # blank id, that we consider as silence for purposes of endpointing. + blank_threshold: float = 0.8 # above blank threshold is silence + + # We support three rules. We terminate decoding if ANY of these rules + # evaluates to "true". If you want to add more rules, do it by changing this + # code. If you want to disable a rule, you can set the silence-timeout for + # that rule to a very large number. + + # rule1 times out after 5 seconds of silence, even if we decoded nothing. + rule1: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 5000, 0) + # rule4 times out after 1.0 seconds of silence after decoding something, + # even if we did not reach a final-state at all. + rule2: OnlineCTCEndpointRule = OnlineCTCEndpointRule(True, 1000, 0) + # rule5 times out after the utterance is 20 seconds long, regardless of + # anything else. + rule3: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 0, 20000) + + +class OnlineCTCEndpoint: + """ + [END-TO-END AUTOMATIC SPEECH RECOGNITION INTEGRATED WITH CTC-BASED VOICE ACTIVITY DETECTION](https://arxiv.org/pdf/2002.00551.pdf) + """ + + def __init__(self, opts: OnlineCTCEndpoingOpt): + self.opts = opts + logger.info(f"Endpont Opts: {opts}") + self.frame_shift_in_ms = opts.frame_shift_in_ms + + self.num_frames_decoded = 0 + self.trailing_silence_frames = 0 + + self.reset() + + def reset(self): + self.num_frames_decoded = 0 + self.trailing_silence_frames = 0 + + def rule_activated(self, + rule: OnlineCTCEndpointRule, + rule_name: str, + decoding_something: bool, + trailine_silence: int, + utterance_length: int) -> bool: + ans = ( + decoding_something or (not rule.must_contain_nonsilence) + ) and trailine_silence >= rule.min_trailing_silence and utterance_length >= rule.min_utterance_length + if (ans): + logger.info( + f"Endpoint Rule: {rule_name} activated: {decoding_something}, {trailine_silence}, {utterance_length}" + ) + return ans + + def endpoint_detected(ctc_log_probs: List[List[float]], + decoding_something: bool) -> bool: + for logprob in ctc_log_probs: + blank_prob = exp(logprob[self.opts.blank_id]) + + self.num_frames_decoded += 1 + if blank_prob > self.opts.blank_threshold: + self.trailing_silence_frames += 1 + else: + self.trailing_silence_frames = 0 + + assert self.num_frames_decoded >= self.trailing_silence_frames + assert self.frame_shift_in_ms > 0 + + utterance_length = self.num_frames_decoded * self.frame_shift_in_ms + trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms + if self.rule_activated(self.opts.rule1, 'rule1', decoding_something, + trailing_silence, utterance_length): + return True + if self.rule_activated(self.opts.rule2, 'rule2', decoding_something, + trailing_silence, utterance_length): + return True + if self.rule_activated(self.opts.rule3, 'rule3', decoding_something, + trailing_silence, utterance_length): + return True + return False diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index 4c9ac3acbad1758b163e34f10ac46f6ae78e9b71..46f310c80fe96fcbfc1d9dbc83b9f6f8161e24a1 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -30,8 +30,29 @@ class CTCPrefixBeamSearch: config (yacs.config.CfgNode): the ctc prefix beam search configuration """ self.config = config + + # beam size + self.first_beam_size = self.config.beam_size + # TODO(support second beam size) + self.second_beam_size = int(self.first_beam_size * 1.0) + logger.info( + f"first and second beam size: {self.first_beam_size}, {self.second_beam_size}" + ) + + # state + self.cur_hyps = None + self.hyps = None + self.abs_time_step = 0 + self.reset() + def reset(self): + """Rest the search cache value + """ + self.cur_hyps = None + self.hyps = None + self.abs_time_step = 0 + @paddle.no_grad() def search(self, ctc_probs, device, blank_id=0): """ctc prefix beam search method decode a chunk feature @@ -47,12 +68,17 @@ class CTCPrefixBeamSearch: """ # decode logger.info("start to ctc prefix search") - + assert len(ctc_probs.shape) == 2 batch_size = 1 - beam_size = self.config.beam_size - maxlen = ctc_probs.shape[0] - assert len(ctc_probs.shape) == 2 + vocab_size = ctc_probs.shape[1] + first_beam_size = min(self.first_beam_size, vocab_size) + second_beam_size = min(self.second_beam_size, vocab_size) + logger.info( + f"effect first and second beam size: {self.first_beam_size}, {self.second_beam_size}" + ) + + maxlen = ctc_probs.shape[0] # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) # 0. blank_ending_score, @@ -75,7 +101,8 @@ class CTCPrefixBeamSearch: # 2.1 First beam prune: select topk best # do token passing process - top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + top_k_logp, top_k_index = logp.topk( + first_beam_size) # (first_beam_size,) for s in top_k_index: s = s.item() ps = logp[s].item() @@ -148,7 +175,7 @@ class CTCPrefixBeamSearch: next_hyps.items(), key=lambda x: log_add([x[1][0], x[1][1]]), reverse=True) - self.cur_hyps = next_hyps[:beam_size] + self.cur_hyps = next_hyps[:second_beam_size] # 2.3 update the absolute time step self.abs_time_step += 1 @@ -163,7 +190,7 @@ class CTCPrefixBeamSearch: """Return the one best result Returns: - list: the one best result + list: the one best result, List[str] """ return [self.hyps[0][0]] @@ -171,17 +198,10 @@ class CTCPrefixBeamSearch: """Return the search hyps Returns: - list: return the search hyps + list: return the search hyps, List[Tuple[str, float, ...]] """ return self.hyps - def reset(self): - """Rest the search cache value - """ - self.cur_hyps = None - self.hyps = None - self.abs_time_step = 0 - def finalize_search(self): """do nothing in ctc_prefix_beam_search """