提交 f56dba0c 编写于 作者: X xiongxinlei

fix the code format, test=doc

上级 380afbbc
......@@ -129,7 +129,7 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer":
"paddlespeech.s2t.models.u2:U2Model",
"conformer2online":
"conformer_online":
"paddlespeech.s2t.models.u2:U2Model",
"transformer":
"paddlespeech.s2t.models.u2:U2Model",
......
......@@ -21,7 +21,7 @@ engine_list: ['asr_online']
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'conformer2online_aishell'
model_type: 'conformer_online_multi-cn'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
......
......@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
from typing import Optional
import copy
import numpy as np
import paddle
from numpy import float32
......@@ -58,7 +59,7 @@ pretrained_models = {
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"conformer2online_aishell-zh-16k": {
"conformer_online_multi-cn-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz',
'md5':
......@@ -93,19 +94,22 @@ class PaddleASRConnectionHanddler:
)
self.config = asr_engine.config
self.model_config = asr_engine.executor.config
# self.model = asr_engine.executor.model
self.asr_engine = asr_engine
self.init()
self.reset()
def init(self):
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self.model_type = self.asr_engine.executor.model_type
self.sample_rate = self.asr_engine.executor.sample_rate
# tokens to text
self.text_feature = self.asr_engine.executor.text_feature
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
from paddlespeech.s2t.io.collator import SpeechCollator
self.sample_rate = self.asr_engine.executor.sample_rate
self.am_predictor = self.asr_engine.executor.am_predictor
self.text_feature = self.asr_engine.executor.text_feature
self.collate_fn_test = SpeechCollator.from_config(self.model_config)
self.decoder = CTCDecoder(
odim=self.model_config.output_dim, # <blank> is in vocab
......@@ -114,7 +118,8 @@ class PaddleASRConnectionHanddler:
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.model_config.get('ctc_grad_norm_type', None))
grad_norm_type=self.model_config.get('ctc_grad_norm_type',
None))
cfg = self.model_config.decode
decode_batch_size = 1 # for online
......@@ -123,20 +128,16 @@ class PaddleASRConnectionHanddler:
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
# frame window samples length and frame shift samples length
self.win_length = int(self.model_config.window_ms * self.sample_rate)
self.n_shift = int(self.model_config.stride_ms * self.sample_rate)
# frame window samples length and frame shift samples length
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type:
self.sample_rate = self.asr_engine.executor.sample_rate
self.win_length = int(self.model_config.window_ms *
self.sample_rate)
self.n_shift = int(self.model_config.stride_ms * self.sample_rate)
elif "conformer" in self.model_type or "transformer" in self.model_type:
# acoustic model
self.model = self.asr_engine.executor.model
# tokens to text
self.text_feature = self.asr_engine.executor.text_feature
# ctc decoding config
self.ctc_decode_config = self.asr_engine.executor.config.decode
self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config)
......@@ -189,7 +190,7 @@ class PaddleASRConnectionHanddler:
audio = paddle.to_tensor(audio, dtype='float32')
# audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
if self.cached_feat is None:
self.cached_feat = audio
else:
......@@ -211,7 +212,7 @@ class PaddleASRConnectionHanddler:
logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
)
elif "conformer2online" in self.model_type:
elif "conformer_online" in self.model_type:
logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
......@@ -264,41 +265,43 @@ class PaddleASRConnectionHanddler:
def reset(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
# for deepspeech2
self.chunk_state_h_box = copy.deepcopy(self.asr_engine.executor.chunk_state_h_box)
self.chunk_state_c_box = copy.deepcopy(self.asr_engine.executor.chunk_state_c_box)
self.chunk_state_h_box = copy.deepcopy(
self.asr_engine.executor.chunk_state_h_box)
self.chunk_state_c_box = copy.deepcopy(
self.asr_engine.executor.chunk_state_c_box)
self.decoder.reset_decoder(batch_size=1)
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type:
# for conformer online
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.encoder_out = None
self.cached_feat = None
self.remained_wav = None
self.offset = 0
self.num_samples = 0
self.device = None
self.hyps = []
self.num_frames = 0
self.chunk_num = 0
self.global_frame_offset = 0
self.result_transcripts = ['']
# for conformer online
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.encoder_out = None
self.cached_feat = None
self.remained_wav = None
self.offset = 0
self.num_samples = 0
self.device = None
self.hyps = []
self.num_frames = 0
self.chunk_num = 0
self.global_frame_offset = 0
self.result_transcripts = ['']
def decode(self, is_finished=False):
if "deepspeech2online" in self.model_type:
# x_chunk 是特征数据
decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model
context = 7 # context=7 in deepspeech2 model
subsampling = 4 # subsampling=4 in deepspeech2 model
decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model
context = 7 # context=7 in deepspeech2 model
subsampling = 4 # subsampling=4 in deepspeech2 model
stride = subsampling * decoding_chunk_size
cached_feature_num = context - subsampling
# decoding window for model
decoding_window = (decoding_chunk_size - 1) * subsampling + context
decoding_window = (decoding_chunk_size - 1) * subsampling + context
if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data")
return
return
num_frames = self.cached_feat.shape[1]
logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
......@@ -306,14 +309,14 @@ class PaddleASRConnectionHanddler:
# the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished:
logger.info(
f"frame feat num is less than {decoding_window}, please input more pcm data"
f"frame feat num is less than {decoding_window}, please input more pcm data"
)
return None, None
# if is_finished=True, we need at least context frames
if num_frames < context:
logger.info(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return None, None
logger.info("start to do model forward")
......@@ -334,8 +337,7 @@ class PaddleASRConnectionHanddler:
self.result_transcripts = [trans_best]
self.cached_feat = self.cached_feat[:, end -
cached_feature_num:, :]
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
# return trans_best[0]
elif "conformer" in self.model_type or "transformer" in self.model_type:
try:
......@@ -354,8 +356,7 @@ class PaddleASRConnectionHanddler:
logger.info("start to decoce one chunk with deepspeech2 model")
input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(
input_names[1])
audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
......@@ -374,11 +375,11 @@ class PaddleASRConnectionHanddler:
output_names = self.am_predictor.get_output_names()
output_handle = self.am_predictor.get_output_handle(output_names[0])
output_lens_handle = self.am_predictor.get_output_handle(
output_names[1])
output_names[1])
output_state_h_handle = self.am_predictor.get_output_handle(
output_names[2])
output_names[2])
output_state_c_handle = self.am_predictor.get_output_handle(
output_names[3])
output_names[3])
self.am_predictor.run()
......@@ -389,7 +390,7 @@ class PaddleASRConnectionHanddler:
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one one best result: {trans_best[0]}")
logger.info(f"decode one best result: {trans_best[0]}")
return trans_best[0]
def advance_decoding(self, is_finished=False):
......@@ -500,7 +501,7 @@ class PaddleASRConnectionHanddler:
def rescoring(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
return
logger.info("rescoring the final result")
if "attention_rescoring" != self.ctc_decode_config.decoding_method:
return
......@@ -587,7 +588,7 @@ class ASRServerExecutor(ASRExecutor):
return decompressed_path
def _init_from_path(self,
model_type: str='wenetspeech',
model_type: str='deepspeech2online_aishell',
am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None,
lang: str='zh',
......@@ -647,7 +648,7 @@ class ASRServerExecutor(ASRExecutor):
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
elif "conformer" in model_type or "transformer" in model_type:
logger.info("start to create the stream conformer asr engine")
if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join(
......@@ -711,7 +712,7 @@ class ASRServerExecutor(ASRExecutor):
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
elif "conformer" in model_type or "transformer" in model_type:
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
logger.info(f"model name: {model_name}")
......@@ -742,7 +743,7 @@ class ASRServerExecutor(ASRExecutor):
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type:
elif "conformer" in self.model_type or "transformer" in self.model_type:
self.transformer_decode_reset()
def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str):
......@@ -754,7 +755,7 @@ class ASRServerExecutor(ASRExecutor):
model_type (str): online model type
Returns:
[type]: [description]
str: one best result
"""
logger.info("start to decoce chunk by chunk")
if "deepspeech2online" in model_type:
......@@ -795,7 +796,7 @@ class ASRServerExecutor(ASRExecutor):
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one one best result: {trans_best[0]}")
logger.info(f"decode one best result: {trans_best[0]}")
return trans_best[0]
elif "conformer" in model_type or "transformer" in model_type:
......@@ -972,7 +973,7 @@ class ASRServerExecutor(ASRExecutor):
x_chunk_lens = np.array([audio_len])
return x_chunk, x_chunk_lens
elif "conformer2online" in self.model_type:
elif "conformer_online" in self.model_type:
if sample_rate != self.sample_rate:
logger.info(f"audio sample rate {sample_rate} is not match,"
......@@ -1005,7 +1006,7 @@ class ASREngine(BaseEngine):
def __init__(self):
super(ASREngine, self).__init__()
logger.info("create the online asr engine instache")
logger.info("create the online asr engine instance")
def init(self, config: dict) -> bool:
"""init engine resource
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册