提交 94327238 编写于 作者: H Hui Zhang

refactor asr online

上级 e6ddb0cc
......@@ -38,7 +38,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import pcm2float
from paddlespeech.server.utils.paddle_predictor import init_predictor
__all__ = ['ASREngine']
__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine']
# ASR server connection process class
......@@ -67,7 +67,7 @@ class PaddleASRConnectionHanddler:
# tokens to text
self.text_feature = self.asr_engine.executor.text_feature
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
if "deepspeech2" in self.model_type:
from paddlespeech.s2t.io.collator import SpeechCollator
self.am_predictor = self.asr_engine.executor.am_predictor
......@@ -89,8 +89,8 @@ class PaddleASRConnectionHanddler:
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
# frame window samples length and frame shift samples length
# frame window and frame shift, in samples unit
self.win_length = int(self.model_config.window_ms / 1000 *
self.sample_rate)
self.n_shift = int(self.model_config.stride_ms / 1000 *
......@@ -109,16 +109,15 @@ class PaddleASRConnectionHanddler:
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
# frame window samples length and frame shift samples length
# frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
else:
raise ValueError(f"Not supported: {self.model_type}")
def extract_feat(self, samples):
# we compute the elapsed time of first char occuring
# and we record the start time at the first pcm sample arraving
# if self.first_char_occur_elapsed is not None:
# self.first_char_occur_elapsed = time.time()
if "deepspeech2online" in self.model_type:
# self.reamined_wav stores all the samples,
......@@ -154,28 +153,28 @@ class PaddleASRConnectionHanddler:
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment
audio = self.collate_fn_test.augmentation.transform_feature(
feat = self.collate_fn_test.augmentation.transform_feature(
spectrum)
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
# audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
# audio_len is frame num
frame_num = feat.shape[0]
feat = paddle.to_tensor(feat, dtype='float32')
feat = paddle.unsqueeze(feat, axis=0)
if self.cached_feat is None:
self.cached_feat = audio
self.cached_feat = feat
else:
assert (len(audio.shape) == 3)
assert (len(feat.shape) == 3)
assert (len(self.cached_feat.shape) == 3)
self.cached_feat = paddle.concat(
[self.cached_feat, audio], axis=1)
[self.cached_feat, feat], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
self.num_frames += audio_len
self.remained_wav = self.remained_wav[self.n_shift * audio_len:]
self.num_frames += frame_num
self.remained_wav = self.remained_wav[self.n_shift * frame_num:]
logger.info(
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
......@@ -183,25 +182,28 @@ class PaddleASRConnectionHanddler:
logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
)
elif "conformer_online" in self.model_type:
logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
logger.info(f"This package receive {samples.shape[0]} pcm data")
self.num_samples += samples.shape[0]
logger.info(f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}")
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1
assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The connection remain the audio samples: {self.remained_wav.shape}"
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
)
if len(self.remained_wav) < self.win_length:
# samples not enough for feature window
return 0
# fbank
......@@ -209,11 +211,13 @@ class PaddleASRConnectionHanddler:
**self.preprocess_args)
x_chunk = paddle.to_tensor(
x_chunk, dtype="float32").unsqueeze(axis=0)
# feature cache
if self.cached_feat is None:
self.cached_feat = x_chunk
else:
assert (len(x_chunk.shape) == 3)
assert (len(self.cached_feat.shape) == 3)
assert (len(x_chunk.shape) == 3) # (B,T,D)
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1)
......@@ -221,20 +225,30 @@ class PaddleASRConnectionHanddler:
if self.device is None:
self.device = self.cached_feat.place
# cur frame step
num_frames = x_chunk.shape[1]
# global frame step
self.num_frames += num_frames
# update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info(
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
)
# logger.info(f"accumulate samples: {self.num_samples}")
logger.info(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}")
else:
raise ValueError(f"not supported: {self.model_type}")
def reset(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
if "deepspeech2" in self.model_type:
# for deepspeech2
self.chunk_state_h_box = copy.deepcopy(
self.asr_engine.executor.chunk_state_h_box)
......@@ -242,35 +256,63 @@ class PaddleASRConnectionHanddler:
self.asr_engine.executor.chunk_state_c_box)
self.decoder.reset_decoder(batch_size=1)
# for conformer online
self.device = None
## common
# global sample and frame step
self.num_samples = 0
self.num_frames = 0
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
# partial/ending decoding results
self.result_transcripts = ['']
## conformer
# cache for conformer online
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.encoder_out = None
self.cached_feat = None
self.remained_wav = None
self.offset = 0
self.num_samples = 0
self.device = None
# conformer decoding state
self.chunk_num = 0 # globa decoding chunk num
self.offset = 0 # global offset in decoding frame unit
self.hyps = []
self.num_frames = 0
self.chunk_num = 0
self.global_frame_offset = 0
self.result_transcripts = ['']
# token timestamp result
self.word_time_stamp = []
# one best timestamp viterbi prob is large.
self.time_stamp = []
self.first_char_occur_elapsed = None
def decode(self, is_finished=False):
"""advance decoding
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Raises:
Exception: when not support model.
Returns:
None: nothing
"""
if "deepspeech2online" in self.model_type:
# x_chunk 是特征数据
decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model
context = 7 # context=7 in deepspeech2 model
subsampling = 4 # subsampling=4 in deepspeech2 model
stride = subsampling * decoding_chunk_size
decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
context = 7 # context=7, in audio frame unit
subsampling = 4 # subsampling=4, in audio frame unit
cached_feature_num = context - subsampling
# decoding window for model
# decoding window for model, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context
# decoding stride for model, in audio frame unit
stride = subsampling * decoding_chunk_size
if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data")
......@@ -280,6 +322,7 @@ class PaddleASRConnectionHanddler:
logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
)
# the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished:
logger.info(
......@@ -293,6 +336,7 @@ class PaddleASRConnectionHanddler:
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return None, None
logger.info("start to do model forward")
# num_frames - context + 1 ensure that current frame can get context window
if is_finished:
......@@ -302,6 +346,7 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk
left_frames = decoding_window
end = None
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
# extract the audio
......@@ -311,7 +356,9 @@ class PaddleASRConnectionHanddler:
self.result_transcripts = [trans_best]
# update feat cache
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
# return trans_best[0]
elif "conformer" in self.model_type or "transformer" in self.model_type:
try:
......@@ -326,9 +373,19 @@ class PaddleASRConnectionHanddler:
else:
raise Exception("invalid model name")
@paddle.no_grad()
def decode_one_chunk(self, x_chunk, x_chunk_lens):
logger.info("start to decoce one chunk with deepspeech2 model")
"""forward one chunk frames
Args:
x_chunk (np.ndarray): (B,T,D), audio frames.
x_chunk_lens ([type]): (B,), audio frame lens
Returns:
logprob: poster probability.
"""
logger.info("start to decoce one chunk for deepspeech2")
input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
......@@ -365,24 +422,31 @@ class PaddleASRConnectionHanddler:
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one best result: {trans_best[0]}")
logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
return trans_best[0]
@paddle.no_grad()
def advance_decoding(self, is_finished=False):
logger.info("start to decode with advanced_decoding method")
logger.info("Conformer/Transformer: start to decode with advanced_decoding method")
cfg = self.ctc_decode_config
# cur chunk size, in decoding frame unit
decoding_chunk_size = cfg.decoding_chunk_size
# using num of history chunks
num_decoding_left_chunks = cfg.num_decoding_left_chunks
assert decoding_chunk_size > 0
subsampling = self.model.encoder.embed.subsampling_rate
context = self.model.encoder.embed.right_context + 1
stride = subsampling * decoding_chunk_size
cached_feature_num = context - subsampling # processed chunk feature cached for next chunk
# decoding window for model
# processed chunk feature cached for next chunk
cached_feature_num = context - subsampling
# decoding stride, in audio frame unit
stride = subsampling * decoding_chunk_size
# decoding window, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context
if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data")
return
......@@ -407,6 +471,7 @@ class PaddleASRConnectionHanddler:
return None, None
logger.info("start to do model forward")
# hist of chunks, in deocding frame unit
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
outputs = []
......@@ -423,8 +488,11 @@ class PaddleASRConnectionHanddler:
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
# global chunk_num
self.chunk_num += 1
# cur chunk
chunk_xs = self.cached_feat[:, cur:end, :]
# forward chunk
(y, self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size,
......@@ -432,7 +500,7 @@ class PaddleASRConnectionHanddler:
self.conformer_cnn_cache)
outputs.append(y)
# update the offset
# update the global offset, in decoding frame unit
self.offset += y.shape[1]
ys = paddle.cat(outputs, 1)
......@@ -445,12 +513,15 @@ class PaddleASRConnectionHanddler:
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
# advance decoding
self.searcher.search(ctc_probs, self.cached_feat.place)
# get one best hyps
self.hyps = self.searcher.get_one_best_hyps()
assert self.cached_feat.shape[0] == 1
assert end >= cached_feature_num
# advance cache of feat
self.cached_feat = self.cached_feat[0, end -
cached_feature_num:, :].unsqueeze(0)
assert len(
......@@ -462,50 +533,79 @@ class PaddleASRConnectionHanddler:
)
def update_result(self):
"""Conformer/Transformer hyps to result.
"""
logger.info("update the final result")
hyps = self.hyps
# output results and tokenids
self.result_transcripts = [
self.text_feature.defeaturize(hyp) for hyp in hyps
]
self.result_tokenids = [hyp for hyp in hyps]
def get_result(self):
"""return partial/ending asr result.
Returns:
str: one best result of partial/ending.
"""
if len(self.result_transcripts) > 0:
return self.result_transcripts[0]
else:
return ''
def get_word_time_stamp(self):
"""return token timestamp result.
Returns:
list: List of ('w':token, 'bg':time, 'ed':time)
"""
return self.word_time_stamp
@paddle.no_grad()
def rescoring(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
"""Second-Pass Decoding,
only for conformer and transformer model.
"""
if "deepspeech2" in self.model_type:
logger.info("deepspeech2 not support rescoring decoding.")
return
logger.info("rescoring the final result")
if "attention_rescoring" != self.ctc_decode_config.decoding_method:
logger.info(f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring")
return
logger.info("rescoring the final result")
# last decoding for last audio
self.searcher.finalize_search()
# update beam search results
self.update_result()
beam_size = self.ctc_decode_config.beam_size
hyps = self.searcher.get_hyps()
if hyps is None or len(hyps) == 0:
logger.info("No Hyps!")
return
# rescore by decoder post probability
# assert len(hyps) == beam_size
# list of Tensor
hyp_list = []
for hyp in hyps:
hyp_content = hyp[0]
# Prevent the hyp is empty
if len(hyp_content) == 0:
hyp_content = (self.model.ctc.blank_id, )
hyp_content = paddle.to_tensor(
hyp_content, place=self.device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id)
hyps_pad = pad_sequence(hyp_list, batch_first=True, padding_value=self.model.ignore_id)
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=self.device,
dtype=paddle.long) # (beam_size,)
......@@ -531,10 +631,12 @@ class PaddleASRConnectionHanddler:
score = 0.0
for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.model.eos]
# add ctc score (which in ln domain)
score += hyp[1] * self.ctc_decode_config.ctc_weight
if score > best_score:
best_score = score
best_index = i
......@@ -542,47 +644,57 @@ class PaddleASRConnectionHanddler:
# update the one best result
# hyps stored the beam results and each fields is:
logger.info(f"best index: {best_index}")
logger.info(f"best hyp index: {best_index}")
# logger.info(f'best result: {hyps[best_index]}')
# the field of the hyps is:
## asr results
# hyps[0][0]: the sentence word-id in the vocab with a tuple
# hyps[0][1]: the sentence decoding probability with all paths
## timestamp
# hyps[0][2]: viterbi_blank ending probability
# hyps[0][3]: viterbi_non_blank probability
# hyps[0][3]: viterbi_non_blank dending probability
# hyps[0][4]: current_token_prob,
# hyps[0][5]: times_viterbi_blank,
# hyps[0][6]: times_titerbi_non_blank
# hyps[0][5]: times_viterbi_blank ending timestamp,
# hyps[0][6]: times_titerbi_non_blank encding timestamp.
self.hyps = [hyps[best_index][0]]
logger.info(f"best hyp ids: {self.hyps}")
# update the hyps time stamp
self.time_stamp = hyps[best_index][5] if hyps[best_index][2] > hyps[
best_index][3] else hyps[best_index][6]
logger.info(f"time stamp: {self.time_stamp}")
# update one best result
self.update_result()
# update each word start and end time stamp
frame_shift_in_ms = self.model.encoder.embed.subsampling_rate * self.n_shift / self.sample_rate
logger.info(f"frame shift ms: {frame_shift_in_ms}")
# decoding frame to audio frame
frame_shift = self.model.encoder.embed.subsampling_rate
frame_shift_in_sec = frame_shift * (self.n_shift / self.sample_rate)
logger.info(f"frame shift sec: {frame_shift_in_sec}")
word_time_stamp = []
for idx, _ in enumerate(self.time_stamp):
start = (self.time_stamp[idx - 1] + self.time_stamp[idx]
) / 2.0 if idx > 0 else 0
start = start * frame_shift_in_ms
start = start * frame_shift_in_sec
end = (self.time_stamp[idx] + self.time_stamp[idx + 1]
) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset
end = end * frame_shift_in_ms
end = end * frame_shift_in_sec
word_time_stamp.append({
"w": self.result_transcripts[0][idx],
"bg": start,
"ed": end
})
# logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}")
# logger.info(f"{word_time_stamp[-1]}")
self.word_time_stamp = word_time_stamp
logger.info(f"word time stamp: {self.word_time_stamp}")
class ASRServerExecutor(ASRExecutor):
def __init__(self):
super().__init__()
......@@ -610,6 +722,7 @@ class ASRServerExecutor(ASRExecutor):
self.sample_rate = sample_rate
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
if cfg_path is None or am_model is None or am_params is None:
logger.info(f"Load the pretrained model, tag = {tag}")
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
......@@ -628,7 +741,7 @@ class ASRServerExecutor(ASRExecutor):
self.am_model = os.path.abspath(am_model)
self.am_params = os.path.abspath(am_params)
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.am_model)
......@@ -639,7 +752,7 @@ class ASRServerExecutor(ASRExecutor):
self.config.merge_from_file(self.cfg_path)
with UpdateConfig(self.config):
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
if "deepspeech2" in model_type:
from paddlespeech.s2t.io.collator import SpeechCollator
self.vocab = self.config.vocab_filepath
self.config.decode.lang_model_path = os.path.join(
......@@ -655,6 +768,7 @@ class ASRServerExecutor(ASRExecutor):
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type:
logger.info("start to create the stream conformer asr engine")
if self.config.spm_model_prefix:
......@@ -682,7 +796,8 @@ class ASRServerExecutor(ASRExecutor):
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
else:
raise Exception("wrong type")
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
if "deepspeech2" in model_type:
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf
......@@ -719,6 +834,7 @@ class ASRServerExecutor(ASRExecutor):
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
elif "conformer" in model_type or "transformer" in model_type:
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
......@@ -737,277 +853,14 @@ class ASRServerExecutor(ASRExecutor):
# update the ctc decoding
self.searcher = CTCPrefixBeamSearch(self.config.decode)
self.transformer_decode_reset()
return True
def reset_decoder_and_chunk(self):
"""reset decoder and chunk state for an new audio
"""
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
self.decoder.reset_decoder(batch_size=1)
# init state box, for new audio request
self.chunk_state_h_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
elif "conformer" in self.model_type or "transformer" in self.model_type:
self.transformer_decode_reset()
def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str):
"""decode one chunk
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
model_type (str): online model type
Returns:
str: one best result
"""
logger.info("start to decoce chunk by chunk")
if "deepspeech2online" in model_type:
input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(
input_names[1])
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk)
audio_len_handle.reshape(x_chunk_lens.shape)
audio_len_handle.copy_from_cpu(x_chunk_lens)
h_box_handle.reshape(self.chunk_state_h_box.shape)
h_box_handle.copy_from_cpu(self.chunk_state_h_box)
c_box_handle.reshape(self.chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(self.chunk_state_c_box)
output_names = self.am_predictor.get_output_names()
output_handle = self.am_predictor.get_output_handle(output_names[0])
output_lens_handle = self.am_predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.am_predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.am_predictor.get_output_handle(
output_names[3])
self.am_predictor.run()
output_chunk_probs = output_handle.copy_to_cpu()
output_chunk_lens = output_lens_handle.copy_to_cpu()
self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one best result: {trans_best[0]}")
return trans_best[0]
elif "conformer" in model_type or "transformer" in model_type:
try:
logger.info(
f"we will use the transformer like model : {self.model_type}"
)
self.advanced_decoding(x_chunk, x_chunk_lens)
self.update_result()
return self.result_transcripts[0]
except Exception as e:
logger.exception(e)
else:
raise Exception("invalid model name")
def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens):
logger.info("start to decode with advanced_decoding method")
encoder_out, encoder_mask = self.encoder_forward(xs)
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
self.searcher.search(ctc_probs, xs.place)
# update the one best result
self.hyps = self.searcher.get_one_best_hyps()
# now we supprot ctc_prefix_beam_search and attention_rescoring
if "attention_rescoring" in self.config.decode.decoding_method:
self.rescoring(encoder_out, xs.place)
def encoder_forward(self, xs):
logger.info("get the model out from the feat")
cfg = self.config.decode
decoding_chunk_size = cfg.decoding_chunk_size
num_decoding_left_chunks = cfg.num_decoding_left_chunks
assert decoding_chunk_size > 0
subsampling = self.model.encoder.embed.subsampling_rate
context = self.model.encoder.embed.right_context + 1
stride = subsampling * decoding_chunk_size
# decoding window for model
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
logger.info("start to do model forward")
outputs = []
# num_frames - context + 1 ensure that current frame can get context window
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
(y, self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size,
self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache)
outputs.append(y)
self.offset += y.shape[1]
ys = paddle.cat(outputs, 1)
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1)
return ys, masks
def rescoring(self, encoder_out, device):
logger.info("start to rescoring the hyps")
beam_size = self.config.decode.beam_size
hyps = self.searcher.get_hyps()
assert len(hyps) == beam_size
hyp_list = []
for hyp in hyps:
hyp_content = hyp[0]
# Prevent the hyp is empty
if len(hyp_content) == 0:
hyp_content = (self.model.ctc.blank_id, )
hyp_content = paddle.to_tensor(
hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id)
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos,
self.model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy()
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for i, hyp in enumerate(hyps):
score = 0.0
for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.model.eos]
# add ctc score (which in ln domain)
score += hyp[1] * self.config.decode.ctc_weight
if score > best_score:
best_score = score
best_index = i
# update the one best result
self.hyps = [hyps[best_index][0]]
return hyps[best_index][0]
def transformer_decode_reset(self):
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.offset = 0
# decoding reset
self.searcher.reset()
def update_result(self):
logger.info("update the final result")
hyps = self.hyps
self.result_transcripts = [
self.text_feature.defeaturize(hyp) for hyp in hyps
]
self.result_tokenids = [hyp for hyp in hyps]
def extract_feat(self, samples, sample_rate):
"""extract feat
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
if "deepspeech2online" in self.model_type:
# pcm16 -> pcm 32
samples = pcm2float(samples)
# read audio
speech_segment = SpeechSegment.from_pcm(
samples, sample_rate, transcript=" ")
# audio augment
self.collate_fn_test.augmentation.transform_audio(speech_segment)
# extract speech feature
spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
speech_segment, self.collate_fn_test.keep_transcription_text)
# CMVN spectrum
if self.collate_fn_test._normalizer:
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment
audio = self.collate_fn_test.augmentation.transform_feature(
spectrum)
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
# audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
x_chunk = audio.numpy()
x_chunk_lens = np.array([audio_len])
return x_chunk, x_chunk_lens
elif "conformer_online" in self.model_type:
if sample_rate != self.sample_rate:
logger.info(f"audio sample rate {sample_rate} is not match,"
"the model sample_rate is {self.sample_rate}")
logger.info(f"ASR Engine use the {self.model_type} to process")
logger.info("Create the preprocess instance")
preprocess_conf = self.config.preprocess_config
preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf)
logger.info("Read the audio file")
logger.info(f"audio shape: {samples.shape}")
# fbank
x_chunk = preprocessing(samples, **preprocess_args)
x_chunk_lens = paddle.to_tensor(x_chunk.shape[0])
x_chunk = paddle.to_tensor(
x_chunk, dtype="float32").unsqueeze(axis=0)
logger.info(
f"process the audio feature success, feat shape: {x_chunk.shape}"
)
return x_chunk, x_chunk_lens
raise ValueError(f"Not support: {model_type}")
return True
class ASREngine(BaseEngine):
"""ASR server engine
"""ASR server resource
Args:
metaclass: Defaults to Singleton.
......@@ -1015,7 +868,7 @@ class ASREngine(BaseEngine):
def __init__(self):
super(ASREngine, self).__init__()
logger.info("create the online asr engine instance")
logger.info("create the online asr engine resource instance")
def init(self, config: dict) -> bool:
"""init engine resource
......@@ -1026,17 +879,12 @@ class ASREngine(BaseEngine):
Returns:
bool: init failed or success
"""
self.input = None
self.output = ""
self.executor = ASRServerExecutor()
self.config = config
self.executor = ASRServerExecutor()
try:
if self.config.get("device", None):
self.device = self.config.device
else:
self.device = paddle.get_device()
logger.info(f"paddlespeech_server set the device: {self.device}")
paddle.set_device(self.device)
default_dev = paddle.get_device()
paddle.set_device(self.config.get("device", default_dev))
except BaseException as e:
logger.error(
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
......@@ -1045,6 +893,8 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'")
sys.exit(-1)
logger.info(f"paddlespeech_server set the device: {self.device}")
if not self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
......@@ -1062,42 +912,11 @@ class ASREngine(BaseEngine):
logger.info("Initialize ASR server engine successfully.")
return True
def preprocess(self,
samples,
sample_rate,
model_type="deepspeech2online_aishell-zh-16k"):
"""preprocess
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
# if "deepspeech" in model_type:
x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate)
return x_chunk, x_chunk_lens
def preprocess(self, *args, **kwargs):
raise NotImplementedError("Online not using this.")
def run(self, x_chunk, x_chunk_lens, decoder_chunk_size=1):
"""run online engine
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
decoder_chunk_size(int)
"""
self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens,
self.config.model_type)
def run(self, *args, **kwargs):
raise NotImplementedError("Online not using this.")
def postprocess(self):
"""postprocess
"""
return self.output
def reset(self):
"""reset engine decoder and inference state
"""
self.executor.reset_decoder_and_chunk()
self.output = ""
raise NotImplementedError("Online not using this.")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册