Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
f07f57a3
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f07f57a3
编写于
5月 25, 2022
作者:
H
Hui Zhang
提交者:
GitHub
5月 25, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1945 from PaddlePaddle/asr_line
[server][asr] refactor asr streaming server and remove useless code
上级
8f8239ad
c15278ed
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
247 addition
and
402 deletion
+247
-402
demos/audio_searching/src/operations/load.py
demos/audio_searching/src/operations/load.py
+2
-3
demos/streaming_asr_server/websocket_client.py
demos/streaming_asr_server/websocket_client.py
+0
-2
docs/source/asr/PPASR_cn.md
docs/source/asr/PPASR_cn.md
+0
-2
paddlespeech/cli/utils.py
paddlespeech/cli/utils.py
+1
-1
paddlespeech/s2t/io/sampler.py
paddlespeech/s2t/io/sampler.py
+1
-1
paddlespeech/server/engine/acs/__init__.py
paddlespeech/server/engine/acs/__init__.py
+13
-0
paddlespeech/server/engine/acs/python/__init__.py
paddlespeech/server/engine/acs/python/__init__.py
+13
-0
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+196
-379
paddlespeech/server/restful/api.py
paddlespeech/server/restful/api.py
+1
-1
paddlespeech/server/utils/audio_handler.py
paddlespeech/server/utils/audio_handler.py
+1
-1
paddlespeech/server/utils/buffer.py
paddlespeech/server/utils/buffer.py
+4
-3
paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py
paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py
+4
-1
paddlespeech/t2s/exps/speedyspeech/train.py
paddlespeech/t2s/exps/speedyspeech/train.py
+4
-1
paddlespeech/t2s/modules/transformer/repeat.py
paddlespeech/t2s/modules/transformer/repeat.py
+1
-1
setup.py
setup.py
+4
-4
tests/unit/cli/aishell_test_prepare.py
tests/unit/cli/aishell_test_prepare.py
+2
-2
未找到文件。
demos/audio_searching/src/operations/load.py
浏览文件 @
f07f57a3
...
...
@@ -26,9 +26,8 @@ def get_audios(path):
"""
supported_formats
=
[
".wav"
,
".mp3"
,
".ogg"
,
".flac"
,
".m4a"
]
return
[
item
for
sublist
in
[[
os
.
path
.
join
(
dir
,
file
)
for
file
in
files
]
for
dir
,
_
,
files
in
list
(
os
.
walk
(
path
))]
item
for
sublist
in
[[
os
.
path
.
join
(
dir
,
file
)
for
file
in
files
]
for
dir
,
_
,
files
in
list
(
os
.
walk
(
path
))]
for
item
in
sublist
if
os
.
path
.
splitext
(
item
)[
1
]
in
supported_formats
]
...
...
demos/streaming_asr_server/websocket_client.py
浏览文件 @
f07f57a3
...
...
@@ -13,9 +13,7 @@
# limitations under the License.
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
import
argparse
import
asyncio
import
codecs
...
...
docs/source/asr/PPASR_cn.md
浏览文件 @
f07f57a3
...
...
@@ -92,5 +92,3 @@ server 的 demo: [streaming_asr_server](https://github.com/PaddlePaddle/Paddle
## 4. 快速开始
关于如果使用 PP-ASR,可以看这里的
[
install
](
https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md
)
,其中提供了
**简单**
、
**中等**
、
**困难**
三种安装方式。如果想体验 paddlespeech 的推理功能,可以用
**简单**
安装方式。
paddlespeech/cli/utils.py
浏览文件 @
f07f57a3
...
...
@@ -24,11 +24,11 @@ from typing import Any
from
typing
import
Dict
import
paddle
import
paddleaudio
import
requests
import
yaml
from
paddle.framework
import
load
import
paddleaudio
from
.
import
download
from
.entry
import
commands
try
:
...
...
paddlespeech/s2t/io/sampler.py
浏览文件 @
f07f57a3
...
...
@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng
=
np
.
random
.
RandomState
(
epoch
)
shift_len
=
rng
.
randint
(
0
,
batch_size
-
1
)
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
rng
.
shuffle
(
batch_indices
)
batch_indices
=
[
item
for
batch
in
batch_indices
for
item
in
batch
]
assert
clipped
is
False
...
...
paddlespeech/server/engine/acs/__init__.py
浏览文件 @
f07f57a3
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/engine/acs/python/__init__.py
浏览文件 @
f07f57a3
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
f07f57a3
...
...
@@ -38,7 +38,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine
from
paddlespeech.server.utils.audio_process
import
pcm2float
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
__all__
=
[
'ASREngine'
]
__all__
=
[
'
PaddleASRConnectionHanddler'
,
'ASRServerExecutor'
,
'
ASREngine'
]
# ASR server connection process class
...
...
@@ -67,7 +67,7 @@ class PaddleASRConnectionHanddler:
# tokens to text
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
if
"deepspeech2
online"
in
self
.
model_type
or
"deepspeech2offline
"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
from
paddlespeech.s2t.io.collator
import
SpeechCollator
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
...
...
@@ -89,8 +89,8 @@ class PaddleASRConnectionHanddler:
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
# frame window samples length and frame shift samples length
# frame window and frame shift, in samples unit
self
.
win_length
=
int
(
self
.
model_config
.
window_ms
/
1000
*
self
.
sample_rate
)
self
.
n_shift
=
int
(
self
.
model_config
.
stride_ms
/
1000
*
...
...
@@ -109,16 +109,15 @@ class PaddleASRConnectionHanddler:
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
# frame window
samples length and frame shift samples length
# frame window
and frame shift, in samples unit
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
else
:
raise
ValueError
(
f
"Not supported:
{
self
.
model_type
}
"
)
def
extract_feat
(
self
,
samples
):
# we compute the elapsed time of first char occuring
# and we record the start time at the first pcm sample arraving
# if self.first_char_occur_elapsed is not None:
# self.first_char_occur_elapsed = time.time()
if
"deepspeech2online"
in
self
.
model_type
:
# self.reamined_wav stores all the samples,
...
...
@@ -154,28 +153,27 @@ class PaddleASRConnectionHanddler:
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
# spectrum augment
audio
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
feat
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
audio_len
=
audio
.
shape
[
0
]
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
# audio_len = paddle.to_tensor(audio_len
)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
# audio_len is frame num
frame_num
=
feat
.
shape
[
0
]
feat
=
paddle
.
to_tensor
(
feat
,
dtype
=
'float32'
)
feat
=
paddle
.
unsqueeze
(
feat
,
axis
=
0
)
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
audio
self
.
cached_feat
=
feat
else
:
assert
(
len
(
audio
.
shape
)
==
3
)
assert
(
len
(
feat
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
audio
],
axis
=
1
)
[
self
.
cached_feat
,
feat
],
axis
=
1
)
# set the feat device
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
self
.
num_frames
+=
audio_len
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
audio_len
:]
self
.
num_frames
+=
frame_num
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
frame_num
:]
logger
.
info
(
f
"process the audio feature success, the connection feat shape:
{
self
.
cached_feat
.
shape
}
"
...
...
@@ -183,25 +181,30 @@ class PaddleASRConnectionHanddler:
logger
.
info
(
f
"After extract feat, the connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
elif
"conformer_online"
in
self
.
model_type
:
logger
.
info
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
logger
.
info
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data"
)
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if
self
.
remained_wav
is
None
:
self
.
remained_wav
=
samples
else
:
assert
self
.
remained_wav
.
ndim
==
1
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
f
"The con
nection remain the audio sample
s:
{
self
.
remained_wav
.
shape
}
"
f
"The con
catenation of remain and now audio samples length i
s:
{
self
.
remained_wav
.
shape
}
"
)
if
len
(
self
.
remained_wav
)
<
self
.
win_length
:
# samples not enough for feature window
return
0
# fbank
...
...
@@ -209,11 +212,13 @@ class PaddleASRConnectionHanddler:
**
self
.
preprocess_args
)
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
# feature cache
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
x_chunk
else
:
assert
(
len
(
x_chunk
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
assert
(
len
(
x_chunk
.
shape
)
==
3
)
# (B,T,D)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
# (B,T,D)
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
...
...
@@ -221,20 +226,28 @@ class PaddleASRConnectionHanddler:
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
# cur frame step
num_frames
=
x_chunk
.
shape
[
1
]
# global frame step
self
.
num_frames
+=
num_frames
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
f
"process the audio feature success, the c
onnection
feat shape:
{
self
.
cached_feat
.
shape
}
"
f
"process the audio feature success, the c
ached
feat shape:
{
self
.
cached_feat
.
shape
}
"
)
logger
.
info
(
f
"After extract feat, the c
onnection
remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
f
"After extract feat, the c
ached
remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
# logger.info(f"accumulate samples: {self.num_samples}")
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
else
:
raise
ValueError
(
f
"not supported:
{
self
.
model_type
}
"
)
def
reset
(
self
):
if
"deepspeech2
online"
in
self
.
model_type
or
"deepspeech2offline
"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
# for deepspeech2
self
.
chunk_state_h_box
=
copy
.
deepcopy
(
self
.
asr_engine
.
executor
.
chunk_state_h_box
)
...
...
@@ -242,35 +255,61 @@ class PaddleASRConnectionHanddler:
self
.
asr_engine
.
executor
.
chunk_state_c_box
)
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
# for conformer online
self
.
device
=
None
## common
# global sample and frame step
self
.
num_samples
=
0
self
.
num_frames
=
0
# cache for audio and feat
self
.
remained_wav
=
None
self
.
cached_feat
=
None
# partial/ending decoding results
self
.
result_transcripts
=
[
''
]
## conformer
# cache for conformer online
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
=
None
self
.
cached_feat
=
None
self
.
remained_wav
=
None
self
.
offset
=
0
self
.
num_samples
=
0
self
.
device
=
None
# conformer decoding state
self
.
chunk_num
=
0
# globa decoding chunk num
self
.
offset
=
0
# global offset in decoding frame unit
self
.
hyps
=
[]
self
.
num_frames
=
0
self
.
chunk_num
=
0
self
.
global_frame_offset
=
0
self
.
result_transcripts
=
[
''
]
# token timestamp result
self
.
word_time_stamp
=
[]
# one best timestamp viterbi prob is large.
self
.
time_stamp
=
[]
self
.
first_char_occur_elapsed
=
None
def
decode
(
self
,
is_finished
=
False
):
"""advance decoding
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Raises:
Exception: when not support model.
Returns:
None: nothing
"""
if
"deepspeech2online"
in
self
.
model_type
:
# x_chunk 是特征数据
decoding_chunk_size
=
1
# decoding_chunk_size=1 in deepspeech2 model
context
=
7
# context=7 in deepspeech2 model
subsampling
=
4
# subsampling=4 in deepspeech2 model
stride
=
subsampling
*
decoding_chunk_size
decoding_chunk_size
=
1
# decoding chunk size = 1. int decoding frame unit
context
=
7
# context=7, in audio frame unit
subsampling
=
4
# subsampling=4, in audio frame unit
cached_feature_num
=
context
-
subsampling
# decoding window for model
# decoding window for model
, in audio frame unit
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
# decoding stride for model, in audio frame unit
stride
=
subsampling
*
decoding_chunk_size
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
...
...
@@ -280,6 +319,7 @@ class PaddleASRConnectionHanddler:
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
...
...
@@ -293,6 +333,7 @@ class PaddleASRConnectionHanddler:
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return
None
,
None
logger
.
info
(
"start to do model forward"
)
# num_frames - context + 1 ensure that current frame can get context window
if
is_finished
:
...
...
@@ -302,6 +343,7 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk
left_frames
=
decoding_window
end
=
None
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
# extract the audio
...
...
@@ -311,7 +353,9 @@ class PaddleASRConnectionHanddler:
self
.
result_transcripts
=
[
trans_best
]
# update feat cache
self
.
cached_feat
=
self
.
cached_feat
[:,
end
-
cached_feature_num
:,
:]
# return trans_best[0]
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
try
:
...
...
@@ -328,7 +372,16 @@ class PaddleASRConnectionHanddler:
@
paddle
.
no_grad
()
def
decode_one_chunk
(
self
,
x_chunk
,
x_chunk_lens
):
logger
.
info
(
"start to decoce one chunk with deepspeech2 model"
)
"""forward one chunk frames
Args:
x_chunk (np.ndarray): (B,T,D), audio frames.
x_chunk_lens ([type]): (B,), audio frame lens
Returns:
logprob: poster probability.
"""
logger
.
info
(
"start to decoce one chunk for deepspeech2"
)
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
...
...
@@ -365,24 +418,32 @@ class PaddleASRConnectionHanddler:
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one best result:
{
trans_best
[
0
]
}
"
)
logger
.
info
(
f
"decode one best result
for deepspeech2
:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
@
paddle
.
no_grad
()
def
advance_decoding
(
self
,
is_finished
=
False
):
logger
.
info
(
"start to decode with advanced_decoding method"
)
logger
.
info
(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
cfg
=
self
.
ctc_decode_config
# cur chunk size, in decoding frame unit
decoding_chunk_size
=
cfg
.
decoding_chunk_size
# using num of history chunks
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
assert
decoding_chunk_size
>
0
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
stride
=
subsampling
*
decoding_chunk_size
cached_feature_num
=
context
-
subsampling
# processed chunk feature cached for next chunk
# decoding window for model
# processed chunk feature cached for next chunk
cached_feature_num
=
context
-
subsampling
# decoding stride, in audio frame unit
stride
=
subsampling
*
decoding_chunk_size
# decoding window, in audio frame unit
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
return
...
...
@@ -407,6 +468,7 @@ class PaddleASRConnectionHanddler:
return
None
,
None
logger
.
info
(
"start to do model forward"
)
# hist of chunks, in deocding frame unit
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
outputs
=
[]
...
...
@@ -423,8 +485,11 @@ class PaddleASRConnectionHanddler:
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
# global chunk_num
self
.
chunk_num
+=
1
# cur chunk
chunk_xs
=
self
.
cached_feat
[:,
cur
:
end
,
:]
# forward chunk
(
y
,
self
.
subsampling_cache
,
self
.
elayers_output_cache
,
self
.
conformer_cnn_cache
)
=
self
.
model
.
encoder
.
forward_chunk
(
chunk_xs
,
self
.
offset
,
required_cache_size
,
...
...
@@ -432,7 +497,7 @@ class PaddleASRConnectionHanddler:
self
.
conformer_cnn_cache
)
outputs
.
append
(
y
)
# update the
offse
t
# update the
global offset, in decoding frame uni
t
self
.
offset
+=
y
.
shape
[
1
]
ys
=
paddle
.
cat
(
outputs
,
1
)
...
...
@@ -445,12 +510,15 @@ class PaddleASRConnectionHanddler:
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
ys
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
# advance decoding
self
.
searcher
.
search
(
ctc_probs
,
self
.
cached_feat
.
place
)
# get one best hyps
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
assert
self
.
cached_feat
.
shape
[
0
]
==
1
assert
end
>=
cached_feature_num
# advance cache of feat
self
.
cached_feat
=
self
.
cached_feat
[
0
,
end
-
cached_feature_num
:,
:].
unsqueeze
(
0
)
assert
len
(
...
...
@@ -462,50 +530,81 @@ class PaddleASRConnectionHanddler:
)
def
update_result
(
self
):
"""Conformer/Transformer hyps to result.
"""
logger
.
info
(
"update the final result"
)
hyps
=
self
.
hyps
# output results and tokenids
self
.
result_transcripts
=
[
self
.
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
hyps
]
def
get_result
(
self
):
"""return partial/ending asr result.
Returns:
str: one best result of partial/ending.
"""
if
len
(
self
.
result_transcripts
)
>
0
:
return
self
.
result_transcripts
[
0
]
else
:
return
''
def
get_word_time_stamp
(
self
):
"""return token timestamp result.
Returns:
list: List of ('w':token, 'bg':time, 'ed':time)
"""
return
self
.
word_time_stamp
@
paddle
.
no_grad
()
def
rescoring
(
self
):
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
"""Second-Pass Decoding,
only for conformer and transformer model.
"""
if
"deepspeech2"
in
self
.
model_type
:
logger
.
info
(
"deepspeech2 not support rescoring decoding."
)
return
logger
.
info
(
"rescoring the final result"
)
if
"attention_rescoring"
!=
self
.
ctc_decode_config
.
decoding_method
:
logger
.
info
(
f
"decoding method not match:
{
self
.
ctc_decode_config
.
decoding_method
}
, need attention_rescoring"
)
return
logger
.
info
(
"rescoring the final result"
)
# last decoding for last audio
self
.
searcher
.
finalize_search
()
# update beam search results
self
.
update_result
()
beam_size
=
self
.
ctc_decode_config
.
beam_size
hyps
=
self
.
searcher
.
get_hyps
()
if
hyps
is
None
or
len
(
hyps
)
==
0
:
logger
.
info
(
"No Hyps!"
)
return
# rescore by decoder post probability
# assert len(hyps) == beam_size
# list of Tensor
hyp_list
=
[]
for
hyp
in
hyps
:
hyp_content
=
hyp
[
0
]
# Prevent the hyp is empty
if
len
(
hyp_content
)
==
0
:
hyp_content
=
(
self
.
model
.
ctc
.
blank_id
,
)
hyp_content
=
paddle
.
to_tensor
(
hyp_content
,
place
=
self
.
device
,
dtype
=
paddle
.
long
)
hyp_list
.
append
(
hyp_content
)
hyps_pad
=
pad_sequence
(
hyp_list
,
True
,
self
.
model
.
ignore_id
)
hyps_pad
=
pad_sequence
(
hyp_list
,
batch_first
=
True
,
padding_value
=
self
.
model
.
ignore_id
)
hyps_lens
=
paddle
.
to_tensor
(
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
self
.
device
,
dtype
=
paddle
.
long
)
# (beam_size,)
...
...
@@ -531,10 +630,12 @@ class PaddleASRConnectionHanddler:
score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
score
+=
decoder_out
[
i
][
j
][
w
]
# last decoder output token is `eos`, for laste decoder input token.
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
model
.
eos
]
# add ctc score (which in ln domain)
score
+=
hyp
[
1
]
*
self
.
ctc_decode_config
.
ctc_weight
if
score
>
best_score
:
best_score
=
score
best_index
=
i
...
...
@@ -542,43 +643,52 @@ class PaddleASRConnectionHanddler:
# update the one best result
# hyps stored the beam results and each fields is:
logger
.
info
(
f
"best index:
{
best_index
}
"
)
logger
.
info
(
f
"best
hyp
index:
{
best_index
}
"
)
# logger.info(f'best result: {hyps[best_index]}')
# the field of the hyps is:
## asr results
# hyps[0][0]: the sentence word-id in the vocab with a tuple
# hyps[0][1]: the sentence decoding probability with all paths
## timestamp
# hyps[0][2]: viterbi_blank ending probability
# hyps[0][3]: viterbi_non_blank probability
# hyps[0][3]: viterbi_non_blank
dending
probability
# hyps[0][4]: current_token_prob,
# hyps[0][5]: times_viterbi_blank,
# hyps[0][6]: times_titerbi_non_blank
# hyps[0][5]: times_viterbi_blank
ending timestamp
,
# hyps[0][6]: times_titerbi_non_blank
encding timestamp.
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
logger
.
info
(
f
"best hyp ids:
{
self
.
hyps
}
"
)
# update the hyps time stamp
self
.
time_stamp
=
hyps
[
best_index
][
5
]
if
hyps
[
best_index
][
2
]
>
hyps
[
best_index
][
3
]
else
hyps
[
best_index
][
6
]
logger
.
info
(
f
"time stamp:
{
self
.
time_stamp
}
"
)
# update one best result
self
.
update_result
()
# update each word start and end time stamp
frame_shift_in_ms
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
*
self
.
n_shift
/
self
.
sample_rate
logger
.
info
(
f
"frame shift ms:
{
frame_shift_in_ms
}
"
)
# decoding frame to audio frame
frame_shift
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
frame_shift_in_sec
=
frame_shift
*
(
self
.
n_shift
/
self
.
sample_rate
)
logger
.
info
(
f
"frame shift sec:
{
frame_shift_in_sec
}
"
)
word_time_stamp
=
[]
for
idx
,
_
in
enumerate
(
self
.
time_stamp
):
start
=
(
self
.
time_stamp
[
idx
-
1
]
+
self
.
time_stamp
[
idx
]
)
/
2.0
if
idx
>
0
else
0
start
=
start
*
frame_shift_in_
ms
start
=
start
*
frame_shift_in_
sec
end
=
(
self
.
time_stamp
[
idx
]
+
self
.
time_stamp
[
idx
+
1
]
)
/
2.0
if
idx
<
len
(
self
.
time_stamp
)
-
1
else
self
.
offset
end
=
end
*
frame_shift_in_ms
end
=
end
*
frame_shift_in_sec
word_time_stamp
.
append
({
"w"
:
self
.
result_transcripts
[
0
][
idx
],
"bg"
:
start
,
"ed"
:
end
})
# logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}")
# logger.info(f"{word_time_stamp[-1]}")
self
.
word_time_stamp
=
word_time_stamp
logger
.
info
(
f
"word time stamp:
{
self
.
word_time_stamp
}
"
)
...
...
@@ -610,6 +720,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
sample_rate
=
sample_rate
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
if
cfg_path
is
None
or
am_model
is
None
or
am_params
is
None
:
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
...
...
@@ -639,7 +750,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
with
UpdateConfig
(
self
.
config
):
if
"deepspeech2
online"
in
model_type
or
"deepspeech2offline
"
in
model_type
:
if
"deepspeech2"
in
model_type
:
from
paddlespeech.s2t.io.collator
import
SpeechCollator
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
config
.
decode
.
lang_model_path
=
os
.
path
.
join
(
...
...
@@ -655,6 +766,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
download_lm
(
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
logger
.
info
(
"start to create the stream conformer asr engine"
)
if
self
.
config
.
spm_model_prefix
:
...
...
@@ -682,7 +794,8 @@ class ASRServerExecutor(ASRExecutor):
],
f
"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is
{
self
.
config
.
decode
.
decoding_method
}
"
else
:
raise
Exception
(
"wrong type"
)
if
"deepspeech2online"
in
model_type
or
"deepspeech2offline"
in
model_type
:
if
"deepspeech2"
in
model_type
:
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
...
...
@@ -719,6 +832,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
# model_type: {model_name}_{dataset}
...
...
@@ -737,277 +851,14 @@ class ASRServerExecutor(ASRExecutor):
# update the ctc decoding
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
config
.
decode
)
self
.
transformer_decode_reset
()
return
True
def
reset_decoder_and_chunk
(
self
):
"""reset decoder and chunk state for an new audio
"""
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
# init state box, for new audio request
self
.
chunk_state_h_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
self
.
transformer_decode_reset
()
def
decode_one_chunk
(
self
,
x_chunk
,
x_chunk_lens
,
model_type
:
str
):
"""decode one chunk
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
model_type (str): online model type
Returns:
str: one best result
"""
logger
.
info
(
"start to decoce chunk by chunk"
)
if
"deepspeech2online"
in
model_type
:
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
h_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
2
])
c_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
3
])
audio_handle
.
reshape
(
x_chunk
.
shape
)
audio_handle
.
copy_from_cpu
(
x_chunk
)
audio_len_handle
.
reshape
(
x_chunk_lens
.
shape
)
audio_len_handle
.
copy_from_cpu
(
x_chunk_lens
)
h_box_handle
.
reshape
(
self
.
chunk_state_h_box
.
shape
)
h_box_handle
.
copy_from_cpu
(
self
.
chunk_state_h_box
)
c_box_handle
.
reshape
(
self
.
chunk_state_c_box
.
shape
)
c_box_handle
.
copy_from_cpu
(
self
.
chunk_state_c_box
)
output_names
=
self
.
am_predictor
.
get_output_names
()
output_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
0
])
output_lens_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
1
])
output_state_h_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
2
])
output_state_c_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
3
])
self
.
am_predictor
.
run
()
output_chunk_probs
=
output_handle
.
copy_to_cpu
()
output_chunk_lens
=
output_lens_handle
.
copy_to_cpu
()
self
.
chunk_state_h_box
=
output_state_h_handle
.
copy_to_cpu
()
self
.
chunk_state_c_box
=
output_state_c_handle
.
copy_to_cpu
()
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one best result:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
try
:
logger
.
info
(
f
"we will use the transformer like model :
{
self
.
model_type
}
"
)
self
.
advanced_decoding
(
x_chunk
,
x_chunk_lens
)
self
.
update_result
()
return
self
.
result_transcripts
[
0
]
except
Exception
as
e
:
logger
.
exception
(
e
)
else
:
raise
Exception
(
"invalid model name
"
)
raise
ValueError
(
f
"Not support:
{
model_type
}
"
)
def
advanced_decoding
(
self
,
xs
:
paddle
.
Tensor
,
x_chunk_lens
):
logger
.
info
(
"start to decode with advanced_decoding method"
)
encoder_out
,
encoder_mask
=
self
.
encoder_forward
(
xs
)
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
encoder_out
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
self
.
searcher
.
search
(
ctc_probs
,
xs
.
place
)
# update the one best result
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
# now we supprot ctc_prefix_beam_search and attention_rescoring
if
"attention_rescoring"
in
self
.
config
.
decode
.
decoding_method
:
self
.
rescoring
(
encoder_out
,
xs
.
place
)
def
encoder_forward
(
self
,
xs
):
logger
.
info
(
"get the model out from the feat"
)
cfg
=
self
.
config
.
decode
decoding_chunk_size
=
cfg
.
decoding_chunk_size
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
assert
decoding_chunk_size
>
0
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
stride
=
subsampling
*
decoding_chunk_size
# decoding window for model
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
num_frames
=
xs
.
shape
[
1
]
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
logger
.
info
(
"start to do model forward"
)
outputs
=
[]
# num_frames - context + 1 ensure that current frame can get context window
for
cur
in
range
(
0
,
num_frames
-
context
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
chunk_xs
=
xs
[:,
cur
:
end
,
:]
(
y
,
self
.
subsampling_cache
,
self
.
elayers_output_cache
,
self
.
conformer_cnn_cache
)
=
self
.
model
.
encoder
.
forward_chunk
(
chunk_xs
,
self
.
offset
,
required_cache_size
,
self
.
subsampling_cache
,
self
.
elayers_output_cache
,
self
.
conformer_cnn_cache
)
outputs
.
append
(
y
)
self
.
offset
+=
y
.
shape
[
1
]
ys
=
paddle
.
cat
(
outputs
,
1
)
masks
=
paddle
.
ones
([
1
,
ys
.
shape
[
1
]],
dtype
=
paddle
.
bool
)
masks
=
masks
.
unsqueeze
(
1
)
return
ys
,
masks
def
rescoring
(
self
,
encoder_out
,
device
):
logger
.
info
(
"start to rescoring the hyps"
)
beam_size
=
self
.
config
.
decode
.
beam_size
hyps
=
self
.
searcher
.
get_hyps
()
assert
len
(
hyps
)
==
beam_size
hyp_list
=
[]
for
hyp
in
hyps
:
hyp_content
=
hyp
[
0
]
# Prevent the hyp is empty
if
len
(
hyp_content
)
==
0
:
hyp_content
=
(
self
.
model
.
ctc
.
blank_id
,
)
hyp_content
=
paddle
.
to_tensor
(
hyp_content
,
place
=
device
,
dtype
=
paddle
.
long
)
hyp_list
.
append
(
hyp_content
)
hyps_pad
=
pad_sequence
(
hyp_list
,
True
,
self
.
model
.
ignore_id
)
hyps_lens
=
paddle
.
to_tensor
(
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
device
,
dtype
=
paddle
.
long
)
# (beam_size,)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
model
.
sos
,
self
.
model
.
eos
,
self
.
model
.
ignore_id
)
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out
=
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
paddle
.
ones
(
(
beam_size
,
1
,
encoder_out
.
shape
[
1
]),
dtype
=
paddle
.
bool
)
decoder_out
,
_
=
self
.
model
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
decoder_out
=
decoder_out
.
numpy
()
# Only use decoder score for rescoring
best_score
=
-
float
(
'inf'
)
best_index
=
0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for
i
,
hyp
in
enumerate
(
hyps
):
score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
score
+=
decoder_out
[
i
][
j
][
w
]
# last decoder output token is `eos`, for laste decoder input token.
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
model
.
eos
]
# add ctc score (which in ln domain)
score
+=
hyp
[
1
]
*
self
.
config
.
decode
.
ctc_weight
if
score
>
best_score
:
best_score
=
score
best_index
=
i
# update the one best result
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
return
hyps
[
best_index
][
0
]
def
transformer_decode_reset
(
self
):
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
offset
=
0
# decoding reset
self
.
searcher
.
reset
()
def
update_result
(
self
):
logger
.
info
(
"update the final result"
)
hyps
=
self
.
hyps
self
.
result_transcripts
=
[
self
.
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
hyps
]
def
extract_feat
(
self
,
samples
,
sample_rate
):
"""extract feat
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
if
"deepspeech2online"
in
self
.
model_type
:
# pcm16 -> pcm 32
samples
=
pcm2float
(
samples
)
# read audio
speech_segment
=
SpeechSegment
.
from_pcm
(
samples
,
sample_rate
,
transcript
=
" "
)
# audio augment
self
.
collate_fn_test
.
augmentation
.
transform_audio
(
speech_segment
)
# extract speech feature
spectrum
,
transcript_part
=
self
.
collate_fn_test
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
collate_fn_test
.
keep_transcription_text
)
# CMVN spectrum
if
self
.
collate_fn_test
.
_normalizer
:
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
# spectrum augment
audio
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
audio_len
=
audio
.
shape
[
0
]
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
# audio_len = paddle.to_tensor(audio_len)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
x_chunk
=
audio
.
numpy
()
x_chunk_lens
=
np
.
array
([
audio_len
])
return
x_chunk
,
x_chunk_lens
elif
"conformer_online"
in
self
.
model_type
:
if
sample_rate
!=
self
.
sample_rate
:
logger
.
info
(
f
"audio sample rate
{
sample_rate
}
is not match,"
"the model sample_rate is {self.sample_rate}"
)
logger
.
info
(
f
"ASR Engine use the
{
self
.
model_type
}
to process"
)
logger
.
info
(
"Create the preprocess instance"
)
preprocess_conf
=
self
.
config
.
preprocess_config
preprocess_args
=
{
"train"
:
False
}
preprocessing
=
Transformation
(
preprocess_conf
)
logger
.
info
(
"Read the audio file"
)
logger
.
info
(
f
"audio shape:
{
samples
.
shape
}
"
)
# fbank
x_chunk
=
preprocessing
(
samples
,
**
preprocess_args
)
x_chunk_lens
=
paddle
.
to_tensor
(
x_chunk
.
shape
[
0
])
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
logger
.
info
(
f
"process the audio feature success, feat shape:
{
x_chunk
.
shape
}
"
)
return
x_chunk
,
x_chunk_lens
return
True
class
ASREngine
(
BaseEngine
):
"""ASR server
engin
e
"""ASR server
resourc
e
Args:
metaclass: Defaults to Singleton.
...
...
@@ -1015,7 +866,7 @@ class ASREngine(BaseEngine):
def
__init__
(
self
):
super
(
ASREngine
,
self
).
__init__
()
logger
.
info
(
"create the online asr engine instance"
)
logger
.
info
(
"create the online asr engine
resource
instance"
)
def
init
(
self
,
config
:
dict
)
->
bool
:
"""init engine resource
...
...
@@ -1026,17 +877,12 @@ class ASREngine(BaseEngine):
Returns:
bool: init failed or success
"""
self
.
input
=
None
self
.
output
=
""
self
.
executor
=
ASRServerExecutor
()
self
.
config
=
config
self
.
executor
=
ASRServerExecutor
()
try
:
if
self
.
config
.
get
(
"device"
,
None
):
self
.
device
=
self
.
config
.
device
else
:
self
.
device
=
paddle
.
get_device
()
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
paddle
.
set_device
(
self
.
device
)
default_dev
=
paddle
.
get_device
()
paddle
.
set_device
(
self
.
config
.
get
(
"device"
,
default_dev
))
except
BaseException
as
e
:
logger
.
error
(
f
"Set device failed, please check if device '
{
self
.
device
}
' is already used and the parameter 'device' in the yaml file"
...
...
@@ -1045,6 +891,8 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
sys
.
exit
(
-
1
)
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
if
not
self
.
executor
.
_init_from_path
(
model_type
=
self
.
config
.
model_type
,
am_model
=
self
.
config
.
am_model
,
...
...
@@ -1062,42 +910,11 @@ class ASREngine(BaseEngine):
logger
.
info
(
"Initialize ASR server engine successfully."
)
return
True
def
preprocess
(
self
,
samples
,
sample_rate
,
model_type
=
"deepspeech2online_aishell-zh-16k"
):
"""preprocess
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
# if "deepspeech" in model_type:
x_chunk
,
x_chunk_lens
=
self
.
executor
.
extract_feat
(
samples
,
sample_rate
)
return
x_chunk
,
x_chunk_lens
def
preprocess
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"Online not using this."
)
def
run
(
self
,
x_chunk
,
x_chunk_lens
,
decoder_chunk_size
=
1
):
"""run online engine
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
decoder_chunk_size(int)
"""
self
.
output
=
self
.
executor
.
decode_one_chunk
(
x_chunk
,
x_chunk_lens
,
self
.
config
.
model_type
)
def
run
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"Online not using this."
)
def
postprocess
(
self
):
"""postprocess
"""
return
self
.
output
def
reset
(
self
):
"""reset engine decoder and inference state
"""
self
.
executor
.
reset_decoder_and_chunk
()
self
.
output
=
""
raise
NotImplementedError
(
"Online not using this."
)
paddlespeech/server/restful/api.py
浏览文件 @
f07f57a3
...
...
@@ -17,12 +17,12 @@ from typing import List
from
fastapi
import
APIRouter
from
paddlespeech.cli.log
import
logger
from
paddlespeech.server.restful.acs_api
import
router
as
acs_router
from
paddlespeech.server.restful.asr_api
import
router
as
asr_router
from
paddlespeech.server.restful.cls_api
import
router
as
cls_router
from
paddlespeech.server.restful.text_api
import
router
as
text_router
from
paddlespeech.server.restful.tts_api
import
router
as
tts_router
from
paddlespeech.server.restful.vector_api
import
router
as
vec_router
from
paddlespeech.server.restful.acs_api
import
router
as
acs_router
_router
=
APIRouter
()
...
...
paddlespeech/server/utils/audio_handler.py
浏览文件 @
f07f57a3
...
...
@@ -248,7 +248,7 @@ class ASRHttpHandler:
}
res
=
requests
.
post
(
url
=
self
.
url
,
data
=
json
.
dumps
(
data
))
return
res
.
json
()
...
...
paddlespeech/server/utils/buffer.py
浏览文件 @
f07f57a3
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
class
Frame
(
object
):
"""Represents a "frame" of audio data."""
...
...
@@ -45,7 +46,7 @@ class ChunkBuffer(object):
self
.
shift_ms
=
shift_ms
self
.
sample_rate
=
sample_rate
self
.
sample_width
=
sample_width
# int16 = 2; float32 = 4
self
.
window_sec
=
float
((
self
.
window_n
-
1
)
*
self
.
shift_ms
+
self
.
window_ms
)
/
1000.0
self
.
shift_sec
=
float
(
self
.
shift_n
*
self
.
shift_ms
/
1000.0
)
...
...
@@ -77,8 +78,8 @@ class ChunkBuffer(object):
offset
=
0
while
offset
+
self
.
window_bytes
<=
len
(
audio
):
yield
Frame
(
audio
[
offset
:
offset
+
self
.
window_bytes
],
self
.
timestamp
,
self
.
window_sec
)
yield
Frame
(
audio
[
offset
:
offset
+
self
.
window_bytes
],
self
.
timestamp
,
self
.
window_sec
)
self
.
timestamp
+=
self
.
shift_sec
offset
+=
self
.
shift_bytes
...
...
paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py
浏览文件 @
f07f57a3
...
...
@@ -176,7 +176,10 @@ def main():
parser
.
add_argument
(
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu == 0, use cpu or xpu."
)
parser
.
add_argument
(
"--nxpu"
,
type
=
int
,
default
=
0
,
help
=
"if nxpu == 0 and ngpu == 0, use cpu."
)
"--nxpu"
,
type
=
int
,
default
=
0
,
help
=
"if nxpu == 0 and ngpu == 0, use cpu."
)
args
,
_
=
parser
.
parse_known_args
()
...
...
paddlespeech/t2s/exps/speedyspeech/train.py
浏览文件 @
f07f57a3
...
...
@@ -188,7 +188,10 @@ def main():
parser
.
add_argument
(
"--dev-metadata"
,
type
=
str
,
help
=
"dev data."
)
parser
.
add_argument
(
"--output-dir"
,
type
=
str
,
help
=
"output dir."
)
parser
.
add_argument
(
"--nxpu"
,
type
=
int
,
default
=
0
,
help
=
"if nxpu == 0 and ngpu == 0, use cpu."
)
"--nxpu"
,
type
=
int
,
default
=
0
,
help
=
"if nxpu == 0 and ngpu == 0, use cpu."
)
parser
.
add_argument
(
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu == 0, use cpu or xpu"
)
...
...
paddlespeech/t2s/modules/transformer/repeat.py
浏览文件 @
f07f57a3
...
...
@@ -36,4 +36,4 @@ def repeat(N, fn):
Returns:
MultiSequential: Repeated model instance.
"""
return
MultiSequential
(
*
[
fn
(
n
)
for
n
in
range
(
N
)])
return
MultiSequential
(
*
[
fn
(
n
)
for
n
in
range
(
N
)])
setup.py
浏览文件 @
f07f57a3
...
...
@@ -98,7 +98,6 @@ requirements = {
}
def
check_call
(
cmd
:
str
,
shell
=
False
,
executable
=
None
):
try
:
sp
.
check_call
(
...
...
@@ -112,12 +111,13 @@ def check_call(cmd: str, shell=False, executable=None):
file
=
sys
.
stderr
)
raise
e
def
check_output
(
cmd
:
str
,
shell
=
False
):
try
:
out_bytes
=
sp
.
check_output
(
cmd
.
split
())
except
sp
.
CalledProcessError
as
e
:
out_bytes
=
e
.
output
# Output generated before error
code
=
e
.
returncode
# Return code
out_bytes
=
e
.
output
# Output generated before error
code
=
e
.
returncode
# Return code
print
(
f
"
{
__file__
}
:
{
inspect
.
currentframe
().
f_lineno
}
: CMD:
{
cmd
}
, Error:"
,
out_bytes
,
...
...
@@ -146,6 +146,7 @@ def _remove(files: str):
for
f
in
files
:
f
.
unlink
()
################################# Install ##################################
...
...
@@ -308,6 +309,5 @@ setup_info = dict(
]
})
with
version_info
():
setup
(
**
setup_info
)
tests/unit/cli/aishell_test_prepare.py
浏览文件 @
f07f57a3
...
...
@@ -20,7 +20,6 @@ of each audio file in the data set.
"""
import
argparse
import
codecs
import
json
import
os
from
pathlib
import
Path
...
...
@@ -89,7 +88,7 @@ def create_manifest(data_dir, manifest_path_prefix):
duration
=
float
(
len
(
audio_data
)
/
samplerate
)
text
=
transcript_dict
[
audio_id
]
json_lines
.
append
(
audio_path
)
reference_lines
.
append
(
str
(
total_num
+
1
)
+
"
\t
"
+
text
)
reference_lines
.
append
(
str
(
total_num
+
1
)
+
"
\t
"
+
text
)
total_sec
+=
duration
total_text
+=
len
(
text
)
...
...
@@ -106,6 +105,7 @@ def create_manifest(data_dir, manifest_path_prefix):
manifest_dir
=
os
.
path
.
dirname
(
manifest_path_prefix
)
def
prepare_dataset
(
url
,
md5sum
,
target_dir
,
manifest_path
=
None
):
"""Download, unpack and create manifest file."""
data_dir
=
os
.
path
.
join
(
target_dir
,
'data_aishell'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录