Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
1a3c811f
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1a3c811f
编写于
4月 08, 2022
作者:
L
lym0302
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
code format, test=doc
上级
759a9e61
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
73 addition
and
145 deletion
+73
-145
paddlespeech/server/engine/asr/online/asr_engine.py
paddlespeech/server/engine/asr/online/asr_engine.py
+2
-16
paddlespeech/server/engine/tts/online/tts_engine.py
paddlespeech/server/engine/tts/online/tts_engine.py
+38
-123
paddlespeech/server/tests/tts/online/ws_client.py
paddlespeech/server/tests/tts/online/ws_client.py
+2
-2
paddlespeech/server/tests/tts/online/ws_client_playaudio.py
paddlespeech/server/tests/tts/online/ws_client_playaudio.py
+2
-2
paddlespeech/server/utils/audio_process.py
paddlespeech/server/utils/audio_process.py
+14
-0
paddlespeech/server/utils/util.py
paddlespeech/server/utils/util.py
+13
-0
paddlespeech/server/ws/tts_socket.py
paddlespeech/server/ws/tts_socket.py
+2
-2
未找到文件。
paddlespeech/server/engine/asr/online/asr_engine.py
浏览文件 @
1a3c811f
...
@@ -27,6 +27,7 @@ from paddlespeech.s2t.frontend.speech import SpeechSegment
...
@@ -27,6 +27,7 @@ from paddlespeech.s2t.frontend.speech import SpeechSegment
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.modules.ctc
import
CTCDecoder
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.s2t.utils.utility
import
UpdateConfig
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils.audio_process
import
pcm2float
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
from
paddlespeech.server.utils.paddle_predictor
import
init_predictor
__all__
=
[
'ASREngine'
]
__all__
=
[
'ASREngine'
]
...
@@ -222,21 +223,6 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -222,21 +223,6 @@ class ASRServerExecutor(ASRExecutor):
else
:
else
:
raise
Exception
(
"invalid model name"
)
raise
Exception
(
"invalid model name"
)
def
_pcm16to32
(
self
,
audio
):
"""pcm int16 to float32
Args:
audio(numpy.array): numpy.int16
Returns:
audio(numpy.array): numpy.float32
"""
if
audio
.
dtype
==
np
.
int16
:
audio
=
audio
.
astype
(
"float32"
)
bits
=
np
.
iinfo
(
np
.
int16
).
bits
audio
=
audio
/
(
2
**
(
bits
-
1
))
return
audio
def
extract_feat
(
self
,
samples
,
sample_rate
):
def
extract_feat
(
self
,
samples
,
sample_rate
):
"""extract feat
"""extract feat
...
@@ -249,7 +235,7 @@ class ASRServerExecutor(ASRExecutor):
...
@@ -249,7 +235,7 @@ class ASRServerExecutor(ASRExecutor):
x_chunk_lens (numpy.array): shape[B]
x_chunk_lens (numpy.array): shape[B]
"""
"""
# pcm16 -> pcm 32
# pcm16 -> pcm 32
samples
=
self
.
_pcm16to32
(
samples
)
samples
=
pcm2float
(
samples
)
# read audio
# read audio
speech_segment
=
SpeechSegment
.
from_pcm
(
speech_segment
=
SpeechSegment
.
from_pcm
(
...
...
paddlespeech/server/engine/tts/online/tts_engine.py
浏览文件 @
1a3c811f
...
@@ -12,29 +12,17 @@
...
@@ -12,29 +12,17 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
base64
import
base64
import
io
import
time
import
time
import
librosa
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
soundfile
as
sf
from
scipy.io
import
wavfile
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.log
import
logger
from
paddlespeech.cli.tts.infer
import
TTSExecutor
from
paddlespeech.cli.tts.infer
import
TTSExecutor
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.engine.base_engine
import
BaseEngine
from
paddlespeech.server.utils.audio_process
import
change_speed
from
paddlespeech.server.utils.errors
import
ErrorCode
from
paddlespeech.server.utils.exception
import
ServerBaseException
from
paddlespeech.server.utils.audio_process
import
float2pcm
from
paddlespeech.server.utils.audio_process
import
float2pcm
from
paddlespeech.server.utils.config
import
get_config
from
paddlespeech.server.utils.util
import
denorm
from
paddlespeech.server.utils.util
import
get_chunks
from
paddlespeech.server.utils.util
import
get_chunks
import
math
__all__
=
[
'TTSEngine'
]
__all__
=
[
'TTSEngine'
]
...
@@ -44,15 +32,16 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -44,15 +32,16 @@ class TTSServerExecutor(TTSExecutor):
pass
pass
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
infer
(
self
,
def
infer
(
text
:
str
,
self
,
lang
:
str
=
'zh'
,
text
:
str
,
am
:
str
=
'fastspeech2_csmsc'
,
lang
:
str
=
'zh'
,
spk_id
:
int
=
0
,
am
:
str
=
'fastspeech2_csmsc'
,
am_block
:
int
=
42
,
spk_id
:
int
=
0
,
am_pad
:
int
=
12
,
am_block
:
int
=
42
,
voc_block
:
int
=
14
,
am_pad
:
int
=
12
,
voc_pad
:
int
=
14
,):
voc_block
:
int
=
14
,
voc_pad
:
int
=
14
,
):
"""
"""
Model inference and result stored in self.output.
Model inference and result stored in self.output.
"""
"""
...
@@ -61,8 +50,6 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -61,8 +50,6 @@ class TTSServerExecutor(TTSExecutor):
get_tone_ids
=
False
get_tone_ids
=
False
merge_sentences
=
False
merge_sentences
=
False
frontend_st
=
time
.
time
()
frontend_st
=
time
.
time
()
if
am_name
==
'speedyspeech'
:
get_tone_ids
=
True
if
lang
==
'zh'
:
if
lang
==
'zh'
:
input_ids
=
self
.
frontend
.
get_input_ids
(
input_ids
=
self
.
frontend
.
get_input_ids
(
text
,
text
,
...
@@ -95,7 +82,7 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -95,7 +82,7 @@ class TTSServerExecutor(TTSExecutor):
else
:
else
:
mel
=
self
.
am_inference
(
part_phone_ids
)
mel
=
self
.
am_inference
(
part_phone_ids
)
am_et
=
time
.
time
()
am_et
=
time
.
time
()
# voc streaming
# voc streaming
voc_upsample
=
self
.
voc_config
.
n_shift
voc_upsample
=
self
.
voc_config
.
n_shift
mel_chunks
=
get_chunks
(
mel
,
voc_block
,
voc_pad
,
"voc"
)
mel_chunks
=
get_chunks
(
mel
,
voc_block
,
voc_pad
,
"voc"
)
...
@@ -103,17 +90,19 @@ class TTSServerExecutor(TTSExecutor):
...
@@ -103,17 +90,19 @@ class TTSServerExecutor(TTSExecutor):
voc_st
=
time
.
time
()
voc_st
=
time
.
time
()
for
i
,
mel_chunk
in
enumerate
(
mel_chunks
):
for
i
,
mel_chunk
in
enumerate
(
mel_chunks
):
sub_wav
=
self
.
voc_inference
(
mel_chunk
)
sub_wav
=
self
.
voc_inference
(
mel_chunk
)
front_pad
=
min
(
i
*
voc_block
,
voc_pad
)
front_pad
=
min
(
i
*
voc_block
,
voc_pad
)
if
i
==
0
:
if
i
==
0
:
sub_wav
=
sub_wav
[:
voc_block
*
voc_upsample
]
sub_wav
=
sub_wav
[:
voc_block
*
voc_upsample
]
elif
i
==
chunk_num
-
1
:
elif
i
==
chunk_num
-
1
:
sub_wav
=
sub_wav
[
front_pad
*
voc_upsample
:
]
sub_wav
=
sub_wav
[
front_pad
*
voc_upsample
:
]
else
:
else
:
sub_wav
=
sub_wav
[
front_pad
*
voc_upsample
:
(
front_pad
+
voc_block
)
*
voc_upsample
]
sub_wav
=
sub_wav
[
front_pad
*
voc_upsample
:(
front_pad
+
voc_block
)
*
voc_upsample
]
yield
sub_wav
yield
sub_wav
class
TTSEngine
(
BaseEngine
):
class
TTSEngine
(
BaseEngine
):
"""TTS server engine
"""TTS server engine
...
@@ -128,9 +117,11 @@ class TTSEngine(BaseEngine):
...
@@ -128,9 +117,11 @@ class TTSEngine(BaseEngine):
def
init
(
self
,
config
:
dict
)
->
bool
:
def
init
(
self
,
config
:
dict
)
->
bool
:
self
.
executor
=
TTSServerExecutor
()
self
.
executor
=
TTSServerExecutor
()
self
.
config
=
config
assert
"fastspeech2_csmsc"
in
config
.
am
and
(
config
.
voc
==
"hifigan_csmsc-zh"
or
config
.
voc
==
"mb_melgan_csmsc"
),
'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
try
:
try
:
self
.
config
=
config
if
self
.
config
.
device
:
if
self
.
config
.
device
:
self
.
device
=
self
.
config
.
device
self
.
device
=
self
.
config
.
device
else
:
else
:
...
@@ -176,86 +167,11 @@ class TTSEngine(BaseEngine):
...
@@ -176,86 +167,11 @@ class TTSEngine(BaseEngine):
def
preprocess
(
self
,
text_bese64
:
str
=
None
,
text_bytes
:
bytes
=
None
):
def
preprocess
(
self
,
text_bese64
:
str
=
None
,
text_bytes
:
bytes
=
None
):
# Convert byte to text
# Convert byte to text
if
text_bese64
:
if
text_bese64
:
text_bytes
=
base64
.
b64decode
(
text_bese64
)
# base64 to bytes
text_bytes
=
base64
.
b64decode
(
text_bese64
)
# base64 to bytes
text
=
text_bytes
.
decode
(
'utf-8'
)
# bytes to text
text
=
text_bytes
.
decode
(
'utf-8'
)
# bytes to text
return
text
return
text
def
postprocess
(
self
,
wav
,
original_fs
:
int
,
target_fs
:
int
=
0
,
volume
:
float
=
1.0
,
speed
:
float
=
1.0
,
audio_path
:
str
=
None
):
"""Post-processing operations, including speech, volume, sample rate, save audio file
Args:
wav (numpy(float)): Synthesized audio sample points
original_fs (int): original audio sample rate
target_fs (int): target audio sample rate
volume (float): target volume
speed (float): target speed
Raises:
ServerBaseException: Throws an exception if the change speed unsuccessfully.
Returns:
target_fs: target sample rate for synthesized audio.
wav_base64: The base64 format of the synthesized audio.
"""
# transform sample_rate
if
target_fs
==
0
or
target_fs
>
original_fs
:
target_fs
=
original_fs
wav_tar_fs
=
wav
logger
.
info
(
"The sample rate of synthesized audio is the same as model, which is {}Hz"
.
format
(
original_fs
))
else
:
wav_tar_fs
=
librosa
.
resample
(
np
.
squeeze
(
wav
),
original_fs
,
target_fs
)
logger
.
info
(
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully."
.
format
(
original_fs
,
target_fs
))
# transform volume
wav_vol
=
wav_tar_fs
*
volume
logger
.
info
(
"Transform the volume of the audio successfully."
)
# transform speed
try
:
# windows not support soxbindings
wav_speed
=
change_speed
(
wav_vol
,
speed
,
target_fs
)
logger
.
info
(
"Transform the speed of the audio successfully."
)
except
ServerBaseException
:
raise
ServerBaseException
(
ErrorCode
.
SERVER_INTERNAL_ERR
,
"Failed to transform speed. Can not install soxbindings on your system.
\
You need to set speed value 1.0."
)
except
BaseException
:
logger
.
error
(
"Failed to transform speed."
)
# wav to base64
buf
=
io
.
BytesIO
()
wavfile
.
write
(
buf
,
target_fs
,
wav_speed
)
base64_bytes
=
base64
.
b64encode
(
buf
.
read
())
wav_base64
=
base64_bytes
.
decode
(
'utf-8'
)
logger
.
info
(
"Audio to string successfully."
)
# save audio
if
audio_path
is
not
None
:
if
audio_path
.
endswith
(
".wav"
):
sf
.
write
(
audio_path
,
wav_speed
,
target_fs
)
elif
audio_path
.
endswith
(
".pcm"
):
wav_norm
=
wav_speed
*
(
32767
/
max
(
0.001
,
np
.
max
(
np
.
abs
(
wav_speed
))))
with
open
(
audio_path
,
"wb"
)
as
f
:
f
.
write
(
wav_norm
.
astype
(
np
.
int16
))
logger
.
info
(
"Save audio to {} successfully."
.
format
(
audio_path
))
else
:
logger
.
info
(
"There is no need to save audio."
)
return
target_fs
,
wav_base64
def
run
(
self
,
def
run
(
self
,
sentence
:
str
,
sentence
:
str
,
spk_id
:
int
=
0
,
spk_id
:
int
=
0
,
...
@@ -275,31 +191,30 @@ class TTSEngine(BaseEngine):
...
@@ -275,31 +191,30 @@ class TTSEngine(BaseEngine):
save_path (str, optional): The save path of the synthesized audio.
save_path (str, optional): The save path of the synthesized audio.
None means do not save audio. Defaults to None.
None means do not save audio. Defaults to None.
Raises:
ServerBaseException: Throws an exception if tts inference unsuccessfully.
ServerBaseException: Throws an exception if postprocess unsuccessfully.
Returns:
Returns:
lang: model language
target_sample_rate: target sample rate for synthesized audio.
wav_base64: The base64 format of the synthesized audio.
wav_base64: The base64 format of the synthesized audio.
"""
"""
lang
=
self
.
config
.
lang
lang
=
self
.
config
.
lang
wav_list
=
[]
wav_list
=
[]
for
wav
in
self
.
executor
.
infer
(
text
=
sentence
,
lang
=
lang
,
am
=
self
.
config
.
am
,
spk_id
=
spk_id
,
am_block
=
self
.
am_block
,
am_pad
=
self
.
am_pad
,
voc_block
=
self
.
voc_block
,
voc_pad
=
self
.
voc_pad
):
for
wav
in
self
.
executor
.
infer
(
text
=
sentence
,
lang
=
lang
,
am
=
self
.
config
.
am
,
spk_id
=
spk_id
,
am_block
=
self
.
am_block
,
am_pad
=
self
.
am_pad
,
voc_block
=
self
.
voc_block
,
voc_pad
=
self
.
voc_pad
):
# wav type: <class 'numpy.ndarray'> float32, convert to pcm (base64)
# wav type: <class 'numpy.ndarray'> float32, convert to pcm (base64)
wav
=
float2pcm
(
wav
)
# float32 to int16
wav
=
float2pcm
(
wav
)
# float32 to int16
wav_bytes
=
wav
.
tobytes
()
# to bytes
wav_bytes
=
wav
.
tobytes
()
# to bytes
wav_base64
=
base64
.
b64encode
(
wav_bytes
).
decode
(
'utf8'
)
# to base64
wav_base64
=
base64
.
b64encode
(
wav_bytes
).
decode
(
'utf8'
)
# to base64
wav_list
.
append
(
wav
)
wav_list
.
append
(
wav
)
yield
wav_base64
wav_all
=
np
.
concatenate
(
wav_list
,
axis
=
0
)
logger
.
info
(
"The durations of audio is: {} s"
.
format
(
len
(
wav_all
)
/
self
.
executor
.
am_config
.
fs
))
yield
wav_base64
wav_all
=
np
.
concatenate
(
wav_list
,
axis
=
0
)
logger
.
info
(
"The durations of audio is: {} s"
.
format
(
len
(
wav_all
)
/
self
.
executor
.
am_config
.
fs
))
paddlespeech/server/tests/tts/online/ws_client.py
浏览文件 @
1a3c811f
...
@@ -25,7 +25,7 @@ st = 0.0
...
@@ -25,7 +25,7 @@ st = 0.0
all_bytes
=
b
''
all_bytes
=
b
''
class
Ws
_
Param
(
object
):
class
WsParam
(
object
):
# 初始化
# 初始化
def
__init__
(
self
,
text
,
server
=
"127.0.0.1"
,
port
=
8090
):
def
__init__
(
self
,
text
,
server
=
"127.0.0.1"
,
port
=
8090
):
self
.
server
=
server
self
.
server
=
server
...
@@ -116,7 +116,7 @@ if __name__ == "__main__":
...
@@ -116,7 +116,7 @@ if __name__ == "__main__":
print
(
"Sentence to be synthesized: "
,
args
.
text
)
print
(
"Sentence to be synthesized: "
,
args
.
text
)
print
(
"***************************************"
)
print
(
"***************************************"
)
wsParam
=
Ws
_
Param
(
text
=
args
.
text
,
server
=
args
.
server
,
port
=
args
.
port
)
wsParam
=
WsParam
(
text
=
args
.
text
,
server
=
args
.
server
,
port
=
args
.
port
)
websocket
.
enableTrace
(
False
)
websocket
.
enableTrace
(
False
)
wsUrl
=
wsParam
.
create_url
()
wsUrl
=
wsParam
.
create_url
()
...
...
paddlespeech/server/tests/tts/online/ws_client_playaudio.py
浏览文件 @
1a3c811f
...
@@ -32,7 +32,7 @@ st = 0.0
...
@@ -32,7 +32,7 @@ st = 0.0
all_bytes
=
0.0
all_bytes
=
0.0
class
Ws
_
Param
(
object
):
class
WsParam
(
object
):
# 初始化
# 初始化
def
__init__
(
self
,
text
,
server
=
"127.0.0.1"
,
port
=
8090
):
def
__init__
(
self
,
text
,
server
=
"127.0.0.1"
,
port
=
8090
):
self
.
server
=
server
self
.
server
=
server
...
@@ -144,7 +144,7 @@ if __name__ == "__main__":
...
@@ -144,7 +144,7 @@ if __name__ == "__main__":
print
(
"Sentence to be synthesized: "
,
args
.
text
)
print
(
"Sentence to be synthesized: "
,
args
.
text
)
print
(
"***************************************"
)
print
(
"***************************************"
)
wsParam
=
Ws
_
Param
(
text
=
args
.
text
,
server
=
args
.
server
,
port
=
args
.
port
)
wsParam
=
WsParam
(
text
=
args
.
text
,
server
=
args
.
server
,
port
=
args
.
port
)
websocket
.
enableTrace
(
False
)
websocket
.
enableTrace
(
False
)
wsUrl
=
wsParam
.
create_url
()
wsUrl
=
wsParam
.
create_url
()
...
...
paddlespeech/server/utils/audio_process.py
浏览文件 @
1a3c811f
...
@@ -126,3 +126,17 @@ def float2pcm(sig, dtype='int16'):
...
@@ -126,3 +126,17 @@ def float2pcm(sig, dtype='int16'):
abs_max
=
2
**
(
i
.
bits
-
1
)
abs_max
=
2
**
(
i
.
bits
-
1
)
offset
=
i
.
min
+
abs_max
offset
=
i
.
min
+
abs_max
return
(
sig
*
abs_max
+
offset
).
clip
(
i
.
min
,
i
.
max
).
astype
(
dtype
)
return
(
sig
*
abs_max
+
offset
).
clip
(
i
.
min
,
i
.
max
).
astype
(
dtype
)
def
pcm2float
(
data
):
"""pcm int16 to float32
Args:
audio(numpy.array): numpy.int16
Returns:
audio(numpy.array): numpy.float32
"""
if
data
.
dtype
==
np
.
int16
:
data
=
data
.
astype
(
"float32"
)
bits
=
np
.
iinfo
(
np
.
int16
).
bits
data
=
data
/
(
2
**
(
bits
-
1
))
return
data
paddlespeech/server/utils/util.py
浏览文件 @
1a3c811f
...
@@ -35,10 +35,23 @@ def self_check():
...
@@ -35,10 +35,23 @@ def self_check():
def
denorm
(
data
,
mean
,
std
):
def
denorm
(
data
,
mean
,
std
):
"""stream am model need to denorm
"""
return
data
*
std
+
mean
return
data
*
std
+
mean
def
get_chunks
(
data
,
block_size
,
pad_size
,
step
):
def
get_chunks
(
data
,
block_size
,
pad_size
,
step
):
"""Divide data into multiple chunks
Args:
data (tensor): data
block_size (int): [description]
pad_size (int): [description]
step (str): set "am" or "voc", generate chunk for step am or vocoder(voc)
Returns:
list: chunks list
"""
if
step
==
"am"
:
if
step
==
"am"
:
data_len
=
data
.
shape
[
1
]
data_len
=
data
.
shape
[
1
]
elif
step
==
"voc"
:
elif
step
==
"voc"
:
...
...
paddlespeech/server/ws/tts_socket.py
浏览文件 @
1a3c811f
...
@@ -44,11 +44,11 @@ async def websocket_endpoint(websocket: WebSocket):
...
@@ -44,11 +44,11 @@ async def websocket_endpoint(websocket: WebSocket):
sentence
=
tts_engine
.
preprocess
(
text_bese64
=
text_bese64
)
sentence
=
tts_engine
.
preprocess
(
text_bese64
=
text_bese64
)
# run
# run
wav
=
tts_engine
.
run
(
sentence
)
wav
_generator
=
tts_engine
.
run
(
sentence
)
while
True
:
while
True
:
try
:
try
:
tts_results
=
next
(
wav
)
tts_results
=
next
(
wav
_generator
)
resp
=
{
"status"
:
1
,
"audio"
:
tts_results
}
resp
=
{
"status"
:
1
,
"audio"
:
tts_results
}
await
websocket
.
send_json
(
resp
)
await
websocket
.
send_json
(
resp
)
logger
.
info
(
"streaming audio..."
)
logger
.
info
(
"streaming audio..."
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录