未验证 提交 f07f57a3 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1945 from PaddlePaddle/asr_line

[server][asr] refactor asr streaming server and remove useless code
...@@ -26,9 +26,8 @@ def get_audios(path): ...@@ -26,9 +26,8 @@ def get_audios(path):
""" """
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"] supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
return [ return [
item item for sublist in [[os.path.join(dir, file) for file in files]
for sublist in [[os.path.join(dir, file) for file in files] for dir, _, files in list(os.walk(path))]
for dir, _, files in list(os.walk(path))]
for item in sublist if os.path.splitext(item)[1] in supported_formats for item in sublist if os.path.splitext(item)[1] in supported_formats
] ]
......
...@@ -13,9 +13,7 @@ ...@@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
#!/usr/bin/python #!/usr/bin/python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}' # script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
import argparse import argparse
import asyncio import asyncio
import codecs import codecs
......
...@@ -92,5 +92,3 @@ server 的 demo: [streaming_asr_server](https://github.com/PaddlePaddle/Paddle ...@@ -92,5 +92,3 @@ server 的 demo: [streaming_asr_server](https://github.com/PaddlePaddle/Paddle
## 4. 快速开始 ## 4. 快速开始
关于如果使用 PP-ASR,可以看这里的 [install](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md),其中提供了 **简单****中等****困难** 三种安装方式。如果想体验 paddlespeech 的推理功能,可以用 **简单** 安装方式。 关于如果使用 PP-ASR,可以看这里的 [install](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md),其中提供了 **简单****中等****困难** 三种安装方式。如果想体验 paddlespeech 的推理功能,可以用 **简单** 安装方式。
...@@ -24,11 +24,11 @@ from typing import Any ...@@ -24,11 +24,11 @@ from typing import Any
from typing import Dict from typing import Dict
import paddle import paddle
import paddleaudio
import requests import requests
import yaml import yaml
from paddle.framework import load from paddle.framework import load
import paddleaudio
from . import download from . import download
from .entry import commands from .entry import commands
try: try:
......
...@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): ...@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -38,7 +38,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine ...@@ -38,7 +38,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.audio_process import pcm2float
from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import init_predictor
__all__ = ['ASREngine'] __all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine']
# ASR server connection process class # ASR server connection process class
...@@ -67,7 +67,7 @@ class PaddleASRConnectionHanddler: ...@@ -67,7 +67,7 @@ class PaddleASRConnectionHanddler:
# tokens to text # tokens to text
self.text_feature = self.asr_engine.executor.text_feature self.text_feature = self.asr_engine.executor.text_feature
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: if "deepspeech2" in self.model_type:
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
self.am_predictor = self.asr_engine.executor.am_predictor self.am_predictor = self.asr_engine.executor.am_predictor
...@@ -89,8 +89,8 @@ class PaddleASRConnectionHanddler: ...@@ -89,8 +89,8 @@ class PaddleASRConnectionHanddler:
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch) cfg.num_proc_bsearch)
# frame window samples length and frame shift samples length
# frame window and frame shift, in samples unit
self.win_length = int(self.model_config.window_ms / 1000 * self.win_length = int(self.model_config.window_ms / 1000 *
self.sample_rate) self.sample_rate)
self.n_shift = int(self.model_config.stride_ms / 1000 * self.n_shift = int(self.model_config.stride_ms / 1000 *
...@@ -109,16 +109,15 @@ class PaddleASRConnectionHanddler: ...@@ -109,16 +109,15 @@ class PaddleASRConnectionHanddler:
self.preprocess_args = {"train": False} self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf) self.preprocessing = Transformation(self.preprocess_conf)
# frame window samples length and frame shift samples length # frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length'] self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift'] self.n_shift = self.preprocess_conf.process[0]['n_shift']
else:
raise ValueError(f"Not supported: {self.model_type}")
def extract_feat(self, samples): def extract_feat(self, samples):
# we compute the elapsed time of first char occuring # we compute the elapsed time of first char occuring
# and we record the start time at the first pcm sample arraving # and we record the start time at the first pcm sample arraving
# if self.first_char_occur_elapsed is not None:
# self.first_char_occur_elapsed = time.time()
if "deepspeech2online" in self.model_type: if "deepspeech2online" in self.model_type:
# self.reamined_wav stores all the samples, # self.reamined_wav stores all the samples,
...@@ -154,28 +153,27 @@ class PaddleASRConnectionHanddler: ...@@ -154,28 +153,27 @@ class PaddleASRConnectionHanddler:
spectrum = self.collate_fn_test._normalizer.apply(spectrum) spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment # spectrum augment
audio = self.collate_fn_test.augmentation.transform_feature( feat = self.collate_fn_test.augmentation.transform_feature(spectrum)
spectrum)
audio_len = audio.shape[0] # audio_len is frame num
audio = paddle.to_tensor(audio, dtype='float32') frame_num = feat.shape[0]
# audio_len = paddle.to_tensor(audio_len) feat = paddle.to_tensor(feat, dtype='float32')
audio = paddle.unsqueeze(audio, axis=0) feat = paddle.unsqueeze(feat, axis=0)
if self.cached_feat is None: if self.cached_feat is None:
self.cached_feat = audio self.cached_feat = feat
else: else:
assert (len(audio.shape) == 3) assert (len(feat.shape) == 3)
assert (len(self.cached_feat.shape) == 3) assert (len(self.cached_feat.shape) == 3)
self.cached_feat = paddle.concat( self.cached_feat = paddle.concat(
[self.cached_feat, audio], axis=1) [self.cached_feat, feat], axis=1)
# set the feat device # set the feat device
if self.device is None: if self.device is None:
self.device = self.cached_feat.place self.device = self.cached_feat.place
self.num_frames += audio_len self.num_frames += frame_num
self.remained_wav = self.remained_wav[self.n_shift * audio_len:] self.remained_wav = self.remained_wav[self.n_shift * frame_num:]
logger.info( logger.info(
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
...@@ -183,25 +181,30 @@ class PaddleASRConnectionHanddler: ...@@ -183,25 +181,30 @@ class PaddleASRConnectionHanddler:
logger.info( logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
) )
elif "conformer_online" in self.model_type: elif "conformer_online" in self.model_type:
logger.info("Online ASR extract the feat") logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16) samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1 assert samples.ndim == 1
logger.info(f"This package receive {samples.shape[0]} pcm data")
self.num_samples += samples.shape[0] self.num_samples += samples.shape[0]
logger.info(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
)
# self.reamined_wav stores all the samples, # self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples # include the original remained_wav and this package samples
if self.remained_wav is None: if self.remained_wav is None:
self.remained_wav = samples self.remained_wav = samples
else: else:
assert self.remained_wav.ndim == 1 assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples]) self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info( logger.info(
f"The connection remain the audio samples: {self.remained_wav.shape}" f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
) )
if len(self.remained_wav) < self.win_length: if len(self.remained_wav) < self.win_length:
# samples not enough for feature window
return 0 return 0
# fbank # fbank
...@@ -209,11 +212,13 @@ class PaddleASRConnectionHanddler: ...@@ -209,11 +212,13 @@ class PaddleASRConnectionHanddler:
**self.preprocess_args) **self.preprocess_args)
x_chunk = paddle.to_tensor( x_chunk = paddle.to_tensor(
x_chunk, dtype="float32").unsqueeze(axis=0) x_chunk, dtype="float32").unsqueeze(axis=0)
# feature cache
if self.cached_feat is None: if self.cached_feat is None:
self.cached_feat = x_chunk self.cached_feat = x_chunk
else: else:
assert (len(x_chunk.shape) == 3) assert (len(x_chunk.shape) == 3) # (B,T,D)
assert (len(self.cached_feat.shape) == 3) assert (len(self.cached_feat.shape) == 3) # (B,T,D)
self.cached_feat = paddle.concat( self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1) [self.cached_feat, x_chunk], axis=1)
...@@ -221,20 +226,28 @@ class PaddleASRConnectionHanddler: ...@@ -221,20 +226,28 @@ class PaddleASRConnectionHanddler:
if self.device is None: if self.device is None:
self.device = self.cached_feat.place self.device = self.cached_feat.place
# cur frame step
num_frames = x_chunk.shape[1] num_frames = x_chunk.shape[1]
# global frame step
self.num_frames += num_frames self.num_frames += num_frames
# update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:] self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info( logger.info(
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
) )
logger.info( logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
) )
# logger.info(f"accumulate samples: {self.num_samples}") logger.info(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}")
else:
raise ValueError(f"not supported: {self.model_type}")
def reset(self): def reset(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: if "deepspeech2" in self.model_type:
# for deepspeech2 # for deepspeech2
self.chunk_state_h_box = copy.deepcopy( self.chunk_state_h_box = copy.deepcopy(
self.asr_engine.executor.chunk_state_h_box) self.asr_engine.executor.chunk_state_h_box)
...@@ -242,35 +255,61 @@ class PaddleASRConnectionHanddler: ...@@ -242,35 +255,61 @@ class PaddleASRConnectionHanddler:
self.asr_engine.executor.chunk_state_c_box) self.asr_engine.executor.chunk_state_c_box)
self.decoder.reset_decoder(batch_size=1) self.decoder.reset_decoder(batch_size=1)
# for conformer online self.device = None
## common
# global sample and frame step
self.num_samples = 0
self.num_frames = 0
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
# partial/ending decoding results
self.result_transcripts = ['']
## conformer
# cache for conformer online
self.subsampling_cache = None self.subsampling_cache = None
self.elayers_output_cache = None self.elayers_output_cache = None
self.conformer_cnn_cache = None self.conformer_cnn_cache = None
self.encoder_out = None self.encoder_out = None
self.cached_feat = None # conformer decoding state
self.remained_wav = None self.chunk_num = 0 # globa decoding chunk num
self.offset = 0 self.offset = 0 # global offset in decoding frame unit
self.num_samples = 0
self.device = None
self.hyps = [] self.hyps = []
self.num_frames = 0
self.chunk_num = 0 # token timestamp result
self.global_frame_offset = 0
self.result_transcripts = ['']
self.word_time_stamp = [] self.word_time_stamp = []
# one best timestamp viterbi prob is large.
self.time_stamp = [] self.time_stamp = []
self.first_char_occur_elapsed = None
def decode(self, is_finished=False): def decode(self, is_finished=False):
"""advance decoding
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Raises:
Exception: when not support model.
Returns:
None: nothing
"""
if "deepspeech2online" in self.model_type: if "deepspeech2online" in self.model_type:
# x_chunk 是特征数据 decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model context = 7 # context=7, in audio frame unit
context = 7 # context=7 in deepspeech2 model subsampling = 4 # subsampling=4, in audio frame unit
subsampling = 4 # subsampling=4 in deepspeech2 model
stride = subsampling * decoding_chunk_size
cached_feature_num = context - subsampling cached_feature_num = context - subsampling
# decoding window for model # decoding window for model, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context decoding_window = (decoding_chunk_size - 1) * subsampling + context
# decoding stride for model, in audio frame unit
stride = subsampling * decoding_chunk_size
if self.cached_feat is None: if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data") logger.info("no audio feat, please input more pcm data")
...@@ -280,6 +319,7 @@ class PaddleASRConnectionHanddler: ...@@ -280,6 +319,7 @@ class PaddleASRConnectionHanddler:
logger.info( logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
) )
# the cached feat must be larger decoding_window # the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished: if num_frames < decoding_window and not is_finished:
logger.info( logger.info(
...@@ -293,6 +333,7 @@ class PaddleASRConnectionHanddler: ...@@ -293,6 +333,7 @@ class PaddleASRConnectionHanddler:
"flast {num_frames} is less than context {context} frames, and we cannot do model forward" "flast {num_frames} is less than context {context} frames, and we cannot do model forward"
) )
return None, None return None, None
logger.info("start to do model forward") logger.info("start to do model forward")
# num_frames - context + 1 ensure that current frame can get context window # num_frames - context + 1 ensure that current frame can get context window
if is_finished: if is_finished:
...@@ -302,6 +343,7 @@ class PaddleASRConnectionHanddler: ...@@ -302,6 +343,7 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk # we only process decoding_window frames for one chunk
left_frames = decoding_window left_frames = decoding_window
end = None
for cur in range(0, num_frames - left_frames + 1, stride): for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames) end = min(cur + decoding_window, num_frames)
# extract the audio # extract the audio
...@@ -311,7 +353,9 @@ class PaddleASRConnectionHanddler: ...@@ -311,7 +353,9 @@ class PaddleASRConnectionHanddler:
self.result_transcripts = [trans_best] self.result_transcripts = [trans_best]
# update feat cache
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
# return trans_best[0] # return trans_best[0]
elif "conformer" in self.model_type or "transformer" in self.model_type: elif "conformer" in self.model_type or "transformer" in self.model_type:
try: try:
...@@ -328,7 +372,16 @@ class PaddleASRConnectionHanddler: ...@@ -328,7 +372,16 @@ class PaddleASRConnectionHanddler:
@paddle.no_grad() @paddle.no_grad()
def decode_one_chunk(self, x_chunk, x_chunk_lens): def decode_one_chunk(self, x_chunk, x_chunk_lens):
logger.info("start to decoce one chunk with deepspeech2 model") """forward one chunk frames
Args:
x_chunk (np.ndarray): (B,T,D), audio frames.
x_chunk_lens ([type]): (B,), audio frame lens
Returns:
logprob: poster probability.
"""
logger.info("start to decoce one chunk for deepspeech2")
input_names = self.am_predictor.get_input_names() input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
...@@ -365,24 +418,32 @@ class PaddleASRConnectionHanddler: ...@@ -365,24 +418,32 @@ class PaddleASRConnectionHanddler:
self.decoder.next(output_chunk_probs, output_chunk_lens) self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode() trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one best result: {trans_best[0]}") logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
return trans_best[0] return trans_best[0]
@paddle.no_grad() @paddle.no_grad()
def advance_decoding(self, is_finished=False): def advance_decoding(self, is_finished=False):
logger.info("start to decode with advanced_decoding method") logger.info(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
cfg = self.ctc_decode_config cfg = self.ctc_decode_config
# cur chunk size, in decoding frame unit
decoding_chunk_size = cfg.decoding_chunk_size decoding_chunk_size = cfg.decoding_chunk_size
# using num of history chunks
num_decoding_left_chunks = cfg.num_decoding_left_chunks num_decoding_left_chunks = cfg.num_decoding_left_chunks
assert decoding_chunk_size > 0 assert decoding_chunk_size > 0
subsampling = self.model.encoder.embed.subsampling_rate subsampling = self.model.encoder.embed.subsampling_rate
context = self.model.encoder.embed.right_context + 1 context = self.model.encoder.embed.right_context + 1
stride = subsampling * decoding_chunk_size
cached_feature_num = context - subsampling # processed chunk feature cached for next chunk
# decoding window for model # processed chunk feature cached for next chunk
cached_feature_num = context - subsampling
# decoding stride, in audio frame unit
stride = subsampling * decoding_chunk_size
# decoding window, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context decoding_window = (decoding_chunk_size - 1) * subsampling + context
if self.cached_feat is None: if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data") logger.info("no audio feat, please input more pcm data")
return return
...@@ -407,6 +468,7 @@ class PaddleASRConnectionHanddler: ...@@ -407,6 +468,7 @@ class PaddleASRConnectionHanddler:
return None, None return None, None
logger.info("start to do model forward") logger.info("start to do model forward")
# hist of chunks, in deocding frame unit
required_cache_size = decoding_chunk_size * num_decoding_left_chunks required_cache_size = decoding_chunk_size * num_decoding_left_chunks
outputs = [] outputs = []
...@@ -423,8 +485,11 @@ class PaddleASRConnectionHanddler: ...@@ -423,8 +485,11 @@ class PaddleASRConnectionHanddler:
for cur in range(0, num_frames - left_frames + 1, stride): for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames) end = min(cur + decoding_window, num_frames)
# global chunk_num
self.chunk_num += 1 self.chunk_num += 1
# cur chunk
chunk_xs = self.cached_feat[:, cur:end, :] chunk_xs = self.cached_feat[:, cur:end, :]
# forward chunk
(y, self.subsampling_cache, self.elayers_output_cache, (y, self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache) = self.model.encoder.forward_chunk( self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size, chunk_xs, self.offset, required_cache_size,
...@@ -432,7 +497,7 @@ class PaddleASRConnectionHanddler: ...@@ -432,7 +497,7 @@ class PaddleASRConnectionHanddler:
self.conformer_cnn_cache) self.conformer_cnn_cache)
outputs.append(y) outputs.append(y)
# update the offset # update the global offset, in decoding frame unit
self.offset += y.shape[1] self.offset += y.shape[1]
ys = paddle.cat(outputs, 1) ys = paddle.cat(outputs, 1)
...@@ -445,12 +510,15 @@ class PaddleASRConnectionHanddler: ...@@ -445,12 +510,15 @@ class PaddleASRConnectionHanddler:
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
# advance decoding
self.searcher.search(ctc_probs, self.cached_feat.place) self.searcher.search(ctc_probs, self.cached_feat.place)
# get one best hyps
self.hyps = self.searcher.get_one_best_hyps() self.hyps = self.searcher.get_one_best_hyps()
assert self.cached_feat.shape[0] == 1 assert self.cached_feat.shape[0] == 1
assert end >= cached_feature_num assert end >= cached_feature_num
# advance cache of feat
self.cached_feat = self.cached_feat[0, end - self.cached_feat = self.cached_feat[0, end -
cached_feature_num:, :].unsqueeze(0) cached_feature_num:, :].unsqueeze(0)
assert len( assert len(
...@@ -462,50 +530,81 @@ class PaddleASRConnectionHanddler: ...@@ -462,50 +530,81 @@ class PaddleASRConnectionHanddler:
) )
def update_result(self): def update_result(self):
"""Conformer/Transformer hyps to result.
"""
logger.info("update the final result") logger.info("update the final result")
hyps = self.hyps hyps = self.hyps
# output results and tokenids
self.result_transcripts = [ self.result_transcripts = [
self.text_feature.defeaturize(hyp) for hyp in hyps self.text_feature.defeaturize(hyp) for hyp in hyps
] ]
self.result_tokenids = [hyp for hyp in hyps] self.result_tokenids = [hyp for hyp in hyps]
def get_result(self): def get_result(self):
"""return partial/ending asr result.
Returns:
str: one best result of partial/ending.
"""
if len(self.result_transcripts) > 0: if len(self.result_transcripts) > 0:
return self.result_transcripts[0] return self.result_transcripts[0]
else: else:
return '' return ''
def get_word_time_stamp(self): def get_word_time_stamp(self):
"""return token timestamp result.
Returns:
list: List of ('w':token, 'bg':time, 'ed':time)
"""
return self.word_time_stamp return self.word_time_stamp
@paddle.no_grad() @paddle.no_grad()
def rescoring(self): def rescoring(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: """Second-Pass Decoding,
only for conformer and transformer model.
"""
if "deepspeech2" in self.model_type:
logger.info("deepspeech2 not support rescoring decoding.")
return return
logger.info("rescoring the final result")
if "attention_rescoring" != self.ctc_decode_config.decoding_method: if "attention_rescoring" != self.ctc_decode_config.decoding_method:
logger.info(
f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring"
)
return return
logger.info("rescoring the final result")
# last decoding for last audio
self.searcher.finalize_search() self.searcher.finalize_search()
# update beam search results
self.update_result() self.update_result()
beam_size = self.ctc_decode_config.beam_size beam_size = self.ctc_decode_config.beam_size
hyps = self.searcher.get_hyps() hyps = self.searcher.get_hyps()
if hyps is None or len(hyps) == 0: if hyps is None or len(hyps) == 0:
logger.info("No Hyps!")
return return
# rescore by decoder post probability
# assert len(hyps) == beam_size # assert len(hyps) == beam_size
# list of Tensor
hyp_list = [] hyp_list = []
for hyp in hyps: for hyp in hyps:
hyp_content = hyp[0] hyp_content = hyp[0]
# Prevent the hyp is empty # Prevent the hyp is empty
if len(hyp_content) == 0: if len(hyp_content) == 0:
hyp_content = (self.model.ctc.blank_id, ) hyp_content = (self.model.ctc.blank_id, )
hyp_content = paddle.to_tensor( hyp_content = paddle.to_tensor(
hyp_content, place=self.device, dtype=paddle.long) hyp_content, place=self.device, dtype=paddle.long)
hyp_list.append(hyp_content) hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id)
hyps_pad = pad_sequence(
hyp_list, batch_first=True, padding_value=self.model.ignore_id)
hyps_lens = paddle.to_tensor( hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=self.device, [len(hyp[0]) for hyp in hyps], place=self.device,
dtype=paddle.long) # (beam_size,) dtype=paddle.long) # (beam_size,)
...@@ -531,10 +630,12 @@ class PaddleASRConnectionHanddler: ...@@ -531,10 +630,12 @@ class PaddleASRConnectionHanddler:
score = 0.0 score = 0.0
for j, w in enumerate(hyp[0]): for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w] score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token. # last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.model.eos] score += decoder_out[i][len(hyp[0])][self.model.eos]
# add ctc score (which in ln domain) # add ctc score (which in ln domain)
score += hyp[1] * self.ctc_decode_config.ctc_weight score += hyp[1] * self.ctc_decode_config.ctc_weight
if score > best_score: if score > best_score:
best_score = score best_score = score
best_index = i best_index = i
...@@ -542,43 +643,52 @@ class PaddleASRConnectionHanddler: ...@@ -542,43 +643,52 @@ class PaddleASRConnectionHanddler:
# update the one best result # update the one best result
# hyps stored the beam results and each fields is: # hyps stored the beam results and each fields is:
logger.info(f"best index: {best_index}") logger.info(f"best hyp index: {best_index}")
# logger.info(f'best result: {hyps[best_index]}') # logger.info(f'best result: {hyps[best_index]}')
# the field of the hyps is: # the field of the hyps is:
## asr results
# hyps[0][0]: the sentence word-id in the vocab with a tuple # hyps[0][0]: the sentence word-id in the vocab with a tuple
# hyps[0][1]: the sentence decoding probability with all paths # hyps[0][1]: the sentence decoding probability with all paths
## timestamp
# hyps[0][2]: viterbi_blank ending probability # hyps[0][2]: viterbi_blank ending probability
# hyps[0][3]: viterbi_non_blank probability # hyps[0][3]: viterbi_non_blank dending probability
# hyps[0][4]: current_token_prob, # hyps[0][4]: current_token_prob,
# hyps[0][5]: times_viterbi_blank, # hyps[0][5]: times_viterbi_blank ending timestamp,
# hyps[0][6]: times_titerbi_non_blank # hyps[0][6]: times_titerbi_non_blank encding timestamp.
self.hyps = [hyps[best_index][0]] self.hyps = [hyps[best_index][0]]
logger.info(f"best hyp ids: {self.hyps}")
# update the hyps time stamp # update the hyps time stamp
self.time_stamp = hyps[best_index][5] if hyps[best_index][2] > hyps[ self.time_stamp = hyps[best_index][5] if hyps[best_index][2] > hyps[
best_index][3] else hyps[best_index][6] best_index][3] else hyps[best_index][6]
logger.info(f"time stamp: {self.time_stamp}") logger.info(f"time stamp: {self.time_stamp}")
# update one best result
self.update_result() self.update_result()
# update each word start and end time stamp # update each word start and end time stamp
frame_shift_in_ms = self.model.encoder.embed.subsampling_rate * self.n_shift / self.sample_rate # decoding frame to audio frame
logger.info(f"frame shift ms: {frame_shift_in_ms}") frame_shift = self.model.encoder.embed.subsampling_rate
frame_shift_in_sec = frame_shift * (self.n_shift / self.sample_rate)
logger.info(f"frame shift sec: {frame_shift_in_sec}")
word_time_stamp = [] word_time_stamp = []
for idx, _ in enumerate(self.time_stamp): for idx, _ in enumerate(self.time_stamp):
start = (self.time_stamp[idx - 1] + self.time_stamp[idx] start = (self.time_stamp[idx - 1] + self.time_stamp[idx]
) / 2.0 if idx > 0 else 0 ) / 2.0 if idx > 0 else 0
start = start * frame_shift_in_ms start = start * frame_shift_in_sec
end = (self.time_stamp[idx] + self.time_stamp[idx + 1] end = (self.time_stamp[idx] + self.time_stamp[idx + 1]
) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset ) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset
end = end * frame_shift_in_ms
end = end * frame_shift_in_sec
word_time_stamp.append({ word_time_stamp.append({
"w": self.result_transcripts[0][idx], "w": self.result_transcripts[0][idx],
"bg": start, "bg": start,
"ed": end "ed": end
}) })
# logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}") # logger.info(f"{word_time_stamp[-1]}")
self.word_time_stamp = word_time_stamp self.word_time_stamp = word_time_stamp
logger.info(f"word time stamp: {self.word_time_stamp}") logger.info(f"word time stamp: {self.word_time_stamp}")
...@@ -610,6 +720,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -610,6 +720,7 @@ class ASRServerExecutor(ASRExecutor):
self.sample_rate = sample_rate self.sample_rate = sample_rate
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
if cfg_path is None or am_model is None or am_params is None: if cfg_path is None or am_model is None or am_params is None:
logger.info(f"Load the pretrained model, tag = {tag}") logger.info(f"Load the pretrained model, tag = {tag}")
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
...@@ -639,7 +750,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -639,7 +750,7 @@ class ASRServerExecutor(ASRExecutor):
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
with UpdateConfig(self.config): with UpdateConfig(self.config):
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: if "deepspeech2" in model_type:
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
self.vocab = self.config.vocab_filepath self.vocab = self.config.vocab_filepath
self.config.decode.lang_model_path = os.path.join( self.config.decode.lang_model_path = os.path.join(
...@@ -655,6 +766,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -655,6 +766,7 @@ class ASRServerExecutor(ASRExecutor):
self.download_lm( self.download_lm(
lm_url, lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5) os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
logger.info("start to create the stream conformer asr engine") logger.info("start to create the stream conformer asr engine")
if self.config.spm_model_prefix: if self.config.spm_model_prefix:
...@@ -682,7 +794,8 @@ class ASRServerExecutor(ASRExecutor): ...@@ -682,7 +794,8 @@ class ASRServerExecutor(ASRExecutor):
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
else: else:
raise Exception("wrong type") raise Exception("wrong type")
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
if "deepspeech2" in model_type:
# AM predictor # AM predictor
logger.info("ASR engine start to init the am predictor") logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf self.am_predictor_conf = am_predictor_conf
...@@ -719,6 +832,7 @@ class ASRServerExecutor(ASRExecutor): ...@@ -719,6 +832,7 @@ class ASRServerExecutor(ASRExecutor):
self.chunk_state_c_box = np.zeros( self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size), (self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32) dtype=float32)
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset} '_')] # model_type: {model_name}_{dataset}
...@@ -737,277 +851,14 @@ class ASRServerExecutor(ASRExecutor): ...@@ -737,277 +851,14 @@ class ASRServerExecutor(ASRExecutor):
# update the ctc decoding # update the ctc decoding
self.searcher = CTCPrefixBeamSearch(self.config.decode) self.searcher = CTCPrefixBeamSearch(self.config.decode)
self.transformer_decode_reset() self.transformer_decode_reset()
return True
def reset_decoder_and_chunk(self):
"""reset decoder and chunk state for an new audio
"""
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
self.decoder.reset_decoder(batch_size=1)
# init state box, for new audio request
self.chunk_state_h_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
elif "conformer" in self.model_type or "transformer" in self.model_type:
self.transformer_decode_reset()
def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str):
"""decode one chunk
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
model_type (str): online model type
Returns:
str: one best result
"""
logger.info("start to decoce chunk by chunk")
if "deepspeech2online" in model_type:
input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(
input_names[1])
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk)
audio_len_handle.reshape(x_chunk_lens.shape)
audio_len_handle.copy_from_cpu(x_chunk_lens)
h_box_handle.reshape(self.chunk_state_h_box.shape)
h_box_handle.copy_from_cpu(self.chunk_state_h_box)
c_box_handle.reshape(self.chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(self.chunk_state_c_box)
output_names = self.am_predictor.get_output_names()
output_handle = self.am_predictor.get_output_handle(output_names[0])
output_lens_handle = self.am_predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.am_predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.am_predictor.get_output_handle(
output_names[3])
self.am_predictor.run()
output_chunk_probs = output_handle.copy_to_cpu()
output_chunk_lens = output_lens_handle.copy_to_cpu()
self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one best result: {trans_best[0]}")
return trans_best[0]
elif "conformer" in model_type or "transformer" in model_type:
try:
logger.info(
f"we will use the transformer like model : {self.model_type}"
)
self.advanced_decoding(x_chunk, x_chunk_lens)
self.update_result()
return self.result_transcripts[0]
except Exception as e:
logger.exception(e)
else: else:
raise Exception("invalid model name") raise ValueError(f"Not support: {model_type}")
def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): return True
logger.info("start to decode with advanced_decoding method")
encoder_out, encoder_mask = self.encoder_forward(xs)
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
self.searcher.search(ctc_probs, xs.place)
# update the one best result
self.hyps = self.searcher.get_one_best_hyps()
# now we supprot ctc_prefix_beam_search and attention_rescoring
if "attention_rescoring" in self.config.decode.decoding_method:
self.rescoring(encoder_out, xs.place)
def encoder_forward(self, xs):
logger.info("get the model out from the feat")
cfg = self.config.decode
decoding_chunk_size = cfg.decoding_chunk_size
num_decoding_left_chunks = cfg.num_decoding_left_chunks
assert decoding_chunk_size > 0
subsampling = self.model.encoder.embed.subsampling_rate
context = self.model.encoder.embed.right_context + 1
stride = subsampling * decoding_chunk_size
# decoding window for model
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
logger.info("start to do model forward")
outputs = []
# num_frames - context + 1 ensure that current frame can get context window
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
(y, self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size,
self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache)
outputs.append(y)
self.offset += y.shape[1]
ys = paddle.cat(outputs, 1)
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1)
return ys, masks
def rescoring(self, encoder_out, device):
logger.info("start to rescoring the hyps")
beam_size = self.config.decode.beam_size
hyps = self.searcher.get_hyps()
assert len(hyps) == beam_size
hyp_list = []
for hyp in hyps:
hyp_content = hyp[0]
# Prevent the hyp is empty
if len(hyp_content) == 0:
hyp_content = (self.model.ctc.blank_id, )
hyp_content = paddle.to_tensor(
hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id)
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos,
self.model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy()
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for i, hyp in enumerate(hyps):
score = 0.0
for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.model.eos]
# add ctc score (which in ln domain)
score += hyp[1] * self.config.decode.ctc_weight
if score > best_score:
best_score = score
best_index = i
# update the one best result
self.hyps = [hyps[best_index][0]]
return hyps[best_index][0]
def transformer_decode_reset(self):
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.offset = 0
# decoding reset
self.searcher.reset()
def update_result(self):
logger.info("update the final result")
hyps = self.hyps
self.result_transcripts = [
self.text_feature.defeaturize(hyp) for hyp in hyps
]
self.result_tokenids = [hyp for hyp in hyps]
def extract_feat(self, samples, sample_rate):
"""extract feat
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
if "deepspeech2online" in self.model_type:
# pcm16 -> pcm 32
samples = pcm2float(samples)
# read audio
speech_segment = SpeechSegment.from_pcm(
samples, sample_rate, transcript=" ")
# audio augment
self.collate_fn_test.augmentation.transform_audio(speech_segment)
# extract speech feature
spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
speech_segment, self.collate_fn_test.keep_transcription_text)
# CMVN spectrum
if self.collate_fn_test._normalizer:
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment
audio = self.collate_fn_test.augmentation.transform_feature(
spectrum)
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
# audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
x_chunk = audio.numpy()
x_chunk_lens = np.array([audio_len])
return x_chunk, x_chunk_lens
elif "conformer_online" in self.model_type:
if sample_rate != self.sample_rate:
logger.info(f"audio sample rate {sample_rate} is not match,"
"the model sample_rate is {self.sample_rate}")
logger.info(f"ASR Engine use the {self.model_type} to process")
logger.info("Create the preprocess instance")
preprocess_conf = self.config.preprocess_config
preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf)
logger.info("Read the audio file")
logger.info(f"audio shape: {samples.shape}")
# fbank
x_chunk = preprocessing(samples, **preprocess_args)
x_chunk_lens = paddle.to_tensor(x_chunk.shape[0])
x_chunk = paddle.to_tensor(
x_chunk, dtype="float32").unsqueeze(axis=0)
logger.info(
f"process the audio feature success, feat shape: {x_chunk.shape}"
)
return x_chunk, x_chunk_lens
class ASREngine(BaseEngine): class ASREngine(BaseEngine):
"""ASR server engine """ASR server resource
Args: Args:
metaclass: Defaults to Singleton. metaclass: Defaults to Singleton.
...@@ -1015,7 +866,7 @@ class ASREngine(BaseEngine): ...@@ -1015,7 +866,7 @@ class ASREngine(BaseEngine):
def __init__(self): def __init__(self):
super(ASREngine, self).__init__() super(ASREngine, self).__init__()
logger.info("create the online asr engine instance") logger.info("create the online asr engine resource instance")
def init(self, config: dict) -> bool: def init(self, config: dict) -> bool:
"""init engine resource """init engine resource
...@@ -1026,17 +877,12 @@ class ASREngine(BaseEngine): ...@@ -1026,17 +877,12 @@ class ASREngine(BaseEngine):
Returns: Returns:
bool: init failed or success bool: init failed or success
""" """
self.input = None
self.output = ""
self.executor = ASRServerExecutor()
self.config = config self.config = config
self.executor = ASRServerExecutor()
try: try:
if self.config.get("device", None): default_dev = paddle.get_device()
self.device = self.config.device paddle.set_device(self.config.get("device", default_dev))
else:
self.device = paddle.get_device()
logger.info(f"paddlespeech_server set the device: {self.device}")
paddle.set_device(self.device)
except BaseException as e: except BaseException as e:
logger.error( logger.error(
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file" f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
...@@ -1045,6 +891,8 @@ class ASREngine(BaseEngine): ...@@ -1045,6 +891,8 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'") "If all GPU or XPU is used, you can set the server to 'cpu'")
sys.exit(-1) sys.exit(-1)
logger.info(f"paddlespeech_server set the device: {self.device}")
if not self.executor._init_from_path( if not self.executor._init_from_path(
model_type=self.config.model_type, model_type=self.config.model_type,
am_model=self.config.am_model, am_model=self.config.am_model,
...@@ -1062,42 +910,11 @@ class ASREngine(BaseEngine): ...@@ -1062,42 +910,11 @@ class ASREngine(BaseEngine):
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully.")
return True return True
def preprocess(self, def preprocess(self, *args, **kwargs):
samples, raise NotImplementedError("Online not using this.")
sample_rate,
model_type="deepspeech2online_aishell-zh-16k"):
"""preprocess
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
# if "deepspeech" in model_type:
x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate)
return x_chunk, x_chunk_lens
def run(self, x_chunk, x_chunk_lens, decoder_chunk_size=1): def run(self, *args, **kwargs):
"""run online engine raise NotImplementedError("Online not using this.")
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
decoder_chunk_size(int)
"""
self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens,
self.config.model_type)
def postprocess(self): def postprocess(self):
"""postprocess raise NotImplementedError("Online not using this.")
"""
return self.output
def reset(self):
"""reset engine decoder and inference state
"""
self.executor.reset_decoder_and_chunk()
self.output = ""
...@@ -17,12 +17,12 @@ from typing import List ...@@ -17,12 +17,12 @@ from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.restful.acs_api import router as acs_router
from paddlespeech.server.restful.asr_api import router as asr_router from paddlespeech.server.restful.asr_api import router as asr_router
from paddlespeech.server.restful.cls_api import router as cls_router from paddlespeech.server.restful.cls_api import router as cls_router
from paddlespeech.server.restful.text_api import router as text_router from paddlespeech.server.restful.text_api import router as text_router
from paddlespeech.server.restful.tts_api import router as tts_router from paddlespeech.server.restful.tts_api import router as tts_router
from paddlespeech.server.restful.vector_api import router as vec_router from paddlespeech.server.restful.vector_api import router as vec_router
from paddlespeech.server.restful.acs_api import router as acs_router
_router = APIRouter() _router = APIRouter()
......
...@@ -248,7 +248,7 @@ class ASRHttpHandler: ...@@ -248,7 +248,7 @@ class ASRHttpHandler:
} }
res = requests.post(url=self.url, data=json.dumps(data)) res = requests.post(url=self.url, data=json.dumps(data))
return res.json() return res.json()
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
class Frame(object): class Frame(object):
"""Represents a "frame" of audio data.""" """Represents a "frame" of audio data."""
...@@ -45,7 +46,7 @@ class ChunkBuffer(object): ...@@ -45,7 +46,7 @@ class ChunkBuffer(object):
self.shift_ms = shift_ms self.shift_ms = shift_ms
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4 self.sample_width = sample_width # int16 = 2; float32 = 4
self.window_sec = float((self.window_n - 1) * self.shift_ms + self.window_sec = float((self.window_n - 1) * self.shift_ms +
self.window_ms) / 1000.0 self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0) self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
...@@ -77,8 +78,8 @@ class ChunkBuffer(object): ...@@ -77,8 +78,8 @@ class ChunkBuffer(object):
offset = 0 offset = 0
while offset + self.window_bytes <= len(audio): while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], self.timestamp, yield Frame(audio[offset:offset + self.window_bytes],
self.window_sec) self.timestamp, self.window_sec)
self.timestamp += self.shift_sec self.timestamp += self.shift_sec
offset += self.shift_bytes offset += self.shift_bytes
......
...@@ -176,7 +176,10 @@ def main(): ...@@ -176,7 +176,10 @@ def main():
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu.") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu.")
parser.add_argument( parser.add_argument(
"--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.") "--nxpu",
type=int,
default=0,
help="if nxpu == 0 and ngpu == 0, use cpu.")
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
......
...@@ -188,7 +188,10 @@ def main(): ...@@ -188,7 +188,10 @@ def main():
parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument( parser.add_argument(
"--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.") "--nxpu",
type=int,
default=0,
help="if nxpu == 0 and ngpu == 0, use cpu.")
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu")
......
...@@ -36,4 +36,4 @@ def repeat(N, fn): ...@@ -36,4 +36,4 @@ def repeat(N, fn):
Returns: Returns:
MultiSequential: Repeated model instance. MultiSequential: Repeated model instance.
""" """
return MultiSequential(*[fn(n) for n in range(N)]) return MultiSequential(* [fn(n) for n in range(N)])
...@@ -98,7 +98,6 @@ requirements = { ...@@ -98,7 +98,6 @@ requirements = {
} }
def check_call(cmd: str, shell=False, executable=None): def check_call(cmd: str, shell=False, executable=None):
try: try:
sp.check_call( sp.check_call(
...@@ -112,12 +111,13 @@ def check_call(cmd: str, shell=False, executable=None): ...@@ -112,12 +111,13 @@ def check_call(cmd: str, shell=False, executable=None):
file=sys.stderr) file=sys.stderr)
raise e raise e
def check_output(cmd: str, shell=False): def check_output(cmd: str, shell=False):
try: try:
out_bytes = sp.check_output(cmd.split()) out_bytes = sp.check_output(cmd.split())
except sp.CalledProcessError as e: except sp.CalledProcessError as e:
out_bytes = e.output # Output generated before error out_bytes = e.output # Output generated before error
code = e.returncode # Return code code = e.returncode # Return code
print( print(
f"{__file__}:{inspect.currentframe().f_lineno}: CMD: {cmd}, Error:", f"{__file__}:{inspect.currentframe().f_lineno}: CMD: {cmd}, Error:",
out_bytes, out_bytes,
...@@ -146,6 +146,7 @@ def _remove(files: str): ...@@ -146,6 +146,7 @@ def _remove(files: str):
for f in files: for f in files:
f.unlink() f.unlink()
################################# Install ################################## ################################# Install ##################################
...@@ -308,6 +309,5 @@ setup_info = dict( ...@@ -308,6 +309,5 @@ setup_info = dict(
] ]
}) })
with version_info(): with version_info():
setup(**setup_info) setup(**setup_info)
...@@ -20,7 +20,6 @@ of each audio file in the data set. ...@@ -20,7 +20,6 @@ of each audio file in the data set.
""" """
import argparse import argparse
import codecs import codecs
import json
import os import os
from pathlib import Path from pathlib import Path
...@@ -89,7 +88,7 @@ def create_manifest(data_dir, manifest_path_prefix): ...@@ -89,7 +88,7 @@ def create_manifest(data_dir, manifest_path_prefix):
duration = float(len(audio_data) / samplerate) duration = float(len(audio_data) / samplerate)
text = transcript_dict[audio_id] text = transcript_dict[audio_id]
json_lines.append(audio_path) json_lines.append(audio_path)
reference_lines.append(str(total_num+1) + "\t" + text) reference_lines.append(str(total_num + 1) + "\t" + text)
total_sec += duration total_sec += duration
total_text += len(text) total_text += len(text)
...@@ -106,6 +105,7 @@ def create_manifest(data_dir, manifest_path_prefix): ...@@ -106,6 +105,7 @@ def create_manifest(data_dir, manifest_path_prefix):
manifest_dir = os.path.dirname(manifest_path_prefix) manifest_dir = os.path.dirname(manifest_path_prefix)
def prepare_dataset(url, md5sum, target_dir, manifest_path=None): def prepare_dataset(url, md5sum, target_dir, manifest_path=None):
"""Download, unpack and create manifest file.""" """Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell') data_dir = os.path.join(target_dir, 'data_aishell')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册