提交 380afbbc 编写于 作者: X xiongxinlei

add ds2 model multi session, test=doc

上级 5acb0b52
......@@ -18,44 +18,10 @@ engine_list: ['asr_online']
# ENGINE CONFIG #
#################################################################################
# ################################### ASR #########################################
# ################### speech task: asr; engine_type: online #######################
# asr_online:
# model_type: 'deepspeech2online_aishell'
# am_model: # the pdmodel file of am static model [optional]
# am_params: # the pdiparams file of am static model [optional]
# lang: 'zh'
# sample_rate: 16000
# cfg_path:
# decode_method:
# force_yes: True
# am_predictor_conf:
# device: # set 'gpu:id' or 'cpu'
# switch_ir_optim: True
# glog_info: False # True -> print glog
# summary: True # False -> do not show predictor config
# chunk_buffer_conf:
# frame_duration_ms: 80
# shift_ms: 40
# sample_rate: 16000
# sample_width: 2
# vad_conf:
# aggressiveness: 2
# sample_rate: 16000
# frame_duration_ms: 20
# sample_width: 2
# padding_ms: 200
# padding_ratio: 0.9
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'conformer2online_aishell'
model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
......@@ -71,9 +37,19 @@ asr_online:
summary: True # False -> do not show predictor config
chunk_buffer_conf:
frame_duration_ms: 80
shift_ms: 40
sample_rate: 16000
sample_width: 2
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 25 # ms
window_ms: 20 # ms
shift_ms: 10 # ms
vad_conf:
aggressiveness: 2
sample_rate: 16000
sample_width: 2
\ No newline at end of file
frame_duration_ms: 20
sample_width: 2
padding_ms: 200
padding_ratio: 0.9
\ No newline at end of file
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'conformer2online_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2
\ No newline at end of file
......@@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import Optional
import copy
import numpy as np
import paddle
from numpy import float32
......@@ -93,7 +93,7 @@ class PaddleASRConnectionHanddler:
)
self.config = asr_engine.config
self.model_config = asr_engine.executor.config
self.model = asr_engine.executor.model
# self.model = asr_engine.executor.model
self.asr_engine = asr_engine
self.init()
......@@ -102,7 +102,32 @@ class PaddleASRConnectionHanddler:
def init(self):
self.model_type = self.asr_engine.executor.model_type
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
pass
from paddlespeech.s2t.io.collator import SpeechCollator
self.sample_rate = self.asr_engine.executor.sample_rate
self.am_predictor = self.asr_engine.executor.am_predictor
self.text_feature = self.asr_engine.executor.text_feature
self.collate_fn_test = SpeechCollator.from_config(self.model_config)
self.decoder = CTCDecoder(
odim=self.model_config.output_dim, # <blank> is in vocab
enc_n_units=self.model_config.rnn_layer_size * 2,
blank_id=self.model_config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.model_config.get('ctc_grad_norm_type', None))
cfg = self.model_config.decode
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
# frame window samples length and frame shift samples length
self.win_length = int(self.model_config.window_ms * self.sample_rate)
self.n_shift = int(self.model_config.stride_ms * self.sample_rate)
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type:
self.sample_rate = self.asr_engine.executor.sample_rate
......@@ -127,7 +152,65 @@ class PaddleASRConnectionHanddler:
def extract_feat(self, samples):
if "deepspeech2online" in self.model_type:
pass
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The connection remain the audio samples: {self.remained_wav.shape}"
)
# pcm16 -> pcm 32
samples = pcm2float(self.remained_wav)
# read audio
speech_segment = SpeechSegment.from_pcm(
samples, self.sample_rate, transcript=" ")
# audio augment
self.collate_fn_test.augmentation.transform_audio(speech_segment)
# extract speech feature
spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
speech_segment, self.collate_fn_test.keep_transcription_text)
# CMVN spectrum
if self.collate_fn_test._normalizer:
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment
audio = self.collate_fn_test.augmentation.transform_feature(
spectrum)
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
# audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
if self.cached_feat is None:
self.cached_feat = audio
else:
assert (len(audio.shape) == 3)
assert (len(self.cached_feat.shape) == 3)
self.cached_feat = paddle.concat(
[self.cached_feat, audio], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
self.num_frames += audio_len
self.remained_wav = self.remained_wav[self.n_shift * audio_len:]
logger.info(
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
)
elif "conformer2online" in self.model_type:
logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16)
......@@ -179,24 +262,81 @@ class PaddleASRConnectionHanddler:
# logger.info(f"accumulate samples: {self.num_samples}")
def reset(self):
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.encoder_out = None
self.cached_feat = None
self.remained_wav = None
self.offset = 0
self.num_samples = 0
self.device = None
self.hyps = []
self.num_frames = 0
self.chunk_num = 0
self.global_frame_offset = 0
self.result_transcripts = ['']
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
# for deepspeech2
self.chunk_state_h_box = copy.deepcopy(self.asr_engine.executor.chunk_state_h_box)
self.chunk_state_c_box = copy.deepcopy(self.asr_engine.executor.chunk_state_c_box)
self.decoder.reset_decoder(batch_size=1)
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type:
# for conformer online
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.encoder_out = None
self.cached_feat = None
self.remained_wav = None
self.offset = 0
self.num_samples = 0
self.device = None
self.hyps = []
self.num_frames = 0
self.chunk_num = 0
self.global_frame_offset = 0
self.result_transcripts = ['']
def decode(self, is_finished=False):
if "deepspeech2online" in self.model_type:
pass
# x_chunk 是特征数据
decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model
context = 7 # context=7 in deepspeech2 model
subsampling = 4 # subsampling=4 in deepspeech2 model
stride = subsampling * decoding_chunk_size
cached_feature_num = context - subsampling
# decoding window for model
decoding_window = (decoding_chunk_size - 1) * subsampling + context
if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data")
return
num_frames = self.cached_feat.shape[1]
logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
)
# the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished:
logger.info(
f"frame feat num is less than {decoding_window}, please input more pcm data"
)
return None, None
# if is_finished=True, we need at least context frames
if num_frames < context:
logger.info(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return None, None
logger.info("start to do model forward")
# num_frames - context + 1 ensure that current frame can get context window
if is_finished:
# if get the finished chunk, we need process the last context
left_frames = context
else:
# we only process decoding_window frames for one chunk
left_frames = decoding_window
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
# extract the audio
x_chunk = self.cached_feat[:, cur:end, :].numpy()
x_chunk_lens = np.array([x_chunk.shape[1]])
trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)
self.result_transcripts = [trans_best]
self.cached_feat = self.cached_feat[:, end -
cached_feature_num:, :]
# return trans_best[0]
elif "conformer" in self.model_type or "transformer" in self.model_type:
try:
logger.info(
......@@ -210,6 +350,48 @@ class PaddleASRConnectionHanddler:
else:
raise Exception("invalid model name")
def decode_one_chunk(self, x_chunk, x_chunk_lens):
logger.info("start to decoce one chunk with deepspeech2 model")
input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(
input_names[1])
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk)
audio_len_handle.reshape(x_chunk_lens.shape)
audio_len_handle.copy_from_cpu(x_chunk_lens)
h_box_handle.reshape(self.chunk_state_h_box.shape)
h_box_handle.copy_from_cpu(self.chunk_state_h_box)
c_box_handle.reshape(self.chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(self.chunk_state_c_box)
output_names = self.am_predictor.get_output_names()
output_handle = self.am_predictor.get_output_handle(output_names[0])
output_lens_handle = self.am_predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.am_predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.am_predictor.get_output_handle(
output_names[3])
self.am_predictor.run()
output_chunk_probs = output_handle.copy_to_cpu()
output_chunk_lens = output_lens_handle.copy_to_cpu()
self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one one best result: {trans_best[0]}")
return trans_best[0]
def advance_decoding(self, is_finished=False):
logger.info("start to decode with advanced_decoding method")
cfg = self.ctc_decode_config
......@@ -240,6 +422,7 @@ class PaddleASRConnectionHanddler:
)
return None, None
# if is_finished=True, we need at least context frames
if num_frames < context:
logger.info(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
......@@ -315,6 +498,9 @@ class PaddleASRConnectionHanddler:
return ''
def rescoring(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
return
logger.info("rescoring the final result")
if "attention_rescoring" != self.ctc_decode_config.decoding_method:
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册