Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
d2640c14
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d2640c14
编写于
4月 18, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mult sesssion process, test=doc
上级
97d31f9a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
189 addition
and
1 deletion
+189
-1
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+189
-1
未找到文件。
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
d2640c14
...
@@ -78,6 +78,194 @@ pretrained_models = {
...
@@ -78,6 +78,194 @@ pretrained_models = {
},
},
}
}
# ASR server connection process class
class
PaddleASRConnectionHanddler
:
def
__init__
(
self
,
asr_engine
):
super
().
__init__
()
self
.
config
=
asr_engine
.
config
self
.
model_config
=
asr_engine
.
executor
.
config
self
.
asr_engine
=
asr_engine
self
.
init
()
self
.
reset
()
def
init
(
self
):
self
.
model_type
=
self
.
asr_engine
.
executor
.
model_type
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
pass
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
or
"wenetspeech"
in
self
.
model_type
:
self
.
sample_rate
=
self
.
asr_engine
.
executor
.
sample_rate
# acoustic model
self
.
model
=
self
.
asr_engine
.
executor
.
model
# tokens to text
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
# ctc decoding
self
.
ctc_decode_config
=
self
.
asr_engine
.
executor
.
config
.
decode
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
ctc_decode_config
)
# extract fbank
self
.
preprocess_conf
=
self
.
model_config
.
preprocess_config
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
def
extract_feat
(
self
,
samples
):
if
"deepspeech2online"
in
self
.
model_type
:
pass
elif
"conformer2online"
in
self
.
model_type
:
logger
.
info
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
logger
.
info
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data"
)
self
.
num_samples
+=
samples
.
shape
[
0
]
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if
self
.
remained_wav
is
None
:
self
.
remained_wav
=
samples
else
:
assert
self
.
remained_wav
.
ndim
==
1
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
f
"The connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
if
len
(
self
.
remained_wav
)
<
self
.
win_length
:
return
0
# fbank
x_chunk
=
self
.
preprocessing
(
self
.
remained_wav
,
**
self
.
preprocess_args
)
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
x_chunk
else
:
self
.
cached_feat
=
paddle
.
concat
([
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
num_frames
=
x_chunk
.
shape
[
1
]
self
.
num_frames
+=
num_frames
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
f
"process the audio feature success, the connection feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
f
"After extract feat, the connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
# logger.info(f"accumulate samples: {self.num_samples}")
def
reset
(
self
):
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_outs_
=
None
self
.
cached_feat
=
None
self
.
remained_wav
=
None
self
.
offset
=
0
self
.
num_samples
=
0
self
.
num_frames
=
0
self
.
global_frame_offset
=
0
self
.
result
=
[]
def
decode
(
self
,
is_finished
=
False
):
if
"deepspeech2online"
in
self
.
model_type
:
pass
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
try
:
logger
.
info
(
f
"we will use the transformer like model :
{
self
.
model_type
}
"
)
self
.
advance_decoding
(
is_finished
)
# self.update_result()
# return self.result_transcripts[0]
except
Exception
as
e
:
logger
.
exception
(
e
)
else
:
raise
Exception
(
"invalid model name"
)
def
advance_decoding
(
self
,
is_finished
=
False
):
logger
.
info
(
"start to decode with advanced_decoding method"
)
cfg
=
self
.
ctc_decode_config
decoding_chunk_size
=
cfg
.
decoding_chunk_size
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
assert
decoding_chunk_size
>
0
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
stride
=
subsampling
*
decoding_chunk_size
# decoding window for model
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
num_frames
=
self
.
cached_feat
.
shape
[
1
]
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
return
None
,
None
# logger.info("start to do model forward")
# required_cache_size = decoding_chunk_size * num_decoding_left_chunks
# outputs = []
# # num_frames - context + 1 ensure that current frame can get context window
# if is_finished:
# # if get the finished chunk, we need process the last context
# left_frames = context
# else:
# # we only process decoding_window frames for one chunk
# left_frames = decoding_window
# logger.info(f"")
# end = None
# for cur in range(0, num_frames - left_frames + 1, stride):
# end = min(cur + decoding_window, num_frames)
# print(f"cur: {cur}, end: {end}")
# chunk_xs = self.cached_feat[:, cur:end, :]
# (y, self.subsampling_cache, self.elayers_output_cache,
# self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
# chunk_xs, self.offset, required_cache_size,
# self.subsampling_cache, self.elayers_output_cache,
# self.conformer_cnn_cache)
# outputs.append(y)
# update the offset
# self.offset += y.shape[1]
# self.cached_feat = self.cached_feat[end:]
# ys = paddle.cat(outputs, 1)
# masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
# masks = masks.unsqueeze(1)
# # get the ctc probs
# ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
# ctc_probs = ctc_probs.squeeze(0)
# # self.searcher.search(xs, ctc_probs, xs.place)
# self.searcher.search(None, ctc_probs, self.cached_feat.place)
# self.hyps = self.searcher.get_one_best_hyps()
# ys for rescoring
# return ys, masks
def
update_result
(
self
):
logger
.
info
(
"update the final result"
)
hyps
=
self
.
hyps
self
.
result_transcripts
=
[
self
.
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
hyps
]
def
rescoring
(
self
):
pass
class
ASRServerExecutor
(
ASRExecutor
):
class
ASRServerExecutor
(
ASRExecutor
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -492,7 +680,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -492,7 +680,7 @@ class ASRServerExecutor(ASRExecutor):
if
sample_rate
!=
self
.
sample_rate
:
if
sample_rate
!=
self
.
sample_rate
:
logger
.
info
(
f
"audio sample rate
{
sample_rate
}
is not match,"
logger
.
info
(
f
"audio sample rate
{
sample_rate
}
is not match,"
"the model sample_rate is {self.sample_rate}"
)
"the model sample_rate is {self.sample_rate}"
)
logger
.
info
(
"ASR Engine use the {self.model_type} to process"
)
logger
.
info
(
f
"ASR Engine use the
{
self
.
model_type
}
to process"
)
logger
.
info
(
"Create the preprocess instance"
)
logger
.
info
(
"Create the preprocess instance"
)
preprocess_conf
=
self
.
config
.
preprocess_config
preprocess_conf
=
self
.
config
.
preprocess_config
preprocess_args
=
{
"train"
:
False
}
preprocess_args
=
{
"train"
:
False
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录