Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
f07f57a3
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f07f57a3
编写于
5月 25, 2022
作者:
H
Hui Zhang
提交者:
GitHub
5月 25, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1945 from PaddlePaddle/asr_line
[server][asr] refactor asr streaming server and remove useless code
上级
8f8239ad
c15278ed
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
247 addition
and
402 deletion
+247
-402
demos/audio_searching/src/operations/load.py
demos/audio_searching/src/operations/load.py
+2
-3
demos/streaming_asr_server/websocket_client.py
demos/streaming_asr_server/websocket_client.py
+0
-2
docs/source/asr/PPASR_cn.md
docs/source/asr/PPASR_cn.md
+0
-2
paddlespeech/cli/utils.py
paddlespeech/cli/utils.py
+1
-1
paddlespeech/s2t/io/sampler.py
paddlespeech/s2t/io/sampler.py
+1
-1
paddlespeech/server/engine/acs/__init__.py
paddlespeech/server/engine/acs/__init__.py
+13
-0
paddlespeech/server/engine/acs/python/__init__.py
paddlespeech/server/engine/acs/python/__init__.py
+13
-0
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+196
-379
paddlespeech/server/restful/api.py
paddlespeech/server/restful/api.py
+1
-1
paddlespeech/server/utils/audio_handler.py
paddlespeech/server/utils/audio_handler.py
+1
-1
paddlespeech/server/utils/buffer.py
paddlespeech/server/utils/buffer.py
+4
-3
paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py
paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py
+4
-1
paddlespeech/t2s/exps/speedyspeech/train.py
paddlespeech/t2s/exps/speedyspeech/train.py
+4
-1
paddlespeech/t2s/modules/transformer/repeat.py
paddlespeech/t2s/modules/transformer/repeat.py
+1
-1
setup.py
setup.py
+4
-4
tests/unit/cli/aishell_test_prepare.py
tests/unit/cli/aishell_test_prepare.py
+2
-2
未找到文件。
demos/audio_searching/src/operations/load.py
浏览文件 @
f07f57a3
...
@@ -26,9 +26,8 @@ def get_audios(path):
...
@@ -26,9 +26,8 @@ def get_audios(path):
"""
"""
supported_formats
=
[
".wav"
,
".mp3"
,
".ogg"
,
".flac"
,
".m4a"
]
supported_formats
=
[
".wav"
,
".mp3"
,
".ogg"
,
".flac"
,
".m4a"
]
return
[
return
[
item
item
for
sublist
in
[[
os
.
path
.
join
(
dir
,
file
)
for
file
in
files
]
for
sublist
in
[[
os
.
path
.
join
(
dir
,
file
)
for
file
in
files
]
for
dir
,
_
,
files
in
list
(
os
.
walk
(
path
))]
for
dir
,
_
,
files
in
list
(
os
.
walk
(
path
))]
for
item
in
sublist
if
os
.
path
.
splitext
(
item
)[
1
]
in
supported_formats
for
item
in
sublist
if
os
.
path
.
splitext
(
item
)[
1
]
in
supported_formats
]
]
...
...
demos/streaming_asr_server/websocket_client.py
浏览文件 @
f07f57a3
...
@@ -13,9 +13,7 @@
...
@@ -13,9 +13,7 @@
# limitations under the License.
# limitations under the License.
#!/usr/bin/python
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
# script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
import
argparse
import
argparse
import
asyncio
import
asyncio
import
codecs
import
codecs
...
...
docs/source/asr/PPASR_cn.md
浏览文件 @
f07f57a3
...
@@ -92,5 +92,3 @@ server 的 demo: [streaming_asr_server](https://github.com/PaddlePaddle/Paddle
...
@@ -92,5 +92,3 @@ server 的 demo: [streaming_asr_server](https://github.com/PaddlePaddle/Paddle
## 4. 快速开始
## 4. 快速开始
关于如果使用 PP-ASR,可以看这里的
[
install
](
https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md
)
,其中提供了
**简单**
、
**中等**
、
**困难**
三种安装方式。如果想体验 paddlespeech 的推理功能,可以用
**简单**
安装方式。
关于如果使用 PP-ASR,可以看这里的
[
install
](
https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md
)
,其中提供了
**简单**
、
**中等**
、
**困难**
三种安装方式。如果想体验 paddlespeech 的推理功能,可以用
**简单**
安装方式。
paddlespeech/cli/utils.py
浏览文件 @
f07f57a3
...
@@ -24,11 +24,11 @@ from typing import Any
...
@@ -24,11 +24,11 @@ from typing import Any
from
typing
import
Dict
from
typing
import
Dict
import
paddle
import
paddle
import
paddleaudio
import
requests
import
requests
import
yaml
import
yaml
from
paddle.framework
import
load
from
paddle.framework
import
load
import
paddleaudio
from
.
import
download
from
.
import
download
from
.entry
import
commands
from
.entry
import
commands
try
:
try
:
...
...
paddlespeech/s2t/io/sampler.py
浏览文件 @
f07f57a3
...
@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
...
@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
"""
rng
=
np
.
random
.
RandomState
(
epoch
)
rng
=
np
.
random
.
RandomState
(
epoch
)
shift_len
=
rng
.
randint
(
0
,
batch_size
-
1
)
shift_len
=
rng
.
randint
(
0
,
batch_size
-
1
)
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
batch_indices
=
list
(
zip
(
*
[
iter
(
indices
[
shift_len
:])]
*
batch_size
))
rng
.
shuffle
(
batch_indices
)
rng
.
shuffle
(
batch_indices
)
batch_indices
=
[
item
for
batch
in
batch_indices
for
item
in
batch
]
batch_indices
=
[
item
for
batch
in
batch_indices
for
item
in
batch
]
assert
clipped
is
False
assert
clipped
is
False
...
...
paddlespeech/server/engine/acs/__init__.py
浏览文件 @
f07f57a3
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/engine/acs/python/__init__.py
浏览文件 @
f07f57a3
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
f07f57a3
...
@@ -38,7 +38,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine
...
@@ -38,7 +38,7 @@ from paddlespeech.server.engine.base_engine import BaseEngine
from
paddlespeech.server.utils.audio_process
import
pcm2float
from
paddlespeech.server.utils.audio_process
import
pcm2float
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
__all__
=
[
'ASREngine'
]
__all__
=
[
'
PaddleASRConnectionHanddler'
,
'ASRServerExecutor'
,
'
ASREngine'
]
# ASR server connection process class
# ASR server connection process class
...
@@ -67,7 +67,7 @@ class PaddleASRConnectionHanddler:
...
@@ -67,7 +67,7 @@ class PaddleASRConnectionHanddler:
# tokens to text
# tokens to text
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
self
.
text_feature
=
self
.
asr_engine
.
executor
.
text_feature
if
"deepspeech2
online"
in
self
.
model_type
or
"deepspeech2offline
"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
from
paddlespeech.s2t.io.collator
import
SpeechCollator
from
paddlespeech.s2t.io.collator
import
SpeechCollator
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
self
.
am_predictor
=
self
.
asr_engine
.
executor
.
am_predictor
...
@@ -89,8 +89,8 @@ class PaddleASRConnectionHanddler:
...
@@ -89,8 +89,8 @@ class PaddleASRConnectionHanddler:
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
cfg
.
num_proc_bsearch
)
# frame window samples length and frame shift samples length
# frame window and frame shift, in samples unit
self
.
win_length
=
int
(
self
.
model_config
.
window_ms
/
1000
*
self
.
win_length
=
int
(
self
.
model_config
.
window_ms
/
1000
*
self
.
sample_rate
)
self
.
sample_rate
)
self
.
n_shift
=
int
(
self
.
model_config
.
stride_ms
/
1000
*
self
.
n_shift
=
int
(
self
.
model_config
.
stride_ms
/
1000
*
...
@@ -109,16 +109,15 @@ class PaddleASRConnectionHanddler:
...
@@ -109,16 +109,15 @@ class PaddleASRConnectionHanddler:
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocess_args
=
{
"train"
:
False
}
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
self
.
preprocessing
=
Transformation
(
self
.
preprocess_conf
)
# frame window
samples length and frame shift samples length
# frame window
and frame shift, in samples unit
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
win_length
=
self
.
preprocess_conf
.
process
[
0
][
'win_length'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
self
.
n_shift
=
self
.
preprocess_conf
.
process
[
0
][
'n_shift'
]
else
:
raise
ValueError
(
f
"Not supported:
{
self
.
model_type
}
"
)
def
extract_feat
(
self
,
samples
):
def
extract_feat
(
self
,
samples
):
# we compute the elapsed time of first char occuring
# we compute the elapsed time of first char occuring
# and we record the start time at the first pcm sample arraving
# and we record the start time at the first pcm sample arraving
# if self.first_char_occur_elapsed is not None:
# self.first_char_occur_elapsed = time.time()
if
"deepspeech2online"
in
self
.
model_type
:
if
"deepspeech2online"
in
self
.
model_type
:
# self.reamined_wav stores all the samples,
# self.reamined_wav stores all the samples,
...
@@ -154,28 +153,27 @@ class PaddleASRConnectionHanddler:
...
@@ -154,28 +153,27 @@ class PaddleASRConnectionHanddler:
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
# spectrum augment
# spectrum augment
audio
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
feat
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
spectrum
)
audio_len
=
audio
.
shape
[
0
]
# audio_len is frame num
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
frame_num
=
feat
.
shape
[
0
]
# audio_len = paddle.to_tensor(audio_len
)
feat
=
paddle
.
to_tensor
(
feat
,
dtype
=
'float32'
)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
feat
=
paddle
.
unsqueeze
(
feat
,
axis
=
0
)
if
self
.
cached_feat
is
None
:
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
audio
self
.
cached_feat
=
feat
else
:
else
:
assert
(
len
(
audio
.
shape
)
==
3
)
assert
(
len
(
feat
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
self
.
cached_feat
=
paddle
.
concat
(
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
audio
],
axis
=
1
)
[
self
.
cached_feat
,
feat
],
axis
=
1
)
# set the feat device
# set the feat device
if
self
.
device
is
None
:
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
self
.
device
=
self
.
cached_feat
.
place
self
.
num_frames
+=
audio_len
self
.
num_frames
+=
frame_num
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
audio_len
:]
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
frame_num
:]
logger
.
info
(
logger
.
info
(
f
"process the audio feature success, the connection feat shape:
{
self
.
cached_feat
.
shape
}
"
f
"process the audio feature success, the connection feat shape:
{
self
.
cached_feat
.
shape
}
"
...
@@ -183,25 +181,30 @@ class PaddleASRConnectionHanddler:
...
@@ -183,25 +181,30 @@ class PaddleASRConnectionHanddler:
logger
.
info
(
logger
.
info
(
f
"After extract feat, the connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
f
"After extract feat, the connection remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
)
elif
"conformer_online"
in
self
.
model_type
:
elif
"conformer_online"
in
self
.
model_type
:
logger
.
info
(
"Online ASR extract the feat"
)
logger
.
info
(
"Online ASR extract the feat"
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
samples
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
assert
samples
.
ndim
==
1
assert
samples
.
ndim
==
1
logger
.
info
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data"
)
self
.
num_samples
+=
samples
.
shape
[
0
]
self
.
num_samples
+=
samples
.
shape
[
0
]
logger
.
info
(
f
"This package receive
{
samples
.
shape
[
0
]
}
pcm data. Global samples:
{
self
.
num_samples
}
"
)
# self.reamined_wav stores all the samples,
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
# include the original remained_wav and this package samples
if
self
.
remained_wav
is
None
:
if
self
.
remained_wav
is
None
:
self
.
remained_wav
=
samples
self
.
remained_wav
=
samples
else
:
else
:
assert
self
.
remained_wav
.
ndim
==
1
assert
self
.
remained_wav
.
ndim
==
1
# (T,)
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
self
.
remained_wav
=
np
.
concatenate
([
self
.
remained_wav
,
samples
])
logger
.
info
(
logger
.
info
(
f
"The con
nection remain the audio sample
s:
{
self
.
remained_wav
.
shape
}
"
f
"The con
catenation of remain and now audio samples length i
s:
{
self
.
remained_wav
.
shape
}
"
)
)
if
len
(
self
.
remained_wav
)
<
self
.
win_length
:
if
len
(
self
.
remained_wav
)
<
self
.
win_length
:
# samples not enough for feature window
return
0
return
0
# fbank
# fbank
...
@@ -209,11 +212,13 @@ class PaddleASRConnectionHanddler:
...
@@ -209,11 +212,13 @@ class PaddleASRConnectionHanddler:
**
self
.
preprocess_args
)
**
self
.
preprocess_args
)
x_chunk
=
paddle
.
to_tensor
(
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
# feature cache
if
self
.
cached_feat
is
None
:
if
self
.
cached_feat
is
None
:
self
.
cached_feat
=
x_chunk
self
.
cached_feat
=
x_chunk
else
:
else
:
assert
(
len
(
x_chunk
.
shape
)
==
3
)
assert
(
len
(
x_chunk
.
shape
)
==
3
)
# (B,T,D)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
assert
(
len
(
self
.
cached_feat
.
shape
)
==
3
)
# (B,T,D)
self
.
cached_feat
=
paddle
.
concat
(
self
.
cached_feat
=
paddle
.
concat
(
[
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
[
self
.
cached_feat
,
x_chunk
],
axis
=
1
)
...
@@ -221,20 +226,28 @@ class PaddleASRConnectionHanddler:
...
@@ -221,20 +226,28 @@ class PaddleASRConnectionHanddler:
if
self
.
device
is
None
:
if
self
.
device
is
None
:
self
.
device
=
self
.
cached_feat
.
place
self
.
device
=
self
.
cached_feat
.
place
# cur frame step
num_frames
=
x_chunk
.
shape
[
1
]
num_frames
=
x_chunk
.
shape
[
1
]
# global frame step
self
.
num_frames
+=
num_frames
self
.
num_frames
+=
num_frames
# update remained wav
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
self
.
remained_wav
=
self
.
remained_wav
[
self
.
n_shift
*
num_frames
:]
logger
.
info
(
logger
.
info
(
f
"process the audio feature success, the c
onnection
feat shape:
{
self
.
cached_feat
.
shape
}
"
f
"process the audio feature success, the c
ached
feat shape:
{
self
.
cached_feat
.
shape
}
"
)
)
logger
.
info
(
logger
.
info
(
f
"After extract feat, the c
onnection
remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
f
"After extract feat, the c
ached
remain the audio samples:
{
self
.
remained_wav
.
shape
}
"
)
)
# logger.info(f"accumulate samples: {self.num_samples}")
logger
.
info
(
f
"global samples:
{
self
.
num_samples
}
"
)
logger
.
info
(
f
"global frames:
{
self
.
num_frames
}
"
)
else
:
raise
ValueError
(
f
"not supported:
{
self
.
model_type
}
"
)
def
reset
(
self
):
def
reset
(
self
):
if
"deepspeech2
online"
in
self
.
model_type
or
"deepspeech2offline
"
in
self
.
model_type
:
if
"deepspeech2"
in
self
.
model_type
:
# for deepspeech2
# for deepspeech2
self
.
chunk_state_h_box
=
copy
.
deepcopy
(
self
.
chunk_state_h_box
=
copy
.
deepcopy
(
self
.
asr_engine
.
executor
.
chunk_state_h_box
)
self
.
asr_engine
.
executor
.
chunk_state_h_box
)
...
@@ -242,35 +255,61 @@ class PaddleASRConnectionHanddler:
...
@@ -242,35 +255,61 @@ class PaddleASRConnectionHanddler:
self
.
asr_engine
.
executor
.
chunk_state_c_box
)
self
.
asr_engine
.
executor
.
chunk_state_c_box
)
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
# for conformer online
self
.
device
=
None
## common
# global sample and frame step
self
.
num_samples
=
0
self
.
num_frames
=
0
# cache for audio and feat
self
.
remained_wav
=
None
self
.
cached_feat
=
None
# partial/ending decoding results
self
.
result_transcripts
=
[
''
]
## conformer
# cache for conformer online
self
.
subsampling_cache
=
None
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
encoder_out
=
None
self
.
encoder_out
=
None
self
.
cached_feat
=
None
# conformer decoding state
self
.
remained_wav
=
None
self
.
chunk_num
=
0
# globa decoding chunk num
self
.
offset
=
0
self
.
offset
=
0
# global offset in decoding frame unit
self
.
num_samples
=
0
self
.
device
=
None
self
.
hyps
=
[]
self
.
hyps
=
[]
self
.
num_frames
=
0
self
.
chunk_num
=
0
# token timestamp result
self
.
global_frame_offset
=
0
self
.
result_transcripts
=
[
''
]
self
.
word_time_stamp
=
[]
self
.
word_time_stamp
=
[]
# one best timestamp viterbi prob is large.
self
.
time_stamp
=
[]
self
.
time_stamp
=
[]
self
.
first_char_occur_elapsed
=
None
def
decode
(
self
,
is_finished
=
False
):
def
decode
(
self
,
is_finished
=
False
):
"""advance decoding
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Raises:
Exception: when not support model.
Returns:
None: nothing
"""
if
"deepspeech2online"
in
self
.
model_type
:
if
"deepspeech2online"
in
self
.
model_type
:
# x_chunk 是特征数据
decoding_chunk_size
=
1
# decoding chunk size = 1. int decoding frame unit
decoding_chunk_size
=
1
# decoding_chunk_size=1 in deepspeech2 model
context
=
7
# context=7, in audio frame unit
context
=
7
# context=7 in deepspeech2 model
subsampling
=
4
# subsampling=4, in audio frame unit
subsampling
=
4
# subsampling=4 in deepspeech2 model
stride
=
subsampling
*
decoding_chunk_size
cached_feature_num
=
context
-
subsampling
cached_feature_num
=
context
-
subsampling
# decoding window for model
# decoding window for model
, in audio frame unit
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
# decoding stride for model, in audio frame unit
stride
=
subsampling
*
decoding_chunk_size
if
self
.
cached_feat
is
None
:
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
logger
.
info
(
"no audio feat, please input more pcm data"
)
...
@@ -280,6 +319,7 @@ class PaddleASRConnectionHanddler:
...
@@ -280,6 +319,7 @@ class PaddleASRConnectionHanddler:
logger
.
info
(
logger
.
info
(
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
f
"Required decoding window
{
decoding_window
}
frames, and the connection has
{
num_frames
}
frames"
)
)
# the cached feat must be larger decoding_window
# the cached feat must be larger decoding_window
if
num_frames
<
decoding_window
and
not
is_finished
:
if
num_frames
<
decoding_window
and
not
is_finished
:
logger
.
info
(
logger
.
info
(
...
@@ -293,6 +333,7 @@ class PaddleASRConnectionHanddler:
...
@@ -293,6 +333,7 @@ class PaddleASRConnectionHanddler:
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
)
return
None
,
None
return
None
,
None
logger
.
info
(
"start to do model forward"
)
logger
.
info
(
"start to do model forward"
)
# num_frames - context + 1 ensure that current frame can get context window
# num_frames - context + 1 ensure that current frame can get context window
if
is_finished
:
if
is_finished
:
...
@@ -302,6 +343,7 @@ class PaddleASRConnectionHanddler:
...
@@ -302,6 +343,7 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk
# we only process decoding_window frames for one chunk
left_frames
=
decoding_window
left_frames
=
decoding_window
end
=
None
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
end
=
min
(
cur
+
decoding_window
,
num_frames
)
# extract the audio
# extract the audio
...
@@ -311,7 +353,9 @@ class PaddleASRConnectionHanddler:
...
@@ -311,7 +353,9 @@ class PaddleASRConnectionHanddler:
self
.
result_transcripts
=
[
trans_best
]
self
.
result_transcripts
=
[
trans_best
]
# update feat cache
self
.
cached_feat
=
self
.
cached_feat
[:,
end
-
cached_feature_num
:,
:]
self
.
cached_feat
=
self
.
cached_feat
[:,
end
-
cached_feature_num
:,
:]
# return trans_best[0]
# return trans_best[0]
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
try
:
try
:
...
@@ -328,7 +372,16 @@ class PaddleASRConnectionHanddler:
...
@@ -328,7 +372,16 @@ class PaddleASRConnectionHanddler:
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
decode_one_chunk
(
self
,
x_chunk
,
x_chunk_lens
):
def
decode_one_chunk
(
self
,
x_chunk
,
x_chunk_lens
):
logger
.
info
(
"start to decoce one chunk with deepspeech2 model"
)
"""forward one chunk frames
Args:
x_chunk (np.ndarray): (B,T,D), audio frames.
x_chunk_lens ([type]): (B,), audio frame lens
Returns:
logprob: poster probability.
"""
logger
.
info
(
"start to decoce one chunk for deepspeech2"
)
input_names
=
self
.
am_predictor
.
get_input_names
()
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
...
@@ -365,24 +418,32 @@ class PaddleASRConnectionHanddler:
...
@@ -365,24 +418,32 @@ class PaddleASRConnectionHanddler:
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one best result:
{
trans_best
[
0
]
}
"
)
logger
.
info
(
f
"decode one best result
for deepspeech2
:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
return
trans_best
[
0
]
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
advance_decoding
(
self
,
is_finished
=
False
):
def
advance_decoding
(
self
,
is_finished
=
False
):
logger
.
info
(
"start to decode with advanced_decoding method"
)
logger
.
info
(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
cfg
=
self
.
ctc_decode_config
cfg
=
self
.
ctc_decode_config
# cur chunk size, in decoding frame unit
decoding_chunk_size
=
cfg
.
decoding_chunk_size
decoding_chunk_size
=
cfg
.
decoding_chunk_size
# using num of history chunks
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
assert
decoding_chunk_size
>
0
assert
decoding_chunk_size
>
0
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
stride
=
subsampling
*
decoding_chunk_size
cached_feature_num
=
context
-
subsampling
# processed chunk feature cached for next chunk
# decoding window for model
# processed chunk feature cached for next chunk
cached_feature_num
=
context
-
subsampling
# decoding stride, in audio frame unit
stride
=
subsampling
*
decoding_chunk_size
# decoding window, in audio frame unit
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
if
self
.
cached_feat
is
None
:
if
self
.
cached_feat
is
None
:
logger
.
info
(
"no audio feat, please input more pcm data"
)
logger
.
info
(
"no audio feat, please input more pcm data"
)
return
return
...
@@ -407,6 +468,7 @@ class PaddleASRConnectionHanddler:
...
@@ -407,6 +468,7 @@ class PaddleASRConnectionHanddler:
return
None
,
None
return
None
,
None
logger
.
info
(
"start to do model forward"
)
logger
.
info
(
"start to do model forward"
)
# hist of chunks, in deocding frame unit
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
outputs
=
[]
outputs
=
[]
...
@@ -423,8 +485,11 @@ class PaddleASRConnectionHanddler:
...
@@ -423,8 +485,11 @@ class PaddleASRConnectionHanddler:
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
for
cur
in
range
(
0
,
num_frames
-
left_frames
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
end
=
min
(
cur
+
decoding_window
,
num_frames
)
# global chunk_num
self
.
chunk_num
+=
1
self
.
chunk_num
+=
1
# cur chunk
chunk_xs
=
self
.
cached_feat
[:,
cur
:
end
,
:]
chunk_xs
=
self
.
cached_feat
[:,
cur
:
end
,
:]
# forward chunk
(
y
,
self
.
subsampling_cache
,
self
.
elayers_output_cache
,
(
y
,
self
.
subsampling_cache
,
self
.
elayers_output_cache
,
self
.
conformer_cnn_cache
)
=
self
.
model
.
encoder
.
forward_chunk
(
self
.
conformer_cnn_cache
)
=
self
.
model
.
encoder
.
forward_chunk
(
chunk_xs
,
self
.
offset
,
required_cache_size
,
chunk_xs
,
self
.
offset
,
required_cache_size
,
...
@@ -432,7 +497,7 @@ class PaddleASRConnectionHanddler:
...
@@ -432,7 +497,7 @@ class PaddleASRConnectionHanddler:
self
.
conformer_cnn_cache
)
self
.
conformer_cnn_cache
)
outputs
.
append
(
y
)
outputs
.
append
(
y
)
# update the
offse
t
# update the
global offset, in decoding frame uni
t
self
.
offset
+=
y
.
shape
[
1
]
self
.
offset
+=
y
.
shape
[
1
]
ys
=
paddle
.
cat
(
outputs
,
1
)
ys
=
paddle
.
cat
(
outputs
,
1
)
...
@@ -445,12 +510,15 @@ class PaddleASRConnectionHanddler:
...
@@ -445,12 +510,15 @@ class PaddleASRConnectionHanddler:
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
ys
)
# (1, maxlen, vocab_size)
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
ys
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
# advance decoding
self
.
searcher
.
search
(
ctc_probs
,
self
.
cached_feat
.
place
)
self
.
searcher
.
search
(
ctc_probs
,
self
.
cached_feat
.
place
)
# get one best hyps
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
assert
self
.
cached_feat
.
shape
[
0
]
==
1
assert
self
.
cached_feat
.
shape
[
0
]
==
1
assert
end
>=
cached_feature_num
assert
end
>=
cached_feature_num
# advance cache of feat
self
.
cached_feat
=
self
.
cached_feat
[
0
,
end
-
self
.
cached_feat
=
self
.
cached_feat
[
0
,
end
-
cached_feature_num
:,
:].
unsqueeze
(
0
)
cached_feature_num
:,
:].
unsqueeze
(
0
)
assert
len
(
assert
len
(
...
@@ -462,50 +530,81 @@ class PaddleASRConnectionHanddler:
...
@@ -462,50 +530,81 @@ class PaddleASRConnectionHanddler:
)
)
def
update_result
(
self
):
def
update_result
(
self
):
"""Conformer/Transformer hyps to result.
"""
logger
.
info
(
"update the final result"
)
logger
.
info
(
"update the final result"
)
hyps
=
self
.
hyps
hyps
=
self
.
hyps
# output results and tokenids
self
.
result_transcripts
=
[
self
.
result_transcripts
=
[
self
.
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
self
.
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
hyps
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
hyps
]
def
get_result
(
self
):
def
get_result
(
self
):
"""return partial/ending asr result.
Returns:
str: one best result of partial/ending.
"""
if
len
(
self
.
result_transcripts
)
>
0
:
if
len
(
self
.
result_transcripts
)
>
0
:
return
self
.
result_transcripts
[
0
]
return
self
.
result_transcripts
[
0
]
else
:
else
:
return
''
return
''
def
get_word_time_stamp
(
self
):
def
get_word_time_stamp
(
self
):
"""return token timestamp result.
Returns:
list: List of ('w':token, 'bg':time, 'ed':time)
"""
return
self
.
word_time_stamp
return
self
.
word_time_stamp
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
rescoring
(
self
):
def
rescoring
(
self
):
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
"""Second-Pass Decoding,
only for conformer and transformer model.
"""
if
"deepspeech2"
in
self
.
model_type
:
logger
.
info
(
"deepspeech2 not support rescoring decoding."
)
return
return
logger
.
info
(
"rescoring the final result"
)
if
"attention_rescoring"
!=
self
.
ctc_decode_config
.
decoding_method
:
if
"attention_rescoring"
!=
self
.
ctc_decode_config
.
decoding_method
:
logger
.
info
(
f
"decoding method not match:
{
self
.
ctc_decode_config
.
decoding_method
}
, need attention_rescoring"
)
return
return
logger
.
info
(
"rescoring the final result"
)
# last decoding for last audio
self
.
searcher
.
finalize_search
()
self
.
searcher
.
finalize_search
()
# update beam search results
self
.
update_result
()
self
.
update_result
()
beam_size
=
self
.
ctc_decode_config
.
beam_size
beam_size
=
self
.
ctc_decode_config
.
beam_size
hyps
=
self
.
searcher
.
get_hyps
()
hyps
=
self
.
searcher
.
get_hyps
()
if
hyps
is
None
or
len
(
hyps
)
==
0
:
if
hyps
is
None
or
len
(
hyps
)
==
0
:
logger
.
info
(
"No Hyps!"
)
return
return
# rescore by decoder post probability
# assert len(hyps) == beam_size
# assert len(hyps) == beam_size
# list of Tensor
hyp_list
=
[]
hyp_list
=
[]
for
hyp
in
hyps
:
for
hyp
in
hyps
:
hyp_content
=
hyp
[
0
]
hyp_content
=
hyp
[
0
]
# Prevent the hyp is empty
# Prevent the hyp is empty
if
len
(
hyp_content
)
==
0
:
if
len
(
hyp_content
)
==
0
:
hyp_content
=
(
self
.
model
.
ctc
.
blank_id
,
)
hyp_content
=
(
self
.
model
.
ctc
.
blank_id
,
)
hyp_content
=
paddle
.
to_tensor
(
hyp_content
=
paddle
.
to_tensor
(
hyp_content
,
place
=
self
.
device
,
dtype
=
paddle
.
long
)
hyp_content
,
place
=
self
.
device
,
dtype
=
paddle
.
long
)
hyp_list
.
append
(
hyp_content
)
hyp_list
.
append
(
hyp_content
)
hyps_pad
=
pad_sequence
(
hyp_list
,
True
,
self
.
model
.
ignore_id
)
hyps_pad
=
pad_sequence
(
hyp_list
,
batch_first
=
True
,
padding_value
=
self
.
model
.
ignore_id
)
hyps_lens
=
paddle
.
to_tensor
(
hyps_lens
=
paddle
.
to_tensor
(
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
self
.
device
,
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
self
.
device
,
dtype
=
paddle
.
long
)
# (beam_size,)
dtype
=
paddle
.
long
)
# (beam_size,)
...
@@ -531,10 +630,12 @@ class PaddleASRConnectionHanddler:
...
@@ -531,10 +630,12 @@ class PaddleASRConnectionHanddler:
score
=
0.0
score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
for
j
,
w
in
enumerate
(
hyp
[
0
]):
score
+=
decoder_out
[
i
][
j
][
w
]
score
+=
decoder_out
[
i
][
j
][
w
]
# last decoder output token is `eos`, for laste decoder input token.
# last decoder output token is `eos`, for laste decoder input token.
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
model
.
eos
]
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
model
.
eos
]
# add ctc score (which in ln domain)
# add ctc score (which in ln domain)
score
+=
hyp
[
1
]
*
self
.
ctc_decode_config
.
ctc_weight
score
+=
hyp
[
1
]
*
self
.
ctc_decode_config
.
ctc_weight
if
score
>
best_score
:
if
score
>
best_score
:
best_score
=
score
best_score
=
score
best_index
=
i
best_index
=
i
...
@@ -542,43 +643,52 @@ class PaddleASRConnectionHanddler:
...
@@ -542,43 +643,52 @@ class PaddleASRConnectionHanddler:
# update the one best result
# update the one best result
# hyps stored the beam results and each fields is:
# hyps stored the beam results and each fields is:
logger
.
info
(
f
"best index:
{
best_index
}
"
)
logger
.
info
(
f
"best
hyp
index:
{
best_index
}
"
)
# logger.info(f'best result: {hyps[best_index]}')
# logger.info(f'best result: {hyps[best_index]}')
# the field of the hyps is:
# the field of the hyps is:
## asr results
# hyps[0][0]: the sentence word-id in the vocab with a tuple
# hyps[0][0]: the sentence word-id in the vocab with a tuple
# hyps[0][1]: the sentence decoding probability with all paths
# hyps[0][1]: the sentence decoding probability with all paths
## timestamp
# hyps[0][2]: viterbi_blank ending probability
# hyps[0][2]: viterbi_blank ending probability
# hyps[0][3]: viterbi_non_blank probability
# hyps[0][3]: viterbi_non_blank
dending
probability
# hyps[0][4]: current_token_prob,
# hyps[0][4]: current_token_prob,
# hyps[0][5]: times_viterbi_blank,
# hyps[0][5]: times_viterbi_blank
ending timestamp
,
# hyps[0][6]: times_titerbi_non_blank
# hyps[0][6]: times_titerbi_non_blank
encding timestamp.
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
logger
.
info
(
f
"best hyp ids:
{
self
.
hyps
}
"
)
# update the hyps time stamp
# update the hyps time stamp
self
.
time_stamp
=
hyps
[
best_index
][
5
]
if
hyps
[
best_index
][
2
]
>
hyps
[
self
.
time_stamp
=
hyps
[
best_index
][
5
]
if
hyps
[
best_index
][
2
]
>
hyps
[
best_index
][
3
]
else
hyps
[
best_index
][
6
]
best_index
][
3
]
else
hyps
[
best_index
][
6
]
logger
.
info
(
f
"time stamp:
{
self
.
time_stamp
}
"
)
logger
.
info
(
f
"time stamp:
{
self
.
time_stamp
}
"
)
# update one best result
self
.
update_result
()
self
.
update_result
()
# update each word start and end time stamp
# update each word start and end time stamp
frame_shift_in_ms
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
*
self
.
n_shift
/
self
.
sample_rate
# decoding frame to audio frame
logger
.
info
(
f
"frame shift ms:
{
frame_shift_in_ms
}
"
)
frame_shift
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
frame_shift_in_sec
=
frame_shift
*
(
self
.
n_shift
/
self
.
sample_rate
)
logger
.
info
(
f
"frame shift sec:
{
frame_shift_in_sec
}
"
)
word_time_stamp
=
[]
word_time_stamp
=
[]
for
idx
,
_
in
enumerate
(
self
.
time_stamp
):
for
idx
,
_
in
enumerate
(
self
.
time_stamp
):
start
=
(
self
.
time_stamp
[
idx
-
1
]
+
self
.
time_stamp
[
idx
]
start
=
(
self
.
time_stamp
[
idx
-
1
]
+
self
.
time_stamp
[
idx
]
)
/
2.0
if
idx
>
0
else
0
)
/
2.0
if
idx
>
0
else
0
start
=
start
*
frame_shift_in_
ms
start
=
start
*
frame_shift_in_
sec
end
=
(
self
.
time_stamp
[
idx
]
+
self
.
time_stamp
[
idx
+
1
]
end
=
(
self
.
time_stamp
[
idx
]
+
self
.
time_stamp
[
idx
+
1
]
)
/
2.0
if
idx
<
len
(
self
.
time_stamp
)
-
1
else
self
.
offset
)
/
2.0
if
idx
<
len
(
self
.
time_stamp
)
-
1
else
self
.
offset
end
=
end
*
frame_shift_in_ms
end
=
end
*
frame_shift_in_sec
word_time_stamp
.
append
({
word_time_stamp
.
append
({
"w"
:
self
.
result_transcripts
[
0
][
idx
],
"w"
:
self
.
result_transcripts
[
0
][
idx
],
"bg"
:
start
,
"bg"
:
start
,
"ed"
:
end
"ed"
:
end
})
})
# logger.info(f"{self.result_transcripts[0][idx]}, start: {start}, end: {end}")
# logger.info(f"{word_time_stamp[-1]}")
self
.
word_time_stamp
=
word_time_stamp
self
.
word_time_stamp
=
word_time_stamp
logger
.
info
(
f
"word time stamp:
{
self
.
word_time_stamp
}
"
)
logger
.
info
(
f
"word time stamp:
{
self
.
word_time_stamp
}
"
)
...
@@ -610,6 +720,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -610,6 +720,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
sample_rate_str
=
'16k'
if
sample_rate
==
16000
else
'8k'
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
tag
=
model_type
+
'-'
+
lang
+
'-'
+
sample_rate_str
if
cfg_path
is
None
or
am_model
is
None
or
am_params
is
None
:
if
cfg_path
is
None
or
am_model
is
None
or
am_params
is
None
:
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
logger
.
info
(
f
"Load the pretrained model, tag =
{
tag
}
"
)
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
res_path
=
self
.
_get_pretrained_path
(
tag
)
# wenetspeech_zh
...
@@ -639,7 +750,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -639,7 +750,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
self
.
config
.
merge_from_file
(
self
.
cfg_path
)
with
UpdateConfig
(
self
.
config
):
with
UpdateConfig
(
self
.
config
):
if
"deepspeech2
online"
in
model_type
or
"deepspeech2offline
"
in
model_type
:
if
"deepspeech2"
in
model_type
:
from
paddlespeech.s2t.io.collator
import
SpeechCollator
from
paddlespeech.s2t.io.collator
import
SpeechCollator
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
config
.
decode
.
lang_model_path
=
os
.
path
.
join
(
self
.
config
.
decode
.
lang_model_path
=
os
.
path
.
join
(
...
@@ -655,6 +766,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -655,6 +766,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
download_lm
(
self
.
download_lm
(
lm_url
,
lm_url
,
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
os
.
path
.
dirname
(
self
.
config
.
decode
.
lang_model_path
),
lm_md5
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
logger
.
info
(
"start to create the stream conformer asr engine"
)
logger
.
info
(
"start to create the stream conformer asr engine"
)
if
self
.
config
.
spm_model_prefix
:
if
self
.
config
.
spm_model_prefix
:
...
@@ -682,7 +794,8 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -682,7 +794,8 @@ class ASRServerExecutor(ASRExecutor):
],
f
"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is
{
self
.
config
.
decode
.
decoding_method
}
"
],
f
"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is
{
self
.
config
.
decode
.
decoding_method
}
"
else
:
else
:
raise
Exception
(
"wrong type"
)
raise
Exception
(
"wrong type"
)
if
"deepspeech2online"
in
model_type
or
"deepspeech2offline"
in
model_type
:
if
"deepspeech2"
in
model_type
:
# AM predictor
# AM predictor
logger
.
info
(
"ASR engine start to init the am predictor"
)
logger
.
info
(
"ASR engine start to init the am predictor"
)
self
.
am_predictor_conf
=
am_predictor_conf
self
.
am_predictor_conf
=
am_predictor_conf
...
@@ -719,6 +832,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -719,6 +832,7 @@ class ASRServerExecutor(ASRExecutor):
self
.
chunk_state_c_box
=
np
.
zeros
(
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
dtype
=
float32
)
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
model_name
=
model_type
[:
model_type
.
rindex
(
model_name
=
model_type
[:
model_type
.
rindex
(
'_'
)]
# model_type: {model_name}_{dataset}
'_'
)]
# model_type: {model_name}_{dataset}
...
@@ -737,277 +851,14 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -737,277 +851,14 @@ class ASRServerExecutor(ASRExecutor):
# update the ctc decoding
# update the ctc decoding
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
config
.
decode
)
self
.
searcher
=
CTCPrefixBeamSearch
(
self
.
config
.
decode
)
self
.
transformer_decode_reset
()
self
.
transformer_decode_reset
()
return
True
def
reset_decoder_and_chunk
(
self
):
"""reset decoder and chunk state for an new audio
"""
if
"deepspeech2online"
in
self
.
model_type
or
"deepspeech2offline"
in
self
.
model_type
:
self
.
decoder
.
reset_decoder
(
batch_size
=
1
)
# init state box, for new audio request
self
.
chunk_state_h_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
self
.
chunk_state_c_box
=
np
.
zeros
(
(
self
.
config
.
num_rnn_layers
,
1
,
self
.
config
.
rnn_layer_size
),
dtype
=
float32
)
elif
"conformer"
in
self
.
model_type
or
"transformer"
in
self
.
model_type
:
self
.
transformer_decode_reset
()
def
decode_one_chunk
(
self
,
x_chunk
,
x_chunk_lens
,
model_type
:
str
):
"""decode one chunk
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
model_type (str): online model type
Returns:
str: one best result
"""
logger
.
info
(
"start to decoce chunk by chunk"
)
if
"deepspeech2online"
in
model_type
:
input_names
=
self
.
am_predictor
.
get_input_names
()
audio_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
0
])
audio_len_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
1
])
h_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
2
])
c_box_handle
=
self
.
am_predictor
.
get_input_handle
(
input_names
[
3
])
audio_handle
.
reshape
(
x_chunk
.
shape
)
audio_handle
.
copy_from_cpu
(
x_chunk
)
audio_len_handle
.
reshape
(
x_chunk_lens
.
shape
)
audio_len_handle
.
copy_from_cpu
(
x_chunk_lens
)
h_box_handle
.
reshape
(
self
.
chunk_state_h_box
.
shape
)
h_box_handle
.
copy_from_cpu
(
self
.
chunk_state_h_box
)
c_box_handle
.
reshape
(
self
.
chunk_state_c_box
.
shape
)
c_box_handle
.
copy_from_cpu
(
self
.
chunk_state_c_box
)
output_names
=
self
.
am_predictor
.
get_output_names
()
output_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
0
])
output_lens_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
1
])
output_state_h_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
2
])
output_state_c_handle
=
self
.
am_predictor
.
get_output_handle
(
output_names
[
3
])
self
.
am_predictor
.
run
()
output_chunk_probs
=
output_handle
.
copy_to_cpu
()
output_chunk_lens
=
output_lens_handle
.
copy_to_cpu
()
self
.
chunk_state_h_box
=
output_state_h_handle
.
copy_to_cpu
()
self
.
chunk_state_c_box
=
output_state_c_handle
.
copy_to_cpu
()
self
.
decoder
.
next
(
output_chunk_probs
,
output_chunk_lens
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
logger
.
info
(
f
"decode one best result:
{
trans_best
[
0
]
}
"
)
return
trans_best
[
0
]
elif
"conformer"
in
model_type
or
"transformer"
in
model_type
:
try
:
logger
.
info
(
f
"we will use the transformer like model :
{
self
.
model_type
}
"
)
self
.
advanced_decoding
(
x_chunk
,
x_chunk_lens
)
self
.
update_result
()
return
self
.
result_transcripts
[
0
]
except
Exception
as
e
:
logger
.
exception
(
e
)
else
:
else
:
raise
Exception
(
"invalid model name
"
)
raise
ValueError
(
f
"Not support:
{
model_type
}
"
)
def
advanced_decoding
(
self
,
xs
:
paddle
.
Tensor
,
x_chunk_lens
):
return
True
logger
.
info
(
"start to decode with advanced_decoding method"
)
encoder_out
,
encoder_mask
=
self
.
encoder_forward
(
xs
)
ctc_probs
=
self
.
model
.
ctc
.
log_softmax
(
encoder_out
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
self
.
searcher
.
search
(
ctc_probs
,
xs
.
place
)
# update the one best result
self
.
hyps
=
self
.
searcher
.
get_one_best_hyps
()
# now we supprot ctc_prefix_beam_search and attention_rescoring
if
"attention_rescoring"
in
self
.
config
.
decode
.
decoding_method
:
self
.
rescoring
(
encoder_out
,
xs
.
place
)
def
encoder_forward
(
self
,
xs
):
logger
.
info
(
"get the model out from the feat"
)
cfg
=
self
.
config
.
decode
decoding_chunk_size
=
cfg
.
decoding_chunk_size
num_decoding_left_chunks
=
cfg
.
num_decoding_left_chunks
assert
decoding_chunk_size
>
0
subsampling
=
self
.
model
.
encoder
.
embed
.
subsampling_rate
context
=
self
.
model
.
encoder
.
embed
.
right_context
+
1
stride
=
subsampling
*
decoding_chunk_size
# decoding window for model
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
num_frames
=
xs
.
shape
[
1
]
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
logger
.
info
(
"start to do model forward"
)
outputs
=
[]
# num_frames - context + 1 ensure that current frame can get context window
for
cur
in
range
(
0
,
num_frames
-
context
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
chunk_xs
=
xs
[:,
cur
:
end
,
:]
(
y
,
self
.
subsampling_cache
,
self
.
elayers_output_cache
,
self
.
conformer_cnn_cache
)
=
self
.
model
.
encoder
.
forward_chunk
(
chunk_xs
,
self
.
offset
,
required_cache_size
,
self
.
subsampling_cache
,
self
.
elayers_output_cache
,
self
.
conformer_cnn_cache
)
outputs
.
append
(
y
)
self
.
offset
+=
y
.
shape
[
1
]
ys
=
paddle
.
cat
(
outputs
,
1
)
masks
=
paddle
.
ones
([
1
,
ys
.
shape
[
1
]],
dtype
=
paddle
.
bool
)
masks
=
masks
.
unsqueeze
(
1
)
return
ys
,
masks
def
rescoring
(
self
,
encoder_out
,
device
):
logger
.
info
(
"start to rescoring the hyps"
)
beam_size
=
self
.
config
.
decode
.
beam_size
hyps
=
self
.
searcher
.
get_hyps
()
assert
len
(
hyps
)
==
beam_size
hyp_list
=
[]
for
hyp
in
hyps
:
hyp_content
=
hyp
[
0
]
# Prevent the hyp is empty
if
len
(
hyp_content
)
==
0
:
hyp_content
=
(
self
.
model
.
ctc
.
blank_id
,
)
hyp_content
=
paddle
.
to_tensor
(
hyp_content
,
place
=
device
,
dtype
=
paddle
.
long
)
hyp_list
.
append
(
hyp_content
)
hyps_pad
=
pad_sequence
(
hyp_list
,
True
,
self
.
model
.
ignore_id
)
hyps_lens
=
paddle
.
to_tensor
(
[
len
(
hyp
[
0
])
for
hyp
in
hyps
],
place
=
device
,
dtype
=
paddle
.
long
)
# (beam_size,)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
model
.
sos
,
self
.
model
.
eos
,
self
.
model
.
ignore_id
)
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out
=
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
paddle
.
ones
(
(
beam_size
,
1
,
encoder_out
.
shape
[
1
]),
dtype
=
paddle
.
bool
)
decoder_out
,
_
=
self
.
model
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
decoder_out
=
decoder_out
.
numpy
()
# Only use decoder score for rescoring
best_score
=
-
float
(
'inf'
)
best_index
=
0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for
i
,
hyp
in
enumerate
(
hyps
):
score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
score
+=
decoder_out
[
i
][
j
][
w
]
# last decoder output token is `eos`, for laste decoder input token.
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
model
.
eos
]
# add ctc score (which in ln domain)
score
+=
hyp
[
1
]
*
self
.
config
.
decode
.
ctc_weight
if
score
>
best_score
:
best_score
=
score
best_index
=
i
# update the one best result
self
.
hyps
=
[
hyps
[
best_index
][
0
]]
return
hyps
[
best_index
][
0
]
def
transformer_decode_reset
(
self
):
self
.
subsampling_cache
=
None
self
.
elayers_output_cache
=
None
self
.
conformer_cnn_cache
=
None
self
.
offset
=
0
# decoding reset
self
.
searcher
.
reset
()
def
update_result
(
self
):
logger
.
info
(
"update the final result"
)
hyps
=
self
.
hyps
self
.
result_transcripts
=
[
self
.
text_feature
.
defeaturize
(
hyp
)
for
hyp
in
hyps
]
self
.
result_tokenids
=
[
hyp
for
hyp
in
hyps
]
def
extract_feat
(
self
,
samples
,
sample_rate
):
"""extract feat
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
if
"deepspeech2online"
in
self
.
model_type
:
# pcm16 -> pcm 32
samples
=
pcm2float
(
samples
)
# read audio
speech_segment
=
SpeechSegment
.
from_pcm
(
samples
,
sample_rate
,
transcript
=
" "
)
# audio augment
self
.
collate_fn_test
.
augmentation
.
transform_audio
(
speech_segment
)
# extract speech feature
spectrum
,
transcript_part
=
self
.
collate_fn_test
.
_speech_featurizer
.
featurize
(
speech_segment
,
self
.
collate_fn_test
.
keep_transcription_text
)
# CMVN spectrum
if
self
.
collate_fn_test
.
_normalizer
:
spectrum
=
self
.
collate_fn_test
.
_normalizer
.
apply
(
spectrum
)
# spectrum augment
audio
=
self
.
collate_fn_test
.
augmentation
.
transform_feature
(
spectrum
)
audio_len
=
audio
.
shape
[
0
]
audio
=
paddle
.
to_tensor
(
audio
,
dtype
=
'float32'
)
# audio_len = paddle.to_tensor(audio_len)
audio
=
paddle
.
unsqueeze
(
audio
,
axis
=
0
)
x_chunk
=
audio
.
numpy
()
x_chunk_lens
=
np
.
array
([
audio_len
])
return
x_chunk
,
x_chunk_lens
elif
"conformer_online"
in
self
.
model_type
:
if
sample_rate
!=
self
.
sample_rate
:
logger
.
info
(
f
"audio sample rate
{
sample_rate
}
is not match,"
"the model sample_rate is {self.sample_rate}"
)
logger
.
info
(
f
"ASR Engine use the
{
self
.
model_type
}
to process"
)
logger
.
info
(
"Create the preprocess instance"
)
preprocess_conf
=
self
.
config
.
preprocess_config
preprocess_args
=
{
"train"
:
False
}
preprocessing
=
Transformation
(
preprocess_conf
)
logger
.
info
(
"Read the audio file"
)
logger
.
info
(
f
"audio shape:
{
samples
.
shape
}
"
)
# fbank
x_chunk
=
preprocessing
(
samples
,
**
preprocess_args
)
x_chunk_lens
=
paddle
.
to_tensor
(
x_chunk
.
shape
[
0
])
x_chunk
=
paddle
.
to_tensor
(
x_chunk
,
dtype
=
"float32"
).
unsqueeze
(
axis
=
0
)
logger
.
info
(
f
"process the audio feature success, feat shape:
{
x_chunk
.
shape
}
"
)
return
x_chunk
,
x_chunk_lens
class
ASREngine
(
BaseEngine
):
class
ASREngine
(
BaseEngine
):
"""ASR server
engin
e
"""ASR server
resourc
e
Args:
Args:
metaclass: Defaults to Singleton.
metaclass: Defaults to Singleton.
...
@@ -1015,7 +866,7 @@ class ASREngine(BaseEngine):
...
@@ -1015,7 +866,7 @@ class ASREngine(BaseEngine):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
ASREngine
,
self
).
__init__
()
super
(
ASREngine
,
self
).
__init__
()
logger
.
info
(
"create the online asr engine instance"
)
logger
.
info
(
"create the online asr engine
resource
instance"
)
def
init
(
self
,
config
:
dict
)
->
bool
:
def
init
(
self
,
config
:
dict
)
->
bool
:
"""init engine resource
"""init engine resource
...
@@ -1026,17 +877,12 @@ class ASREngine(BaseEngine):
...
@@ -1026,17 +877,12 @@ class ASREngine(BaseEngine):
Returns:
Returns:
bool: init failed or success
bool: init failed or success
"""
"""
self
.
input
=
None
self
.
output
=
""
self
.
executor
=
ASRServerExecutor
()
self
.
config
=
config
self
.
config
=
config
self
.
executor
=
ASRServerExecutor
()
try
:
try
:
if
self
.
config
.
get
(
"device"
,
None
):
default_dev
=
paddle
.
get_device
()
self
.
device
=
self
.
config
.
device
paddle
.
set_device
(
self
.
config
.
get
(
"device"
,
default_dev
))
else
:
self
.
device
=
paddle
.
get_device
()
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
paddle
.
set_device
(
self
.
device
)
except
BaseException
as
e
:
except
BaseException
as
e
:
logger
.
error
(
logger
.
error
(
f
"Set device failed, please check if device '
{
self
.
device
}
' is already used and the parameter 'device' in the yaml file"
f
"Set device failed, please check if device '
{
self
.
device
}
' is already used and the parameter 'device' in the yaml file"
...
@@ -1045,6 +891,8 @@ class ASREngine(BaseEngine):
...
@@ -1045,6 +891,8 @@ class ASREngine(BaseEngine):
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
"If all GPU or XPU is used, you can set the server to 'cpu'"
)
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
logger
.
info
(
f
"paddlespeech_server set the device:
{
self
.
device
}
"
)
if
not
self
.
executor
.
_init_from_path
(
if
not
self
.
executor
.
_init_from_path
(
model_type
=
self
.
config
.
model_type
,
model_type
=
self
.
config
.
model_type
,
am_model
=
self
.
config
.
am_model
,
am_model
=
self
.
config
.
am_model
,
...
@@ -1062,42 +910,11 @@ class ASREngine(BaseEngine):
...
@@ -1062,42 +910,11 @@ class ASREngine(BaseEngine):
logger
.
info
(
"Initialize ASR server engine successfully."
)
logger
.
info
(
"Initialize ASR server engine successfully."
)
return
True
return
True
def
preprocess
(
self
,
def
preprocess
(
self
,
*
args
,
**
kwargs
):
samples
,
raise
NotImplementedError
(
"Online not using this."
)
sample_rate
,
model_type
=
"deepspeech2online_aishell-zh-16k"
):
"""preprocess
Args:
samples (numpy.array): numpy.float32
sample_rate (int): sample rate
Returns:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
"""
# if "deepspeech" in model_type:
x_chunk
,
x_chunk_lens
=
self
.
executor
.
extract_feat
(
samples
,
sample_rate
)
return
x_chunk
,
x_chunk_lens
def
run
(
self
,
x_chunk
,
x_chunk_lens
,
decoder_chunk_size
=
1
):
def
run
(
self
,
*
args
,
**
kwargs
):
"""run online engine
raise
NotImplementedError
(
"Online not using this."
)
Args:
x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B]
decoder_chunk_size(int)
"""
self
.
output
=
self
.
executor
.
decode_one_chunk
(
x_chunk
,
x_chunk_lens
,
self
.
config
.
model_type
)
def
postprocess
(
self
):
def
postprocess
(
self
):
"""postprocess
raise
NotImplementedError
(
"Online not using this."
)
"""
return
self
.
output
def
reset
(
self
):
"""reset engine decoder and inference state
"""
self
.
executor
.
reset_decoder_and_chunk
()
self
.
output
=
""
paddlespeech/server/restful/api.py
浏览文件 @
f07f57a3
...
@@ -17,12 +17,12 @@ from typing import List
...
@@ -17,12 +17,12 @@ from typing import List
from
fastapi
import
APIRouter
from
fastapi
import
APIRouter
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.log
import
logger
from
paddlespeech.server.restful.acs_api
import
router
as
acs_router
from
paddlespeech.server.restful.asr_api
import
router
as
asr_router
from
paddlespeech.server.restful.asr_api
import
router
as
asr_router
from
paddlespeech.server.restful.cls_api
import
router
as
cls_router
from
paddlespeech.server.restful.cls_api
import
router
as
cls_router
from
paddlespeech.server.restful.text_api
import
router
as
text_router
from
paddlespeech.server.restful.text_api
import
router
as
text_router
from
paddlespeech.server.restful.tts_api
import
router
as
tts_router
from
paddlespeech.server.restful.tts_api
import
router
as
tts_router
from
paddlespeech.server.restful.vector_api
import
router
as
vec_router
from
paddlespeech.server.restful.vector_api
import
router
as
vec_router
from
paddlespeech.server.restful.acs_api
import
router
as
acs_router
_router
=
APIRouter
()
_router
=
APIRouter
()
...
...
paddlespeech/server/utils/audio_handler.py
浏览文件 @
f07f57a3
...
@@ -248,7 +248,7 @@ class ASRHttpHandler:
...
@@ -248,7 +248,7 @@ class ASRHttpHandler:
}
}
res
=
requests
.
post
(
url
=
self
.
url
,
data
=
json
.
dumps
(
data
))
res
=
requests
.
post
(
url
=
self
.
url
,
data
=
json
.
dumps
(
data
))
return
res
.
json
()
return
res
.
json
()
...
...
paddlespeech/server/utils/buffer.py
浏览文件 @
f07f57a3
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
class
Frame
(
object
):
class
Frame
(
object
):
"""Represents a "frame" of audio data."""
"""Represents a "frame" of audio data."""
...
@@ -45,7 +46,7 @@ class ChunkBuffer(object):
...
@@ -45,7 +46,7 @@ class ChunkBuffer(object):
self
.
shift_ms
=
shift_ms
self
.
shift_ms
=
shift_ms
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
self
.
sample_width
=
sample_width
# int16 = 2; float32 = 4
self
.
sample_width
=
sample_width
# int16 = 2; float32 = 4
self
.
window_sec
=
float
((
self
.
window_n
-
1
)
*
self
.
shift_ms
+
self
.
window_sec
=
float
((
self
.
window_n
-
1
)
*
self
.
shift_ms
+
self
.
window_ms
)
/
1000.0
self
.
window_ms
)
/
1000.0
self
.
shift_sec
=
float
(
self
.
shift_n
*
self
.
shift_ms
/
1000.0
)
self
.
shift_sec
=
float
(
self
.
shift_n
*
self
.
shift_ms
/
1000.0
)
...
@@ -77,8 +78,8 @@ class ChunkBuffer(object):
...
@@ -77,8 +78,8 @@ class ChunkBuffer(object):
offset
=
0
offset
=
0
while
offset
+
self
.
window_bytes
<=
len
(
audio
):
while
offset
+
self
.
window_bytes
<=
len
(
audio
):
yield
Frame
(
audio
[
offset
:
offset
+
self
.
window_bytes
],
self
.
timestamp
,
yield
Frame
(
audio
[
offset
:
offset
+
self
.
window_bytes
],
self
.
window_sec
)
self
.
timestamp
,
self
.
window_sec
)
self
.
timestamp
+=
self
.
shift_sec
self
.
timestamp
+=
self
.
shift_sec
offset
+=
self
.
shift_bytes
offset
+=
self
.
shift_bytes
...
...
paddlespeech/t2s/exps/speedyspeech/synthesize_e2e.py
浏览文件 @
f07f57a3
...
@@ -176,7 +176,10 @@ def main():
...
@@ -176,7 +176,10 @@ def main():
parser
.
add_argument
(
parser
.
add_argument
(
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu == 0, use cpu or xpu."
)
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu == 0, use cpu or xpu."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--nxpu"
,
type
=
int
,
default
=
0
,
help
=
"if nxpu == 0 and ngpu == 0, use cpu."
)
"--nxpu"
,
type
=
int
,
default
=
0
,
help
=
"if nxpu == 0 and ngpu == 0, use cpu."
)
args
,
_
=
parser
.
parse_known_args
()
args
,
_
=
parser
.
parse_known_args
()
...
...
paddlespeech/t2s/exps/speedyspeech/train.py
浏览文件 @
f07f57a3
...
@@ -188,7 +188,10 @@ def main():
...
@@ -188,7 +188,10 @@ def main():
parser
.
add_argument
(
"--dev-metadata"
,
type
=
str
,
help
=
"dev data."
)
parser
.
add_argument
(
"--dev-metadata"
,
type
=
str
,
help
=
"dev data."
)
parser
.
add_argument
(
"--output-dir"
,
type
=
str
,
help
=
"output dir."
)
parser
.
add_argument
(
"--output-dir"
,
type
=
str
,
help
=
"output dir."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--nxpu"
,
type
=
int
,
default
=
0
,
help
=
"if nxpu == 0 and ngpu == 0, use cpu."
)
"--nxpu"
,
type
=
int
,
default
=
0
,
help
=
"if nxpu == 0 and ngpu == 0, use cpu."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu == 0, use cpu or xpu"
)
"--ngpu"
,
type
=
int
,
default
=
1
,
help
=
"if ngpu == 0, use cpu or xpu"
)
...
...
paddlespeech/t2s/modules/transformer/repeat.py
浏览文件 @
f07f57a3
...
@@ -36,4 +36,4 @@ def repeat(N, fn):
...
@@ -36,4 +36,4 @@ def repeat(N, fn):
Returns:
Returns:
MultiSequential: Repeated model instance.
MultiSequential: Repeated model instance.
"""
"""
return
MultiSequential
(
*
[
fn
(
n
)
for
n
in
range
(
N
)])
return
MultiSequential
(
*
[
fn
(
n
)
for
n
in
range
(
N
)])
setup.py
浏览文件 @
f07f57a3
...
@@ -98,7 +98,6 @@ requirements = {
...
@@ -98,7 +98,6 @@ requirements = {
}
}
def
check_call
(
cmd
:
str
,
shell
=
False
,
executable
=
None
):
def
check_call
(
cmd
:
str
,
shell
=
False
,
executable
=
None
):
try
:
try
:
sp
.
check_call
(
sp
.
check_call
(
...
@@ -112,12 +111,13 @@ def check_call(cmd: str, shell=False, executable=None):
...
@@ -112,12 +111,13 @@ def check_call(cmd: str, shell=False, executable=None):
file
=
sys
.
stderr
)
file
=
sys
.
stderr
)
raise
e
raise
e
def
check_output
(
cmd
:
str
,
shell
=
False
):
def
check_output
(
cmd
:
str
,
shell
=
False
):
try
:
try
:
out_bytes
=
sp
.
check_output
(
cmd
.
split
())
out_bytes
=
sp
.
check_output
(
cmd
.
split
())
except
sp
.
CalledProcessError
as
e
:
except
sp
.
CalledProcessError
as
e
:
out_bytes
=
e
.
output
# Output generated before error
out_bytes
=
e
.
output
# Output generated before error
code
=
e
.
returncode
# Return code
code
=
e
.
returncode
# Return code
print
(
print
(
f
"
{
__file__
}
:
{
inspect
.
currentframe
().
f_lineno
}
: CMD:
{
cmd
}
, Error:"
,
f
"
{
__file__
}
:
{
inspect
.
currentframe
().
f_lineno
}
: CMD:
{
cmd
}
, Error:"
,
out_bytes
,
out_bytes
,
...
@@ -146,6 +146,7 @@ def _remove(files: str):
...
@@ -146,6 +146,7 @@ def _remove(files: str):
for
f
in
files
:
for
f
in
files
:
f
.
unlink
()
f
.
unlink
()
################################# Install ##################################
################################# Install ##################################
...
@@ -308,6 +309,5 @@ setup_info = dict(
...
@@ -308,6 +309,5 @@ setup_info = dict(
]
]
})
})
with
version_info
():
with
version_info
():
setup
(
**
setup_info
)
setup
(
**
setup_info
)
tests/unit/cli/aishell_test_prepare.py
浏览文件 @
f07f57a3
...
@@ -20,7 +20,6 @@ of each audio file in the data set.
...
@@ -20,7 +20,6 @@ of each audio file in the data set.
"""
"""
import
argparse
import
argparse
import
codecs
import
codecs
import
json
import
os
import
os
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -89,7 +88,7 @@ def create_manifest(data_dir, manifest_path_prefix):
...
@@ -89,7 +88,7 @@ def create_manifest(data_dir, manifest_path_prefix):
duration
=
float
(
len
(
audio_data
)
/
samplerate
)
duration
=
float
(
len
(
audio_data
)
/
samplerate
)
text
=
transcript_dict
[
audio_id
]
text
=
transcript_dict
[
audio_id
]
json_lines
.
append
(
audio_path
)
json_lines
.
append
(
audio_path
)
reference_lines
.
append
(
str
(
total_num
+
1
)
+
"
\t
"
+
text
)
reference_lines
.
append
(
str
(
total_num
+
1
)
+
"
\t
"
+
text
)
total_sec
+=
duration
total_sec
+=
duration
total_text
+=
len
(
text
)
total_text
+=
len
(
text
)
...
@@ -106,6 +105,7 @@ def create_manifest(data_dir, manifest_path_prefix):
...
@@ -106,6 +105,7 @@ def create_manifest(data_dir, manifest_path_prefix):
manifest_dir
=
os
.
path
.
dirname
(
manifest_path_prefix
)
manifest_dir
=
os
.
path
.
dirname
(
manifest_path_prefix
)
def
prepare_dataset
(
url
,
md5sum
,
target_dir
,
manifest_path
=
None
):
def
prepare_dataset
(
url
,
md5sum
,
target_dir
,
manifest_path
=
None
):
"""Download, unpack and create manifest file."""
"""Download, unpack and create manifest file."""
data_dir
=
os
.
path
.
join
(
target_dir
,
'data_aishell'
)
data_dir
=
os
.
path
.
join
(
target_dir
,
'data_aishell'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录